TensorFlow ทำไมถึงมี 3 ไฟล์หลังจากบันทึกโมเดล


113

หลังจากอ่านเอกสารฉันได้บันทึกโมเดลไว้TensorFlowนี่คือรหัสสาธิตของฉัน:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

แต่หลังจากนั้นฉันพบว่ามี 3 ไฟล์

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

และฉันไม่สามารถกู้คืนโมเดลด้วยการกู้คืนmodel.ckptไฟล์เนื่องจากไม่มีไฟล์ดังกล่าว นี่คือรหัสของฉัน

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

เหตุใดจึงมี 3 ไฟล์?


2
คุณทราบวิธีแก้ไขปัญหานี้หรือไม่? ฉันจะโหลดโมเดลอีกครั้ง (โดยใช้ Keras) ได้อย่างไร?
rajkiran

คำตอบ:


116

ลองสิ่งนี้:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
    saver.restore(sess, "/tmp/model.ckpt")

TensorFlow บันทึกวิธีการบันทึกสามชนิดของไฟล์เพราะมันเก็บโครงสร้างกราฟแยกต่างหากจากค่าตัวแปร .metaไฟล์อธิบายโครงสร้างกราฟที่บันทึกไว้เพื่อให้คุณจำเป็นต้องนำเข้ามันก่อนที่จะคืนด่าน (มิฉะนั้นจะไม่ทราบว่าตัวแปรค่าด่านบันทึกไว้ตรงตามลักษณะที่)

หรือคุณสามารถทำได้:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/tmp/model.ckpt")

แม้ว่าจะไม่มีชื่อไฟล์model.ckptแต่คุณยังคงอ้างถึงจุดตรวจที่บันทึกไว้ด้วยชื่อนั้นเมื่อกู้คืน จากsaver.pyซอร์สโค้ด :

ผู้ใช้จะต้องโต้ตอบกับคำนำหน้าที่ผู้ใช้ระบุเท่านั้น ... แทนชื่อพา ธ ทางกายภาพใด ๆ


1
ดังนั้นจึงไม่ใช้. ดัชนีและ. data? 2 ไฟล์นั้นถูกใช้เมื่อใด
ajfbiw.s

26
@ ajfbiw.s .meta เก็บโครงสร้างกราฟ,. data เก็บค่าของตัวแปรแต่ละตัวในกราฟ, .index ระบุ checkpiont ดังนั้นในตัวอย่างด้านบน: import_meta_graph ใช้. meta และ saver.restore ใช้. data และ. index
TK Bartel

อ้อเข้าใจแล้ว. ขอบคุณ
ajfbiw.s

1
มีโอกาสที่คุณจะบันทึกโมเดลด้วย TensorFlow เวอร์ชันที่แตกต่างจากที่คุณใช้โหลดหรือไม่ ( github.com/tensorflow/tensorflow/issues/5639 )
TK Bartel

5
ไม่มีใครรู้ว่าสิ่งที่ 00000และ00001หมายเลขหมายความว่าอย่างไร ในvariables.data-?????-of-?????ไฟล์
Ivan Talalaev

55
  • ไฟล์เมตา : อธิบายโครงสร้างกราฟที่บันทึกไว้รวมถึง GraphDef, SaverDef และอื่น ๆ จากนั้นให้ใช้tf.train.import_meta_graph('/tmp/model.ckpt.meta')จะเรียกคืนและSaverGraph

  • ไฟล์ดัชนี : เป็นตารางสตริงที่ไม่เปลี่ยนรูป (tensorflow :: table :: Table) แต่ละคีย์เป็นชื่อของเทนเซอร์และค่าของมันคือ BundleEntryProto แบบอนุกรม BundleEntryProto แต่ละรายการจะอธิบายข้อมูลเมตาของเทนเซอร์: ไฟล์ "ข้อมูล" ใดที่มีเนื้อหาของเทนเซอร์, ออฟเซ็ตในไฟล์นั้น, การตรวจสอบ, ข้อมูลเสริมบางอย่างเป็นต้น

  • ไฟล์ข้อมูล : เป็นคอลเลกชัน TensorBundle บันทึกค่าของตัวแปรทั้งหมด


ฉันมีไฟล์ pb ที่ฉันมีสำหรับการจัดประเภทรูปภาพ ฉันสามารถใช้เพื่อจัดหมวดหมู่วิดีโอแบบเรียลไทม์ได้หรือไม่

คุณช่วยแจ้งให้เราทราบได้ไหมว่าการใช้ Keras 2 ฉันจะโหลดโมเดลได้อย่างไรหากบันทึกเป็น 3 ไฟล์
rajkiran

5

ฉันกำลังการฟื้นฟู embeddings คำได้รับการฝึกฝนจากWord2Vec tensorflow กวดวิชา

ในกรณีที่คุณสร้างจุดตรวจหลายจุด:

เช่นไฟล์ที่สร้างขึ้นมีลักษณะเช่นนี้

model.ckpt-55695.data-00000 ของ 00001

model.ckpt-55695.index

model.ckpt-55695.meta

ลองดู

def restore_session(self, session):
   saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
   saver.restore(session, './tmp/model.ckpt-55695')

เมื่อเรียก restore_session ():

def test_word2vec():
   opts = Options()    
   with tf.Graph().as_default(), tf.Session() as session:
       with tf.device("/cpu:0"):            
           model = Word2Vec(opts, session)
           model.restore_session(session)
           model.get_embedding("assistance")

"00000-of-00001" ใน "model.ckpt-55695.data-00000-of-00001" หมายความว่าอย่างไร
hafiz031

0

ตัวอย่างเช่นหากคุณฝึก CNN ด้วยการออกกลางคันคุณสามารถทำได้:

def predict(image, model_name):
    """
    image -> single image, (width, height, channels)
    model_name -> model file that was saved without any extensions
    """
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./' + model_name + '.meta')
        saver.restore(sess, './' + model_name)
        # Substitute 'logits' with your model
        prediction = tf.argmax(logits, 1)
        # 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
        return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})
โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.