ฝึกการแบตช์ใน Tensorflow


11

ขณะนี้ฉันกำลังพยายามฝึกอบรมโมเดลด้วยไฟล์ csv ขนาดใหญ่ (> 70GB ที่มีมากกว่า 60 ล้านแถว) หากต้องการทำเช่นนั้นฉันกำลังใช้ tf.contrib.learn.read_batch_examples ฉันดิ้นรนในการทำความเข้าใจว่าฟังก์ชั่นนี้อ่านข้อมูลได้อย่างไร หากฉันใช้ขนาดแบทช์เป็น 50,000 เช่นนั้นจะอ่านไฟล์ 50,000 บรรทัดแรกหรือไม่ หากฉันต้องการวนซ้ำไฟล์ทั้งหมด (1 ตอน) ฉันต้องใช้ num_rows / batch_size = 1.200 จำนวนขั้นตอนสำหรับเมธอด estimator.fit หรือไม่

นี่คือฟังก์ชั่นอินพุตที่ฉันใช้ในปัจจุบัน:

def input_fn(file_names, batch_size):
    # Read csv files and create examples dict
    examples_dict = read_csv_examples(file_names, batch_size)

    # Continuous features
    feature_cols = {k: tf.string_to_number(examples_dict[k],
                                           out_type=tf.float32) for k in CONTINUOUS_COLUMNS}

    # Categorical features
    feature_cols.update({
                            k: tf.SparseTensor(
                                indices=[[i, 0] for i in range(examples_dict[k].get_shape()[0])],
                                values=examples_dict[k],
                                shape=[int(examples_dict[k].get_shape()[0]), 1])
                            for k in CATEGORICAL_COLUMNS})

    label = tf.string_to_number(examples_dict[LABEL_COLUMN], out_type=tf.int32)

    return feature_cols, label


def read_csv_examples(file_names, batch_size):
    def parse_fn(record):
        record_defaults = [tf.constant([''], dtype=tf.string)] * len(COLUMNS)

        return tf.decode_csv(record, record_defaults)

    examples_op = tf.contrib.learn.read_batch_examples(
        file_names,
        batch_size=batch_size,
        queue_capacity=batch_size*2.5,
        reader=tf.TextLineReader,
        parse_fn=parse_fn,
        #read_batch_size= batch_size,
        #randomize_input=True,
        num_threads=8
    )

    # Important: convert examples to dict for ease of use in `input_fn`
    # Map each header to its respective column (COLUMNS order
    # matters!
    examples_dict_op = {}
    for i, header in enumerate(COLUMNS):
        examples_dict_op[header] = examples_op[:, i]

    return examples_dict_op

นี่คือรหัส im ที่ใช้สำหรับฝึกโมเดล:

def train_and_eval():
"""Train and evaluate the model."""

m = build_estimator(model_dir)
m.fit(input_fn=lambda: input_fn(train_file_name, batch_size), steps=steps)

จะเกิดอะไรขึ้นถ้าฉันเรียกฟังก์ชันพอดีอีกครั้งด้วย input_fn เดียวกัน มันเริ่มต้นที่จุดเริ่มต้นของไฟล์อีกครั้งหรือจะจำบรรทัดที่มันหยุดครั้งล่าสุดได้หรือไม่


ฉันพบmedium.com/@ilblackdragon/เป็นประโยชน์ในการแบทช์ภายใน tensorflow input_fn
fistynuts

เหยาเหยาตรวจสอบอันนั้นแล้ว? stackoverflow.com/questions/37091899/…
Frankstr

คำตอบ:


1

เนื่องจากยังไม่มีคำตอบฉันต้องการพยายามให้คำตอบที่เป็นประโยชน์อย่างน้อยที่สุด การรวมคำจำกัดความค่าคงที่จะช่วยให้บิตเข้าใจรหัสที่ระบุ

โดยทั่วไปการพูดชุดงานใช้ n ครั้งบันทึกหรือรายการ วิธีที่คุณกำหนดรายการขึ้นอยู่กับปัญหาของคุณ ในเทนเซอร์ไหลแบทช์จะถูกเข้ารหัสในมิติแรกของเทนเซอร์ ในกรณีของคุณด้วยไฟล์ csv อาจเป็นทีละบรรทัด ( reader=tf.TextLineReader) มันสามารถเรียนรู้โดยคอลัมน์ แต่ฉันไม่คิดว่านี่เกิดขึ้นในรหัสของคุณ หากคุณต้องการรถไฟกับชุดข้อมูลทั้งหมด (= หนึ่งยุค ) numBatches=numItems/batchSizeคุณสามารถทำได้โดยใช้

โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.