ทำไมเราต้องเรียก zero_grad () ใน PyTorch


126

zero_grad()ต้องมีการเรียกวิธีการนี้ในระหว่างการฝึกอบรม แต่เอกสารประกอบก็ไม่ค่อยมีประโยชน์

|  zero_grad(self)
|      Sets gradients of all model parameters to zero.

ทำไมเราต้องเรียกวิธีนี้?

คำตอบ:


164

ในPyTorchเราจำเป็นต้องตั้งค่าการไล่ระดับสีเป็นศูนย์ก่อนที่จะเริ่มทำ backpropragation เนื่องจาก PyTorch สะสมการไล่ระดับสีในการย้อนกลับในภายหลัง สะดวกในขณะฝึก RNN ดังนั้นการดำเนินการเริ่มต้นคือการสะสม (เช่นผลรวม) การไล่ระดับสีในทุกการloss.backward()โทร

ด้วยเหตุนี้เมื่อคุณเริ่มลูปการฝึกของคุณคุณควรzero out the gradientsจะอัปเดตพารามิเตอร์ให้ถูกต้อง มิฉะนั้นการไล่ระดับสีจะชี้ไปในทิศทางอื่นนอกเหนือจากทิศทางที่ตั้งใจไว้ไปสู่ค่าต่ำสุด (หรือสูงสุดในกรณีของวัตถุประสงค์ในการขยายสูงสุด)

นี่คือตัวอย่างง่ายๆ:

import torch
from torch.autograd import Variable
import torch.optim as optim

def linear_model(x, W, b):
    return torch.matmul(x, W) + b

data, targets = ...

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

optimizer = optim.Adam([W, b])

for sample, target in zip(data, targets):
    # clear out the gradients of all Variables 
    # in this optimizer (i.e. W, b)
    optimizer.zero_grad()
    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()
    optimizer.step()

หรือหากคุณกำลังทำเชื้อสายวานิลลาไล่ระดับสีให้ทำดังนี้

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

for sample, target in zip(data, targets):
    # clear out the gradients of Variables 
    # (i.e. W, b)
    W.grad.data.zero_()
    b.grad.data.zero_()

    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()

    W -= learning_rate * W.grad.data
    b -= learning_rate * b.grad.data

หมายเหตุ : ในการสะสม (คือผลรวม ) ของการไล่ระดับสีเกิดขึ้นเมื่อ.backward()มีการเรียกร้องให้lossเมตริกซ์


3
ขอบคุณมากนี่เป็นประโยชน์จริงๆ! คุณรู้หรือไม่ว่าเทนเซอร์โฟลว์มีพฤติกรรมหรือไม่?
layser

เพื่อให้แน่ใจว่า .. ถ้าคุณไม่ทำเช่นนี้คุณจะพบปัญหาการไล่ระดับสีที่ระเบิดใช่ไหม?
zwep

3
@zwep หากเราสะสมการไล่ระดับสีก็ไม่ได้หมายความว่าขนาดของมันจะเพิ่มขึ้นตัวอย่างเช่นหากสัญญาณของการไล่ระดับสียังคงพลิก ดังนั้นจึงไม่สามารถรับประกันได้ว่าคุณจะประสบปัญหาการไล่ระดับสีแบบระเบิด นอกจากนี้ยังมีการไล่ระดับสีแบบระเบิดแม้ว่าคุณจะเป็นศูนย์อย่างถูกต้องก็ตาม
Tom Roth

เมื่อคุณเรียกใช้การไล่ระดับสีวานิลลาคุณไม่ได้รับข้อผิดพลาด "Leaf Variable ที่ต้องการการไล่ระดับสีในการดำเนินการแบบแทนที่" เมื่อคุณพยายามอัปเดตน้ำหนัก
MUAS

3

zero_grad() เริ่มการวนซ้ำโดยไม่สูญเสียจากขั้นตอนสุดท้ายหากคุณใช้วิธีการไล่ระดับสีเพื่อลดข้อผิดพลาด (หรือการสูญเสีย)

หากคุณไม่ใช้zero_grad()การสูญเสียจะลดลงไม่เพิ่มขึ้นตามที่กำหนด

ตัวอย่างเช่น:

หากคุณใช้zero_grad()คุณจะได้รับผลลัพธ์ต่อไปนี้:

model training loss is 1.5
model training loss is 1.4
model training loss is 1.3
model training loss is 1.2

หากคุณไม่ใช้zero_grad()คุณจะได้รับผลลัพธ์ต่อไปนี้:

model training loss is 1.4
model training loss is 1.9
model training loss is 2
model training loss is 2.8
model training loss is 3.5
โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.