วิธีการที่ไม่tf.app.run()
ทำงานใน Tensorflow แปลสาธิต?
ในtensorflow/models/rnn/translate/translate.py
มีการเรียกร้องให้tf.app.run()
เป็น มันถูกจัดการอย่างไร?
if __name__ == "__main__":
tf.app.run()
วิธีการที่ไม่tf.app.run()
ทำงานใน Tensorflow แปลสาธิต?
ในtensorflow/models/rnn/translate/translate.py
มีการเรียกร้องให้tf.app.run()
เป็น มันถูกจัดการอย่างไร?
if __name__ == "__main__":
tf.app.run()
คำตอบ:
if __name__ == "__main__":
หมายถึงไฟล์ปัจจุบันจะถูกดำเนินการภายใต้เปลือกแทนที่จะนำเข้าเป็นโมดูล
tf.app.run()
อย่างที่คุณเห็นผ่านไฟล์ app.py
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
f = flags.FLAGS
# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None
# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access
main = main or sys.modules['__main__'].main
# Call the main function, passing through any arguments
# to the final program.
sys.exit(main(sys.argv[:1] + flags_passthrough))
มาแบ่งกันทีละบรรทัด:
flags_passthrough = f._parse_flags(args=args)
สิ่งนี้ทำให้มั่นใจได้ว่าอาร์กิวเมนต์ที่คุณผ่านบรรทัดคำสั่งนั้นถูกต้องเช่นที่
python my_model.py --data_dir='...' --max_iteration=10000
จริงแล้วคุณสมบัตินี้มีการใช้งานตามargparse
โมดูลมาตรฐานของไพธ อน
main = main or sys.modules['__main__'].main
ครั้งแรกที่main
ด้านขวาของ
เป็นอาร์กิวเมนต์แรกของการทำงานในปัจจุบัน=
run(main=None, argv=None)
ในขณะที่sys.modules['__main__']
หมายถึงไฟล์ที่กำลังทำงานอยู่ (เช่นmy_model.py
)
ดังนั้นจึงมีสองกรณี:
คุณไม่มีmain
ฟังก์ชั่นในmy_model.py
จากนั้นคุณต้องโทรออกtf.app.run(my_main_running_function)
คุณมีmain
ฟังก์ชั่นmy_model.py
ค่ะ (นี่เป็นกรณีส่วนใหญ่)
บรรทัดสุดท้าย:
sys.exit(main(sys.argv[:1] + flags_passthrough))
ทำให้แน่ใจว่าคุณmain(argv)
หรือmy_main_running_function(argv)
ฟังก์ชั่นถูกเรียกด้วยการแยกวิเคราะห์อย่างถูกต้อง
abseil
ที่ TF ต้องดูดซับabseil.io/docs/python/guides/flags
main = main or sys.modules['__main__'].main
และ sys.exit(main(sys.argv[:1] + flags_passthrough))
หมายความว่าอย่างไร
main()
?
tf.app
ไม่มีอะไรพิเศษในการเป็น นี่เป็นเพียงสคริปต์จุดเริ่มต้นทั่วไปเท่านั้น
เรียกใช้โปรแกรมด้วยฟังก์ชัน 'main' และรายการ 'argv' ที่เป็นทางเลือก
มันไม่มีส่วนเกี่ยวข้องกับโครงข่ายประสาทและมันก็เรียกฟังก์ชั่นหลักผ่านข้อโต้แย้งใด ๆ
ในแง่ง่ายงานของtf.app.run()
คือการแรกตั้งธงทั่วโลกสำหรับการใช้งานในภายหลังเช่น:
from tensorflow.python.platform import flags
f = flags.FLAGS
จากนั้นเรียกใช้ฟังก์ชันหลักที่กำหนดเองของคุณด้วยชุดอาร์กิวเมนต์
ตัวอย่างเช่นในฐานรหัสTensorFlow NMTจุดเริ่มต้นแรกสำหรับการดำเนินการโปรแกรมสำหรับการฝึกอบรม / การอนุมานเริ่มต้นที่จุดนี้ (ดูรหัสด้านล่าง)
if __name__ == "__main__":
nmt_parser = argparse.ArgumentParser()
add_arguments(nmt_parser)
FLAGS, unparsed = nmt_parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
หลังจากแยกวิเคราะห์ข้อโต้แย้งโดยใช้argparse
เมื่อtf.app.run()
คุณเรียกใช้ฟังก์ชัน "main" ซึ่งกำหนดไว้เช่น:
def main(unused_argv):
default_hparams = create_hparams(FLAGS)
train_fn = train.train
inference_fn = inference.inference
run_main(FLAGS, default_hparams, train_fn, inference_fn)
ดังนั้นหลังจากตั้งค่าสถานะสำหรับการใช้งานทั่วโลกtf.app.run()
เพียงแค่เรียกmain
ใช้ฟังก์ชันที่คุณส่งไปพร้อมกับargv
เป็นพารามิเตอร์
PS: ในฐานะที่เป็นคำตอบของ Salvador Daliกล่าวว่าเป็นเพียงการปฏิบัติงานวิศวกรรมซอฟต์แวร์ที่ดีผมว่าถึงแม้ว่าฉันไม่แน่ใจว่า TensorFlow ดำเนินการเพิ่มประสิทธิภาพการทำงานใด ๆ ของmain
ฟังก์ชั่นกว่าที่ทำงานโดยใช้ CPython ปกติ
รหัส Google ขึ้นอยู่กับการตั้งค่าสถานะส่วนกลางที่เข้าถึงในไลบรารี / ไบนารี / สคริปต์หลามและ tf.app.run () จะแยกวิเคราะห์สถานะเหล่านั้นเพื่อสร้างสถานะโกลบอลในรูปแบบ FLAG (หรือตัวแปรอื่น ๆ ที่คล้ายกัน) จากนั้นเรียกไพ ธ เมน ) เท่าที่ควร
หากพวกเขาไม่มีการเรียก tf.app.run () นี้ผู้ใช้อาจลืมทำการแยกวิเคราะห์ FLAGs ซึ่งนำไปสู่ไลบรารี / ไบนารี / สคริปต์เหล่านี้ซึ่งไม่สามารถเข้าถึง FLAG ที่ต้องการได้
2.0 คำตอบที่เข้ากันได้ : ถ้าคุณต้องการที่จะใช้tf.app.run()
ในการTensorflow 2.0
ที่เราควรจะใช้คำสั่ง
tf.compat.v1.app.run()
หรือคุณสามารถใช้tf_upgrade_v2
ในการแปลงรหัส1.x
2.0
tf.flags.DEFINE_integer('batch_size', 128, 'Number of images to process in a batch.')
และจากนั้นหากคุณใช้tf.app.run()
มันจะตั้งค่าสิ่งต่าง ๆ เพื่อให้คุณสามารถเข้าถึงค่าที่ส่งผ่านของค่าสถานะที่คุณกำหนดไว้เช่นtf.flags.FLAGS.batch_size
จากที่ใดก็ตามที่คุณต้องการในรหัสของคุณ