TensorFlow บันทึกลงใน / โหลดกราฟจากไฟล์


101

จากสิ่งที่ฉันรวบรวมมาจนถึงตอนนี้มีหลายวิธีในการทิ้งกราฟ TensorFlow ลงในไฟล์แล้วโหลดลงในโปรแกรมอื่น แต่ฉันยังไม่พบตัวอย่าง / ข้อมูลที่ชัดเจนเกี่ยวกับวิธีการทำงาน สิ่งที่ฉันรู้อยู่แล้วคือ:

  1. บันทึกตัวแปรของโมเดลลงในไฟล์จุดตรวจ (.ckpt) โดยใช้ a tf.train.Saver()และเรียกคืนในภายหลัง (ที่มา )
  2. บันทึกโมเดลลงในไฟล์. pb และโหลดกลับมาใช้tf.train.write_graph()และtf.import_graph_def()(ที่มา )
  3. โหลดแบบจำลองจากไฟล์. pb ฝึกใหม่และถ่ายโอนข้อมูลลงในไฟล์. pb ใหม่โดยใช้ Bazel ( ซอร์ส )
  4. ตรึงกราฟเพื่อบันทึกกราฟและน้ำหนักร่วมกัน (ที่มา )
  5. ใช้as_graph_def()เพื่อบันทึกแบบจำลองและสำหรับน้ำหนัก / ตัวแปรให้จับคู่เป็นค่าคงที่ (ที่มา )

อย่างไรก็ตามฉันไม่สามารถไขข้อข้องใจเกี่ยวกับวิธีการต่างๆเหล่านี้ได้:

  1. เกี่ยวกับไฟล์จุดตรวจพวกเขาบันทึกเฉพาะน้ำหนักที่ฝึกแล้วของแบบจำลองหรือไม่? สามารถโหลดไฟล์จุดตรวจลงในโปรแกรมใหม่และใช้ในการรันโมเดลได้หรือไม่หรือใช้เป็นวิธีในการบันทึกน้ำหนักในโมเดลในช่วงเวลา / ขั้นตอนที่กำหนด
  2. เกี่ยวกับtf.train.write_graph()น้ำหนัก / ตัวแปรได้รับการบันทึกด้วยหรือไม่?
  3. เกี่ยวกับ Bazel สามารถบันทึกลงใน / โหลดจากไฟล์. pb เพื่อฝึกอบรมใหม่ได้หรือไม่? มีคำสั่ง Bazel ง่ายๆเพียงเพื่อถ่ายโอนกราฟลงใน. pb หรือไม่?
  4. เกี่ยวกับการแช่แข็งสามารถโหลดกราฟแช่แข็งโดยใช้ได้tf.import_graph_def()หรือไม่?
  5. การสาธิต Android สำหรับ TensorFlow โหลดในโมเดล Inception ของ Google จากไฟล์. pb ถ้าฉันต้องการแทนที่ไฟล์. pb ของตัวเองฉันจะทำอย่างไร ฉันจะต้องเปลี่ยนรหัส / วิธีการใด ๆ หรือไม่?
  6. โดยทั่วไปแล้ววิธีการทั้งหมดนี้แตกต่างกันอย่างไร? หรือกว้างกว่านั้นคือas_graph_def()/.ckpt/.pb ต่างกันอย่างไร

ในระยะสั้นสิ่งที่ฉันกำลังมองหาคือวิธีการบันทึกทั้งกราฟ (เช่นการดำเนินการต่างๆและอื่น ๆ ) และน้ำหนัก / ตัวแปรลงในไฟล์ซึ่งสามารถใช้เพื่อโหลดกราฟและน้ำหนักลงในโปรแกรมอื่นได้ สำหรับการใช้งาน (ไม่จำเป็นต้องดำเนินการต่อ / ฝึกอบรมใหม่)

เอกสารเกี่ยวกับหัวข้อนี้ไม่ตรงไปตรงมามากนักดังนั้นคำตอบ / ข้อมูลใด ๆ จะได้รับการชื่นชมอย่างมาก


2
API ใหม่ล่าสุด / สมบูรณ์ที่สุดคือเมตากราฟซึ่งจะช่วยให้คุณสามารถบันทึกทั้งสามอย่างพร้อมกัน - 1) กราฟ 2) ค่าพารามิเตอร์ 3) คอลเลกชัน: tensorflow.org/versions/r0.10/how_tos/meta_graph/ index.html
Yaroslav Bulatov

คำตอบ:


81

มีหลายวิธีในการแก้ไขปัญหาในการบันทึกโมเดลใน TensorFlow ซึ่งอาจทำให้สับสนเล็กน้อย ตอบคำถามย่อยแต่ละข้อของคุณ:

  1. ไฟล์ด่าน (ผลิตเช่นโดยการเรียกsaver.save()บนtf.train.Saverวัตถุ) มีเพียงน้ำหนักและตัวแปรอื่น ๆ ที่กำหนดไว้ในโปรแกรมเดียวกัน หากต้องการใช้ในโปรแกรมอื่นคุณต้องสร้างโครงสร้างกราฟที่เกี่ยวข้องขึ้นมาใหม่ (เช่นโดยการรันโค้ดเพื่อสร้างอีกครั้งหรือเรียกใช้tf.import_graph_def()) ซึ่งจะบอก TensorFlow ว่าจะทำอย่างไรกับน้ำหนักเหล่านั้น โปรดทราบว่าการโทรsaver.save()ยังสร้างไฟล์ที่มี a MetaGraphDefซึ่งมีกราฟและรายละเอียดวิธีการเชื่อมโยงน้ำหนักจากจุดตรวจกับกราฟนั้น ดูบทแนะนำสำหรับรายละเอียดเพิ่มเติม

  2. tf.train.write_graph()เขียนเฉพาะโครงสร้างกราฟ ไม่ใช่น้ำหนัก

  3. Bazel ไม่เกี่ยวข้องกับการอ่านหรือเขียนกราฟ TensorFlow (บางทีฉันอาจเข้าใจผิดคำถามของคุณ: อย่าลังเลที่จะชี้แจงในความคิดเห็น)

  4. สามารถโหลดกราฟแช่แข็งได้โดยใช้tf.import_graph_def(). ในกรณีนี้น้ำหนักจะฝังอยู่ในกราฟดังนั้นคุณไม่จำเป็นต้องโหลดจุดตรวจแยกต่างหาก

  5. การเปลี่ยนแปลงหลักคือการอัปเดตชื่อของเทนเซอร์ที่ป้อนเข้าไปในโมเดลและชื่อของเทนเซอร์ที่ดึงมาจากแบบจำลอง ในการสาธิต TensorFlow Android นี้จะสอดคล้องกับinputNameและสตริงที่มีการส่งผ่านไปยังoutputNameTensorFlowClassifier.initializeTensorFlow()

  6. นี่GraphDefคือโครงสร้างของโปรแกรมซึ่งโดยทั่วไปจะไม่เปลี่ยนแปลงตามกระบวนการฝึกอบรม จุดตรวจคือภาพรวมของสถานะของกระบวนการฝึกอบรมซึ่งโดยทั่วไปจะมีการเปลี่ยนแปลงในทุกขั้นตอนของกระบวนการฝึกอบรม ด้วยเหตุนี้ TensorFlow จึงใช้รูปแบบการจัดเก็บที่แตกต่างกันสำหรับข้อมูลประเภทนี้และ API ระดับต่ำให้วิธีต่างๆในการบันทึกและโหลด ไลบรารีระดับสูงขึ้นเช่นMetaGraphDefไลบรารีKerasและskflowสร้างกลไกเหล่านี้เพื่อมอบวิธีที่สะดวกยิ่งขึ้นในการบันทึกและกู้คืนโมเดลทั้งหมด


นี่หมายความว่าเอกสารC ++ APIอยู่เมื่อมีการระบุว่าคุณสามารถโหลดกราฟที่บันทึกด้วยtf.train.write_graph()แล้วดำเนินการได้หรือไม่
mnicky

2
เอกสาร C ++ API ไม่ได้โกหก แต่ไม่มีรายละเอียดบางอย่าง รายละเอียดที่สำคัญที่สุดคือนอกจากที่GraphDefบันทึกไว้tf.train.write_graph()แล้วคุณยังต้องจำชื่อของเทนเซอร์ที่คุณต้องการป้อนและดึงข้อมูลเมื่อเรียกใช้กราฟ (ข้อ 5 ด้านบน)
mrry

@mrry: ฉันพยายามใช้ตัวอย่าง DeepDream ของเทนเซอร์โฟลว์ แต่ดูเหมือนว่าจะต้องใช้โมเดลที่กำหนดไว้ล่วงหน้าในรูปแบบ pb! ฉันใช้ตัวอย่าง Cifar10 แต่มันสร้างจุดตรวจเท่านั้น! ฉันไม่พบไฟล์ pb หรืออะไรเลย! ฉันจะแปลงจุดตรวจเป็นรูปแบบ pb ที่ตัวอย่าง deepdream ใช้ได้อย่างไร
Rika

2
@ Coderx7 ฉันคิดว่าคุณไม่สามารถแปลง. ckpt เป็น. pb ได้เนื่องจากจุดตรวจมีเพียงน้ำหนักและตัวแปรและไม่รู้อะไรเกี่ยวกับโครงสร้างของกราฟ
davidivad

1
มีรหัสง่ายๆในการโหลดไฟล์. pb แล้วเรียกใช้หรือไม่
ก้อง

โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.