ใน Tensorflow รับชื่อของ Tensors ทั้งหมดในกราฟ


118

ฉันกำลังสร้างตาข่ายประสาทด้วยTensorflowและskflow; ด้วยเหตุผลบางอย่างที่ฉันต้องการที่จะได้รับค่าของเทนเซอร์ภายในบางอย่างสำหรับการป้อนข้อมูลที่กำหนดดังนั้นฉันใช้myClassifier.get_layer_value(input, "tensorName"), การเป็นmyClassifierskflow.estimators.TensorFlowEstimator

อย่างไรก็ตามฉันพบว่ายากที่จะหาไวยากรณ์ที่ถูกต้องของชื่อเทนเซอร์แม้จะรู้ชื่อ (และฉันก็สับสนระหว่างการดำเนินการกับเทนเซอร์) ดังนั้นฉันจึงใช้เทนเซอร์บอร์ดเพื่อพล็อตกราฟและมองหาชื่อ

มีวิธีแจกแจงเทนเซอร์ทั้งหมดในกราฟโดยไม่ใช้เทนเซอร์บอร์ดหรือไม่?

คำตอบ:


189

คุณทำได้

[n.name for n in tf.get_default_graph().as_graph_def().node]

นอกจากนี้หากคุณกำลังสร้างต้นแบบในโน้ตบุ๊ก IPython คุณสามารถแสดงกราฟได้โดยตรงในโน้ตบุ๊กดูshow_graphฟังก์ชันในสมุดบันทึก Deep Dream ของ Alexander


2
คุณสามารถกรองสิ่งนี้สำหรับตัวแปรเช่นโดยเพิ่มif "Variable" in n.opที่ส่วนท้ายของความเข้าใจ
Radu

มีวิธีรับโหนดเฉพาะหรือไม่ถ้าคุณรู้ชื่อ?
Rocket Pingu

หากต้องการอ่านเพิ่มเติมเกี่ยวกับโหนดกราฟ: tensorflow.org/extend/tool_developers/#nodes
Ivan

3
คำสั่งด้านบนให้ชื่อของการดำเนินการ / โหนดทั้งหมด ในการรับชื่อของเทนเซอร์ทั้งหมดให้ทำ: tensors_per_node = [node.values ​​() สำหรับโหนดใน graph.get_operations ()] tensor_names = [tensor.name สำหรับเทนเซอร์ใน tensors_per_node สำหรับเทนเซอร์ในเทนเซอร์]
gebbissimo

24

ไม่มีทางที่จะทำมันเล็กน้อยเร็วกว่าในคำตอบของยาโรสลาฟโดยใช้เป็นget_operations นี่คือตัวอย่างสั้น ๆ :

import tensorflow as tf

a = tf.constant(1.3, name='const_a')
b = tf.Variable(3.1, name='variable_b')
c = tf.add(a, b, name='addition')
d = tf.multiply(c, a, name='multiply')

for op in tf.get_default_graph().get_operations():
    print(str(op.name))

2
คุณไม่สามารถใช้ Tensors tf.get_operations()ได้ คุณจะได้รับการดำเนินการเท่านั้น
Soulduck

14

ฉันจะพยายามสรุปคำตอบ:

ในการรับโหนดทั้งหมด(ประเภทtensorflow.core.framework.node_def_pb2.NodeDef):

all_nodes = [n for n in tf.get_default_graph().as_graph_def().node]

ในการรับopsทั้งหมด(ประเภทtensorflow.python.framework.ops.Operation):

all_ops = tf.get_default_graph().get_operations()

ในการรับตัวแปรทั้งหมด(ประเภทtensorflow.python.ops.resource_variable_ops.ResourceVariable):

all_vars = tf.global_variables()

ในการรับเทนเซอร์ทั้งหมด(ประเภทtensorflow.python.framework.ops.Tensor) :

all_tensors = [tensor for op in tf.get_default_graph().get_operations() for tensor in op.values()]

11

tf.all_variables() สามารถรับข้อมูลที่คุณต้องการได้

นอกจากนี้การกระทำนี้ทำในวันนี้ใน TensorFlow Learn ซึ่งมีฟังก์ชันget_variable_namesในตัวประมาณค่าที่คุณสามารถใช้เพื่อดึงชื่อตัวแปรทั้งหมดได้อย่างง่ายดาย


ฟังก์ชันนี้เลิกใช้งานแล้ว
CAFEBABE

8
... และผู้สืบทอดคือtf.global_variables()
bluenote10

11
สิ่งนี้ดึงเฉพาะตัวแปรไม่ใช่เทนเซอร์
Rajarshee Mitra

ใน Tensorflow 1.9.0 แสดงให้เห็นว่าall_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02
stackoverYC

5

ฉันคิดว่าสิ่งนี้จะทำเกินไป:

print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))

แต่เมื่อเทียบกับคำตอบของ Salvado และ Yaroslav แล้วฉันไม่รู้ว่าอันไหนดีกว่ากัน


อันนี้ใช้งานได้กับกราฟที่นำเข้าจากไฟล์ frozen_inference_graph.pb ที่ใช้ใน API การตรวจจับวัตถุ tensorflow ขอบคุณ
simo23

4

คำตอบที่ยอมรับจะให้รายการสตริงที่มีชื่อเท่านั้น ฉันชอบแนวทางอื่นซึ่งช่วยให้คุณสามารถเข้าถึงเทนเซอร์ได้โดยตรง (เกือบ)

graph = tf.get_default_graph()
list_of_tuples = [op.values() for op in graph.get_operations()]

list_of_tuplesตอนนี้มีทุกเทนเซอร์แต่ละตัวอยู่ในทูเพิล คุณยังสามารถดัดแปลงเพื่อรับเทนเซอร์ได้โดยตรง:

graph = tf.get_default_graph()
list_of_tuples = [op.values()[0] for op in graph.get_operations()]

4

เนื่องจาก OP ขอรายชื่อเทนเซอร์แทนรายการของการดำเนินการ / โหนดรหัสจึงควรแตกต่างกันเล็กน้อย:

graph = tf.get_default_graph()    
tensors_per_node = [node.values() for node in graph.get_operations()]
tensor_names = [tensor.name for tensors in tensors_per_node for tensor in tensors]

3

คำตอบก่อนหน้านี้ดีฉันแค่อยากแบ่งปันฟังก์ชันยูทิลิตี้ที่ฉันเขียนเพื่อเลือก Tensors จากกราฟ:

def get_graph_op(graph, and_conds=None, op='and', or_conds=None):
    """Selects nodes' names in the graph if:
    - The name contains all items in and_conds
    - OR/AND depending on op
    - The name contains any item in or_conds

    Condition starting with a "!" are negated.
    Returns all ops if no optional arguments is given.

    Args:
        graph (tf.Graph): The graph containing sought tensors
        and_conds (list(str)), optional): Defaults to None.
            "and" conditions
        op (str, optional): Defaults to 'and'. 
            How to link the and_conds and or_conds:
            with an 'and' or an 'or'
        or_conds (list(str), optional): Defaults to None.
            "or conditions"

    Returns:
        list(str): list of relevant tensor names
    """
    assert op in {'and', 'or'}

    if and_conds is None:
        and_conds = ['']
    if or_conds is None:
        or_conds = ['']

    node_names = [n.name for n in graph.as_graph_def().node]

    ands = {
        n for n in node_names
        if all(
            cond in n if '!' not in cond
            else cond[1:] not in n
            for cond in and_conds
        )}

    ors = {
        n for n in node_names
        if any(
            cond in n if '!' not in cond
            else cond[1:] not in n
            for cond in or_conds
        )}

    if op == 'and':
        return [
            n for n in node_names
            if n in ands.intersection(ors)
        ]
    elif op == 'or':
        return [
            n for n in node_names
            if n in ands.union(ors)
        ]

ดังนั้นหากคุณมีกราฟพร้อม ops:

['model/classifier/dense/kernel',
'model/classifier/dense/kernel/Assign',
'model/classifier/dense/kernel/read',
'model/classifier/dense/bias',
'model/classifier/dense/bias/Assign',
'model/classifier/dense/bias/read',
'model/classifier/dense/MatMul',
'model/classifier/dense/BiasAdd',
'model/classifier/ArgMax/dimension',
'model/classifier/ArgMax']

จากนั้นวิ่ง

get_graph_op(tf.get_default_graph(), ['dense', '!kernel'], 'or', ['Assign'])

ผลตอบแทน:

['model/classifier/dense/kernel/Assign',
'model/classifier/dense/bias',
'model/classifier/dense/bias/Assign',
'model/classifier/dense/bias/read',
'model/classifier/dense/MatMul',
'model/classifier/dense/BiasAdd']

โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.