tf.app.run () ทำงานอย่างไร


148

วิธีการที่ไม่tf.app.run()ทำงานใน Tensorflow แปลสาธิต?

ในtensorflow/models/rnn/translate/translate.pyมีการเรียกร้องให้tf.app.run()เป็น มันถูกจัดการอย่างไร?

if __name__ == "__main__":
    tf.app.run() 

คำตอบ:


134
if __name__ == "__main__":

หมายถึงไฟล์ปัจจุบันจะถูกดำเนินการภายใต้เปลือกแทนที่จะนำเข้าเป็นโมดูล

tf.app.run()

อย่างที่คุณเห็นผ่านไฟล์ app.py

def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS

  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None

  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access

  main = main or sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  sys.exit(main(sys.argv[:1] + flags_passthrough))

มาแบ่งกันทีละบรรทัด:

flags_passthrough = f._parse_flags(args=args)

สิ่งนี้ทำให้มั่นใจได้ว่าอาร์กิวเมนต์ที่คุณผ่านบรรทัดคำสั่งนั้นถูกต้องเช่นที่ python my_model.py --data_dir='...' --max_iteration=10000จริงแล้วคุณสมบัตินี้มีการใช้งานตามargparseโมดูลมาตรฐานของไพธ อน

main = main or sys.modules['__main__'].main

ครั้งแรกที่mainด้านขวาของ เป็นอาร์กิวเมนต์แรกของการทำงานในปัจจุบัน= run(main=None, argv=None)ในขณะที่sys.modules['__main__']หมายถึงไฟล์ที่กำลังทำงานอยู่ (เช่นmy_model.py)

ดังนั้นจึงมีสองกรณี:

  1. คุณไม่มีmainฟังก์ชั่นในmy_model.pyจากนั้นคุณต้องโทรออกtf.app.run(my_main_running_function)

  2. คุณมีmainฟังก์ชั่นmy_model.pyค่ะ (นี่เป็นกรณีส่วนใหญ่)

บรรทัดสุดท้าย:

sys.exit(main(sys.argv[:1] + flags_passthrough))

ทำให้แน่ใจว่าคุณmain(argv)หรือmy_main_running_function(argv)ฟังก์ชั่นถูกเรียกด้วยการแยกวิเคราะห์อย่างถูกต้อง


67
ชิ้นส่วนปริศนาที่หายไปสำหรับผู้ใช้ Tensorflow เริ่มต้น: Tensorflow มีกลไกการจัดการแฟล็กบรรทัดคำสั่งในตัว คุณสามารถกำหนดค่าสถานะของคุณเช่นtf.flags.DEFINE_integer('batch_size', 128, 'Number of images to process in a batch.')และจากนั้นหากคุณใช้tf.app.run()มันจะตั้งค่าสิ่งต่าง ๆ เพื่อให้คุณสามารถเข้าถึงค่าที่ส่งผ่านของค่าสถานะที่คุณกำหนดไว้เช่นtf.flags.FLAGS.batch_sizeจากที่ใดก็ตามที่คุณต้องการในรหัสของคุณ
isarandi

1
นี่เป็นคำตอบที่ดีกว่าของสาม (ในปัจจุบัน) ในความคิดของฉัน มันอธิบายถึง "tf.app.run () ทำงานอย่างไร" ในขณะที่อีกสองคำตอบเพียงแค่พูดในสิ่งที่มันทำ
Thomas Fauskanger

ดูเหมือนว่าธงจะได้รับการจัดการโดยabseilที่ TF ต้องดูดซับabseil.io/docs/python/guides/flags
34322

75

มันเป็นเพียง wrapper ที่รวดเร็วมากที่จัดการการแยกวิเคราะห์สถานะ ดูรหัส


12
หมายความว่า "จัดการการแยกวิเคราะห์ธง" บางทีคุณสามารถเพิ่มลิงค์เพื่อแจ้งผู้เริ่มต้นได้
Pinocchio

4
แยกวิเคราะห์อาร์กิวเมนต์บรรทัดคำสั่งที่ให้กับโปรแกรมโดยใช้แพ็กเกจแฟล็ก (ซึ่งใช้ไลบรารี 'argparse' มาตรฐานภายใต้หน้าปกพร้อมด้วยโปรแกรมเสริม) มันเชื่อมโยงจากรหัสที่ฉันเชื่อมโยงกับคำตอบของฉัน
dga

1
ใน app.py สิ่งที่ทำ main = main or sys.modules['__main__'].mainและ sys.exit(main(sys.argv[:1] + flags_passthrough))หมายความว่าอย่างไร
hAcKnRoCk

3
นี้ดูเหมือนว่าแปลกให้ฉันทำไมตัดฟังก์ชั่นหลักในทุกสิ่งที่ถ้าคุณเพียงแค่สามารถเรียกมันโดยตรงmain()?
Charlie Parker

2
hAcKnRoCk: หากไม่มีไฟล์หลักในไฟล์มันจะใช้ไฟล์ใด ๆ ใน sys.modules [' main '] .main sys.exit หมายถึงการเรียกใช้คำสั่งหลักที่พบโดยใช้ args และแฟล็กใด ๆ ที่ผ่านและเพื่อออกด้วยค่าส่งคืนของ main @CharlieParker - สำหรับความเข้ากันได้กับไลบรารีแอพไพ ธ อนที่มีอยู่ของ Google เช่น gflags และ google-apputils ดูตัวอย่างเช่นgithub.com/google/google-apputils
DGA

8

tf.appไม่มีอะไรพิเศษในการเป็น นี่เป็นเพียงสคริปต์จุดเริ่มต้นทั่วไปเท่านั้น

เรียกใช้โปรแกรมด้วยฟังก์ชัน 'main' และรายการ 'argv' ที่เป็นทางเลือก

มันไม่มีส่วนเกี่ยวข้องกับโครงข่ายประสาทและมันก็เรียกฟังก์ชั่นหลักผ่านข้อโต้แย้งใด ๆ


5

ในแง่ง่ายงานของtf.app.run()คือการแรกตั้งธงทั่วโลกสำหรับการใช้งานในภายหลังเช่น:

from tensorflow.python.platform import flags
f = flags.FLAGS

จากนั้นเรียกใช้ฟังก์ชันหลักที่กำหนดเองของคุณด้วยชุดอาร์กิวเมนต์

ตัวอย่างเช่นในฐานรหัสTensorFlow NMTจุดเริ่มต้นแรกสำหรับการดำเนินการโปรแกรมสำหรับการฝึกอบรม / การอนุมานเริ่มต้นที่จุดนี้ (ดูรหัสด้านล่าง)

if __name__ == "__main__":
  nmt_parser = argparse.ArgumentParser()
  add_arguments(nmt_parser)
  FLAGS, unparsed = nmt_parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

หลังจากแยกวิเคราะห์ข้อโต้แย้งโดยใช้argparseเมื่อtf.app.run()คุณเรียกใช้ฟังก์ชัน "main" ซึ่งกำหนดไว้เช่น:

def main(unused_argv):
  default_hparams = create_hparams(FLAGS)
  train_fn = train.train
  inference_fn = inference.inference
  run_main(FLAGS, default_hparams, train_fn, inference_fn)

ดังนั้นหลังจากตั้งค่าสถานะสำหรับการใช้งานทั่วโลกtf.app.run()เพียงแค่เรียกmainใช้ฟังก์ชันที่คุณส่งไปพร้อมกับargvเป็นพารามิเตอร์

PS: ในฐานะที่เป็นคำตอบของ Salvador Daliกล่าวว่าเป็นเพียงการปฏิบัติงานวิศวกรรมซอฟต์แวร์ที่ดีผมว่าถึงแม้ว่าฉันไม่แน่ใจว่า TensorFlow ดำเนินการเพิ่มประสิทธิภาพการทำงานใด ๆ ของmainฟังก์ชั่นกว่าที่ทำงานโดยใช้ CPython ปกติ


2

รหัส Google ขึ้นอยู่กับการตั้งค่าสถานะส่วนกลางที่เข้าถึงในไลบรารี / ไบนารี / สคริปต์หลามและ tf.app.run () จะแยกวิเคราะห์สถานะเหล่านั้นเพื่อสร้างสถานะโกลบอลในรูปแบบ FLAG (หรือตัวแปรอื่น ๆ ที่คล้ายกัน) จากนั้นเรียกไพ ธ เมน ) เท่าที่ควร

หากพวกเขาไม่มีการเรียก tf.app.run () นี้ผู้ใช้อาจลืมทำการแยกวิเคราะห์ FLAGs ซึ่งนำไปสู่ไลบรารี / ไบนารี / สคริปต์เหล่านี้ซึ่งไม่สามารถเข้าถึง FLAG ที่ต้องการได้


1

2.0 คำตอบที่เข้ากันได้ : ถ้าคุณต้องการที่จะใช้tf.app.run()ในการTensorflow 2.0ที่เราควรจะใช้คำสั่ง

tf.compat.v1.app.run()หรือคุณสามารถใช้tf_upgrade_v2ในการแปลงรหัส1.x2.0

โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.