วิธีที่ดีที่สุดในการบันทึกโมเดลที่ได้รับการฝึกฝนใน PyTorch


193

ฉันกำลังมองหาวิธีอื่นในการบันทึกโมเดลที่ได้รับการฝึกฝนใน PyTorch จนถึงตอนนี้ฉันได้พบสองทางเลือก

  1. torch.save ()เพื่อบันทึกโมเดลและtorch.load ()เพื่อโหลดโมเดล
  2. model.state_dict ()เพื่อบันทึกโมเดลที่ผ่านการฝึกอบรมและmodel.load_state_dict ()เพื่อโหลดโมเดลที่บันทึกไว้

ฉันได้พบกับการสนทนานี้ที่แนะนำวิธีที่ 2 มากกว่าวิธีที่ 1

คำถามของฉันคือทำไมทำไมถึงเลือกวิธีที่สอง เป็นเพียงเพราะโมดูลtorch.nnมีสองฟังก์ชั่นเหล่านั้นและเราสนับสนุนให้ใช้พวกเขา?


2
ฉันคิดว่าเป็นเพราะ torch.save () บันทึกตัวแปรกลางทั้งหมดเช่นกันเช่นเอาท์พุทกลางสำหรับการเผยแพร่กลับใช้ แต่คุณจะต้องบันทึกพารามิเตอร์ของแบบจำลองเช่นน้ำหนัก / อคติเป็นต้นบางครั้งอดีตอาจมีขนาดใหญ่กว่าหลัง
Dawei Yang

2
ผมทดสอบและtorch.save(model, f) torch.save(model.state_dict(), f)ไฟล์ที่บันทึกมีขนาดเท่ากัน ตอนนี้ฉันสับสน นอกจากนี้ฉันพบว่าใช้ดองเพื่อบันทึก model.state_dict () ช้ามาก ฉันคิดว่าวิธีที่ดีที่สุดคือการใช้torch.save(model.state_dict(), f)ตั้งแต่คุณจัดการการสร้างแบบจำลองและไฟฉายจัดการการโหลดน้ำหนักของแบบจำลองจึงช่วยขจัดปัญหาที่เป็นไปได้ การอ้างอิง: Discuss.pytorch.org/t/saving-torch-models/838/4
Dawei Yang

ดูเหมือน PyTorch ได้กล่าวถึงเรื่องนี้อย่างชัดเจนยิ่งขึ้นในบทแนะนำของพวกเขา- มีข้อมูลที่ดีมากมายที่นั่นซึ่งไม่ได้ระบุไว้ในคำตอบที่นี่รวมถึงการบันทึกมากกว่าหนึ่งรุ่นในแต่ละครั้งและรุ่นเริ่มต้นที่อบอุ่น
whlteXbread

มีอะไรผิดปกติกับการใช้งานpickle?
Charlie Parker

1
@CharlieParker torch.save อิงจากดอง ต่อไปนี้มาจากบทช่วยสอนที่เชื่อมโยงด้านบน: "[torch.save] จะบันทึกโมดูลทั้งหมดโดยใช้โมดูล pickle ของ Python ข้อเสียของวิธีการนี้คือข้อมูลที่ต่อเนื่องถูกผูกไว้กับคลาสเฉพาะและโครงสร้างไดเรกทอรีที่แน่นอนที่ใช้เมื่อโมเดล จะถูกบันทึกเหตุผลนี้เป็นเพราะ pickle ไม่ได้บันทึก class model เอง แต่มันจะบันทึกพา ธ ไปยังไฟล์ที่มี class ซึ่งถูกใช้ในช่วงเวลาในการโหลดด้วยเหตุนี้โค้ดของคุณจึงสามารถแตกได้หลายวิธีเมื่อ ใช้ในโครงการอื่นหรือหลังจาก refactors "
David Miller

คำตอบ:


215

ฉันได้พบหน้านี้ใน repo GitHub ของพวกเขาฉันจะวางเนื้อหาที่นี่


วิธีการที่แนะนำสำหรับการบันทึกแบบจำลอง

มีวิธีการหลักสองวิธีในการทำให้เป็นอนุกรมและการกู้คืนแบบจำลอง

คำแนะนำแรก (แนะนำ) บันทึกและโหลดเฉพาะพารามิเตอร์โมเดล:

torch.save(the_model.state_dict(), PATH)

จากนั้นในภายหลัง:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

ประการที่สองบันทึกและโหลดโมเดลทั้งหมด:

torch.save(the_model, PATH)

จากนั้นในภายหลัง:

the_model = torch.load(PATH)

อย่างไรก็ตามในกรณีนี้ข้อมูลต่อเนื่องจะถูกผูกไว้กับคลาสที่เฉพาะเจาะจงและโครงสร้างไดเรกทอรีที่แน่นอนที่ใช้ดังนั้นจึงสามารถแตกในรูปแบบต่างๆเมื่อใช้ในโครงการอื่น ๆ หรือหลังจาก refactors ร้ายแรงบางอย่าง


8
ตาม @smth discuss.pytorch.org/t/saving-and-loading-a-model-in-pytorch/...โหลดรูปแบบการฝึกอบรมรุ่นโดยค่าเริ่มต้น ดังนั้นจำเป็นต้องเรียก the_model.eval () ด้วยตนเองหลังจากโหลดถ้าคุณกำลังโหลดเพื่ออนุมานไม่ใช่ดำเนินการฝึกอบรมต่อ
WillZ

วิธีที่สองให้ข้อผิดพลาดstackoverflow.com/questions/53798009/..บน windows 10 ไม่สามารถแก้ไขได้
Gulzar

มีตัวเลือกในการบันทึกโดยไม่จำเป็นต้องเข้าถึงคลาสรุ่นหรือไม่
Michael D

ด้วยวิธีการที่คุณติดตาม * args และ ** kwargs ที่คุณต้องผ่านสำหรับกรณีโหลดได้อย่างไร
Mariano Kamp

มีอะไรผิดปกติกับการใช้งานpickle?
Charlie Parker

144

ขึ้นอยู่กับสิ่งที่คุณต้องการจะทำ

กรณี # 1: บันทึกโมเดลเพื่อใช้เป็นข้อสรุป : คุณบันทึกโมเดลคุณกู้คืนโมเดลจากนั้นคุณเปลี่ยนโมเดลเป็นโหมดการประเมินผล สิ่งนี้จะทำเพราะคุณมักจะมีBatchNormและDropoutชั้นที่โดยค่าเริ่มต้นอยู่ในโหมดรถไฟในการก่อสร้าง:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

กรณีที่ # 2: บันทึกแบบจำลองเพื่อดำเนินการฝึกอบรมในภายหลัง : หากคุณต้องการฝึกอบรมแบบจำลองที่คุณกำลังจะบันทึกไว้คุณต้องบันทึกมากกว่าแบบจำลอง คุณต้องบันทึกสถานะของเครื่องมือเพิ่มประสิทธิภาพ, ยุค, คะแนน, ฯลฯ คุณจะทำเช่นนี้:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

ในการฝึกอบรมต่อคุณจะต้องทำสิ่งต่าง ๆ เช่น: state = torch.load(filepath)จากนั้นเพื่อเรียกคืนสถานะของแต่ละวัตถุบางอย่างดังนี้:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

เนื่องจากคุณกลับมาทำงานการฝึกอบรมอย่าเรียกใช้model.eval()เมื่อคุณกู้คืนสถานะเมื่อโหลด

กรณีที่ # 3: รูปแบบที่บุคคลอื่นจะใช้โดยไม่สามารถเข้าถึงรหัสของคุณ : ใน Tensorflow คุณสามารถสร้าง.pbไฟล์ที่กำหนดทั้งสถาปัตยกรรมและน้ำหนักของแบบจำลอง Tensorflow serveนี้จะมีประโยชน์มากเป็นพิเศษเมื่อใช้ วิธีที่เท่าเทียมกันในการทำเช่นนี้ใน Pytorch คือ:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

วิธีนี้ยังไม่กันกระสุนและเนื่องจาก pytorch ยังคงมีการเปลี่ยนแปลงมากมายฉันไม่แนะนำ


1
มีไฟล์แนะนำที่จะสิ้นสุดใน 3 รายหรือไม่? หรือมันคือ. pth เสมอ
Verena Haunschmid

1
ใน Case # 3 torch.loadส่งคืนเฉพาะ OrderedDict คุณจะรับโมเดลเพื่อทำนายได้อย่างไร
Alber8295

สวัสดีฉันขอทราบวิธีการที่กล่าวถึง "กรณีที่ 2: บันทึกแบบจำลองเพื่อดำเนินการฝึกอบรมภายหลัง" ฉันจัดการเพื่อโหลดจุดตรวจสอบเป็นแบบจำลองจากนั้นฉันไม่สามารถเรียกใช้หรือดำเนินการต่อเพื่อฝึกอบรมโมเดลเช่น "model.to (อุปกรณ์) model = train_model_epoch (โมเดลเกณฑ์การเพิ่มประสิทธิภาพ
กำหนดการ

1
สวัสดีสำหรับกรณีที่เป็นการอนุมานใน pytorch doc อย่างเป็นทางการกล่าวว่าต้องบันทึก state_dict ของเครื่องมือเพิ่มประสิทธิภาพสำหรับการอนุมานหรือการฝึกอบรมให้เสร็จ "เมื่อบันทึกจุดตรวจทั่วไปเพื่อใช้สำหรับการอนุมานหรือเริ่มการฝึกอบรมต่อคุณต้องบันทึกมากกว่า state_dict ของรุ่นสิ่งสำคัญคือต้องบันทึก state_dict ของเครื่องมือเพิ่มประสิทธิภาพด้วยเนื่องจากมีบัฟเฟอร์และพารามิเตอร์ที่ปรับปรุงเป็นโมเดลรถไฟ . "
Mohammed Awney

1
ในกรณีที่ # 3 ควรกำหนดคลาสของโมเดลไว้ที่ใดที่หนึ่ง
Michael D

12

ดองดำเนินห้องสมุดหลามโปรโตคอลไบนารี serializing และ de-serializing วัตถุหลาม

เมื่อคุณimport torch(หรือเมื่อคุณใช้ PyTorch) มันจะช่วยimport pickleให้คุณและคุณไม่จำเป็นต้องโทรpickle.dump()และpickle.load()โดยตรงซึ่งเป็นวิธีการบันทึกและโหลดวัตถุ

ในความเป็นจริงtorch.save()และtorch.load()จะห่อpickle.dump()และpickle.load()สำหรับคุณ

state_dictคำตอบอื่น ๆ ที่กล่าวสมควรได้รับการบันทึกมากขึ้นเพียงไม่กี่

state_dictเรามีอะไรใน PyTorch มีอยู่สองstate_dictตัว

โมเดล PyTorch torch.nn.Moduleมีการmodel.parameters()เรียกเพื่อรับพารามิเตอร์ที่เรียนรู้ได้ (w และ b) พารามิเตอร์ที่เรียนรู้ได้เหล่านี้เมื่อตั้งค่าแบบสุ่มจะอัปเดตตามเวลาที่เราเรียนรู้ พารามิเตอร์ learnable state_dictเป็นครั้งแรก

ประการที่สองstate_dictคือคำสั่งรัฐเพิ่มประสิทธิภาพ คุณจำได้ว่าเครื่องมือเพิ่มประสิทธิภาพใช้เพื่อปรับปรุงพารามิเตอร์ที่เรียนรู้ได้ของเรา แต่เครื่องมือเพิ่มประสิทธิภาพstate_dictได้รับการแก้ไข ไม่มีอะไรให้เรียนรู้ในนั้น

เนื่องจากstate_dictวัตถุเป็นพจนานุกรมภาษาไพ ธ อนพวกเขาจึงสามารถบันทึกอัปเดตแก้ไขและเรียกคืนได้อย่างง่ายดายเพิ่มความเป็นโมดุลเดอเรชันให้กับโมเดล PyTorch และเครื่องมือเพิ่มประสิทธิภาพอย่างมาก

มาสร้างแบบจำลองง่าย ๆ เพื่ออธิบายสิ่งนี้:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

รหัสนี้จะส่งออกต่อไปนี้:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

หมายเหตุนี่เป็นรุ่นที่น้อยที่สุด คุณอาจลองเพิ่มชุดข้อมูลตามลำดับ

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

หมายเหตุว่ามีเพียงชั้นกับพารามิเตอร์ learnable (ชั้นสับสนชั้นเชิงเส้น, ฯลฯ ) และบัฟเฟอร์ลงทะเบียน (ชั้น batchnorm) state_dictมีรายการในรูปแบบของ

สิ่งที่ไม่สามารถเรียนรู้ได้นั้นเป็นของวัตถุของเครื่องมือเพิ่มประสิทธิภาพstate_dictซึ่งมีข้อมูลเกี่ยวกับสถานะของเครื่องมือเพิ่มประสิทธิภาพเช่นเดียวกับพารามิเตอร์ที่ใช้

เรื่องที่เหลือก็เหมือนกัน ในช่วงการอนุมาน (นี่เป็นระยะเมื่อเราใช้แบบจำลองหลังการฝึกอบรม) สำหรับการทำนาย เราทำนายตามพารามิเตอร์ที่เราเรียนรู้ model.state_dict()ดังนั้นสำหรับการอนุมานเราก็จำเป็นต้องบันทึกค่าพารามิเตอร์

torch.save(model.state_dict(), filepath)

และเพื่อใช้ในภายหลัง model.load_state_dict (torch.load (filepath)) model.eval ()

หมายเหตุ: อย่าลืมบรรทัดสุดท้ายmodel.eval()สิ่งนี้สำคัญมากหลังจากโหลดแบบจำลอง

torch.save(model.parameters(), filepath)นอกจากนี้ยังไม่ได้พยายามที่จะบันทึก model.parameters()เป็นเพียงวัตถุกำเนิดไฟฟ้า

ในด้านอื่น ๆ ที่torch.save(model, filepath)จะช่วยประหยัดวัตถุรูปแบบของตัวเอง state_dictแต่เก็บไว้ในใจแบบไม่ได้มีการเพิ่มประสิทธิภาพของ ตรวจสอบคำตอบที่ยอดเยี่ยมอื่น ๆ โดย @Jadiel de Armas เพื่อบันทึก dict สถานะของเครื่องมือเพิ่มประสิทธิภาพ


ถึงแม้ว่ามันจะไม่ใช่วิธีแก้ปัญหาที่ตรงไปตรงมา แต่สาระสำคัญของปัญหาก็คือการวิเคราะห์อย่างลึกซึ้ง! upvote
Jason Young

7

หลักการทั่วไปของ PyTorch คือการบันทึกแบบจำลองโดยใช้นามสกุลไฟล์. ppt หรือ. pth

บันทึก / โหลดทั้งรุ่น บันทึก:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

โหลด:

คลาสของโมเดลต้องถูกกำหนดไว้ที่ใดที่หนึ่ง

model = torch.load(PATH)
model.eval()

4

หากคุณต้องการบันทึกโมเดลและต้องการเริ่มการฝึกอบรมในภายหลัง

GPU เดี่ยว: บันทึก:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

โหลด:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

หลาย GPU: บันทึก

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

โหลด:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.