การใช้ torch.no_grad ใน pytorch คืออะไร


21

ฉันใหม่เพื่อ pytorch และเริ่มต้นด้วยการนี้รหัส GitHub ฉันไม่เข้าใจความคิดเห็นในบรรทัด 60-61 "because weights have requires_grad=True, but we don't need to track this in autograd"ในรหัส ฉันเข้าใจว่าเราพูดถึงrequires_grad=Trueตัวแปรที่เราต้องคำนวณการไล่ระดับสีสำหรับการใช้ autograd แต่มันหมายความว่า"tracked by autograd"อย่างไร

คำตอบ:


24

wrapper "with torch.no_grad ()" ตั้งค่าสถานะ require_grad ทั้งหมดชั่วคราวให้เป็นเท็จ ตัวอย่างจากการสอน PyTorch อย่างเป็นทางการ ( https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients ):

x = torch.randn(3, requires_grad=True)
print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)

ออก:

True
True
False

ผมขอแนะนำให้คุณอ่านบทเรียนทั้งหมดจากเว็บไซต์ด้านบน

ในตัวอย่างของคุณ: ฉันเดาว่าผู้เขียนไม่ต้องการ PyTorch ในการคำนวณการไล่ระดับสีของตัวแปรที่กำหนดใหม่ w1 และ w2 เนื่องจากเขาเพียงต้องการอัปเดตค่าของพวกเขา


6
with torch.no_grad()

จะทำให้การดำเนินการทั้งหมดในบล็อกไม่มีการไล่ระดับสี

ใน pytorch คุณไม่สามารถทำ inplacement เปลี่ยนแปลงของ W1 และ W2 require_grad = Trueซึ่งมีสองตัวแปรที่มี ฉันคิดว่าการหลีกเลี่ยงการเปลี่ยนแปลงการแทนที่ของ w1 และ w2 นั้นเป็นเพราะจะทำให้เกิดข้อผิดพลาดในการคำนวณการกระจายกลับ เนื่องจากการเปลี่ยนตำแหน่งจะเปลี่ยน w1 และ w2 โดยสิ้นเชิง

อย่างไรก็ตามถ้าคุณใช้สิ่งนี้no_grad()คุณสามารถควบคุม w1 และ w2 ใหม่ได้โดยไม่มีการไล่ระดับสีเนื่องจากถูกสร้างขึ้นโดยการดำเนินการซึ่งหมายความว่าคุณเปลี่ยนค่าของ w1 และ w2 เท่านั้นไม่ใช่ส่วนที่ไล่ระดับสีพวกเขายังคงมีข้อมูลการไล่ระดับสีตัวแปร และการขยายพันธุ์กลับสามารถดำเนินการต่อ

โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.