จะแสดงรายการการดำเนินการทั้งหมดที่ใช้ใน Tensorflow SavedModel ได้อย่างไร


10

หากฉันบันทึกแบบจำลองของฉันโดยใช้tensorflow.saved_model.saveฟังก์ชันในรูปแบบ SavedModel ฉันจะดึง Tensorflow Ops ที่ใช้ในรุ่นนี้ได้อย่างไร เนื่องจากโมเดลสามารถกู้คืนการดำเนินการเหล่านี้ถูกเก็บไว้ในกราฟฉันเดาว่าอยู่ในsaved_model.pbไฟล์ ถ้าฉันโหลด protobuf นี้ (ไม่ใช่โมเดลทั้งหมด) ส่วนไลบรารีของ protobuf จะแสดงรายการเหล่านี้ แต่นี่ยังไม่ได้บันทึกและติดแท็กเป็นคุณลักษณะทดลองสำหรับตอนนี้ โมเดลที่สร้างใน Tensorflow 1.x จะไม่มีส่วนนี้

ดังนั้นวิธีที่รวดเร็วและเชื่อถือได้ในการดึงรายการการดำเนินการที่ใช้แล้ว (Like MatchingFilesหรือWriteFile) จากแบบจำลองในรูปแบบ SavedModel คืออะไร

ตอนนี้ฉันสามารถหยุดสิ่งทั้งปวงtensorflowjs-converterได้ ขณะที่พวกเขายังตรวจสอบการดำเนินงานที่รองรับ ในปัจจุบันนี้ไม่ทำงานเมื่อ LSTM อยู่ในรูปแบบที่เห็นนี่ มีวิธีที่ดีกว่าในการทำเช่นนี้ใน Ops แน่นอนหรือไม่

ตัวอย่างแบบ:

class FileReader(tf.Module):

@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
    input_scalar = tf.reshape(file_name, [])
    output = tf.io.read_file(input_scalar)
    return tf.stack([output], name='content')

file_reader = FileReader()

tf.saved_model.save(file_reader, 'file_reader')

คาดว่าจะส่งออก Ops ทั้งหมดที่มีในกรณีนี้อย่างน้อย:


1
มันยากที่จะบอกว่าสิ่งที่คุณต้องการคืออะไรsaved_model.pbมันคืออะไรtf.GraphDefหรือSavedModelข้อความ protobuf? หากคุณมีtf.GraphDefที่เรียกว่าคุณจะได้รับรายชื่อของปฏิบัติการที่ใช้กับgd หากคุณมีรูปแบบการโหลดที่คุณสามารถทำได้sorted(set(n.op for n in gd.node)) sorted(set(op.type for op in tf.get_default_graph().get_operations()))ถ้าเป็นSavedModelคุณสามารถได้รับtf.GraphDefจากมัน (เช่นsaved_model.meta_graphs[0].graph_def)
jdehesa

ฉันต้องการดึงตัวเลือกจาก SavedModel ที่เก็บไว้ ดังนั้นตัวเลือกสุดท้ายที่คุณกำลังอธิบาย อะไรคือsaved_modelตัวแปรในตัวอย่างสุดท้ายของคุณหรือไม่ ผลลัพธ์ของtf.saved_model.load('/path/to/model')หรือโหลด protobuf ของไฟล์ save_model.pb
sampers

คำตอบ:


1

ถ้าsaved_model.pbเป็นSavedModelข้อความ protobuf คุณจะได้รับการดำเนินการโดยตรงจากที่นั่น สมมติว่าเราสร้างแบบจำลองดังต่อไปนี้:

import tensorflow as tf

class FileReader(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
    def read_disk(self, file_name):
        input_scalar = tf.reshape(file_name, [])
        output = tf.io.read_file(input_scalar)
        return tf.stack([output], name='content')

file_reader = FileReader()
tf.saved_model.save(file_reader, 'tmp')

ตอนนี้เราสามารถค้นหาการดำเนินการที่ใช้โดยรุ่นดังกล่าวดังนี้

from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

saved_model = SavedModel()
with open('tmp/saved_model.pb', 'rb') as f:
    saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one
for meta_graph in saved_model.meta_graphs:
    # Add operations in the graph definition
    model_op_names.update(node.op for node in meta_graph.graph_def.node)
    # Go through the functions in the graph definition
    for func in meta_graph.graph_def.library.function:
        # Add operations in each function
        model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
print(*model_op_names, sep='\n')
# Const
# Identity
# MergeV2Checkpoints
# NoOp
# Pack
# PartitionedCall
# Placeholder
# ReadFile
# Reshape
# RestoreV2
# SaveV2
# ShardedFilename
# StatefulPartitionedCall
# StringJoin

ฉันลองอะไรเช่นนี้ แต่น่าเสียดายที่นี่ไม่ใช่สิ่งที่ฉันคาดหวัง: บอกว่าฉันมีแบบจำลองที่ทำสิ่งนี้: input_scalar = tf.reshape(file_name, []) output = tf.io.read_file(input_scalar) return tf.stack([output], name='content')จากนั้น ReadFile Op ตามที่ระบุไว้ที่นี่จะอยู่ในนั้น แต่ไม่มีการพิมพ์ออกมา
sampers

1
@sampers ฉันได้แก้ไขคำตอบด้วยตัวอย่างเช่นคุณแนะนำ ฉันจะได้รับการReadFileดำเนินการในการส่งออก เป็นไปได้หรือไม่ว่าในกรณีของคุณการดำเนินการนั้นไม่ได้อยู่ระหว่างอินพุตและเอาต์พุตของโมเดลที่บันทึกไว้? ในกรณีนี้ฉันคิดว่ามันอาจถูกตัด
jdehesa

แน่นอนกับรุ่นที่กำหนดมันใช้งานได้ น่าเสียดายสำหรับโมดูลที่สร้างขึ้นใน tf2 มันไม่ได้ ถ้าฉันสร้าง tf.Module ด้วย 1 ฟังก์ชั่นที่มีคำอธิบายประกอบfile_nameอาร์กิวเมนต์@tf.functionซึ่งมีการโทรที่ฉันระบุไว้ในความคิดเห็นก่อนหน้าของฉันก็จะให้รายการดังต่อไปนี้:Const, NoOp, PartitionedCall, Placeholder, StatefulPartitionedCall
sampers

เพิ่มแบบจำลองในคำถามของฉัน
sampers

@sampers ฉันได้อัปเดตคำตอบของฉันแล้ว ก่อนหน้านี้ฉันเคยใช้ TF 1.x ฉันไม่คุ้นเคยกับการเปลี่ยนแปลงของวัตถุนิยามกราฟใน TF 2.x ฉันคิดว่าคำตอบตอนนี้ครอบคลุมทุกอย่างในรูปแบบที่บันทึกไว้ ฉันคิดว่าการดำเนินการที่สอดคล้องกับฟังก์ชั่น Python ที่คุณเขียนนั้นอยู่ในsaved_model.meta_graphs[0].graph_def.library.function[0](การnode_defรวบรวมภายในวัตถุฟังก์ชันนั้น)
jdehesa
โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.