รถไฟ / ทดสอบ / การตรวจสอบการตั้งค่าการแยกใน Sklearn


59

ฉันจะแยกเมทริกซ์ข้อมูลและเวกเตอร์เลเบลที่สอดคล้องกันเป็น X_train, X_test, X_val, y_train, y_test, y_test, y_val ด้วย Sklearn ได้อย่างไร เท่าที่ฉันรู้sklearn.cross_validation.train_test_splitมีเพียงความสามารถในการแยกออกเป็นสองไม่ใช่ในสาม ...

คำตอบ:


81

คุณสามารถใช้sklearn.model_selection.train_test_splitสองครั้ง ก่อนอื่นให้แยกรถไฟทดสอบและแยกรถไฟอีกครั้งเป็นการตรวจสอบและรถไฟ บางสิ่งเช่นนี้

 X_train, X_test, y_train, y_test 
    = train_test_split(X, y, test_size=0.2, random_state=1)

 X_train, X_val, y_train, y_val 
    = train_test_split(X_train, y_train, test_size=0.2, random_state=1)

1
ใช่มันใช้งานได้แน่นอน แต่ฉันหวังว่าจะมีอะไรที่หรูหรากว่า;) ไม่เป็นไรฉันยอมรับคำตอบนี้
Hendrik

1
ฉันต้องการเพิ่มว่าหากคุณต้องการใช้ชุดการตรวจสอบเพื่อค้นหาพารามิเตอร์ไฮเปอร์ที่ดีที่สุดคุณสามารถทำสิ่งต่อไปนี้หลังจากแยก: gist.github.com/albertotb/1bad123363b186267e3aeaa26610b54b
skd

12
ดังนั้นรถไฟขบวนสุดท้ายทดสอบสัดส่วนการตรวจสอบในตัวอย่างนี้คืออะไร เพราะในวินาทีtrain_test_split คุณทำสิ่งนี้กับการแบ่ง 80/20 ก่อนหน้า ดังนั้นค่าของคุณคือ 20% จาก 80% สัดส่วนที่แยกกันไม่ได้ตรงไปตรงมาในลักษณะนี้
โมนิก้า Heddneck

1
ฉันเห็นด้วยกับ @Monica Heddneck ว่ารถไฟ 64% การตรวจสอบความถูกต้อง 16% และการทดสอบ splt 20% นั้นชัดเจนขึ้น เป็นการอนุมานที่น่ารำคาญที่คุณต้องทำด้วยวิธีนี้
เพอร์รี่

32

มีคำตอบที่ดีสำหรับคำถามนี้เกี่ยวกับSOที่ใช้ numpy และ pandas

คำสั่ง (ดูคำตอบสำหรับการสนทนา):

train, validate, test = np.split(df.sample(frac=1), [int(.6*len(df)), int(.8*len(df))])

สร้างการแยก 60%, 20%, 20% สำหรับชุดการฝึกอบรมการตรวจสอบและการทดสอบ


2
ฉันสามารถเห็น.6ความหมาย 60% ... แต่สิ่งที่.8หมายถึงอะไร
Tom Hale

1
@ TomHale np.splitจะแบ่งที่ 60% ของความยาวของอาเรย์สับแล้ว 80% ของความยาว (ซึ่งเป็นข้อมูลเพิ่มเติม 20%) จึงเหลือ 20% ของข้อมูล นี่คือสาเหตุที่คำจำกัดความของฟังก์ชั่น คุณสามารถทดสอบ / เล่นกับ: x = np.arange(10.0)ตามด้วยnp.split(x, [ int(len(x)*0.6), int(len(x)*0.8)])
0_0

3

ส่วนใหญ่คุณจะพบว่าตัวเองไม่ได้แยกมันออกมา แต่ในขั้นตอนแรกคุณจะแบ่งข้อมูลของคุณในชุดการฝึกอบรมและการทดสอบ หลังจากนั้นคุณจะทำการค้นหาพารามิเตอร์โดยใช้ตัวแยกส่วนที่ซับซ้อนมากขึ้นเช่นการตรวจสอบความถูกต้องของข้อมูลด้วยอัลกอริทึม 'แยก k-fold' หรือ 'ลาออกหนึ่งครั้ง (ห่วง)'


3

คุณสามารถใช้train_test_splitสองครั้ง ฉันคิดว่านี่เป็นสิ่งที่ตรงไปตรงมาที่สุด

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=1)
X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size=0.25, random_state=1)

ด้วยวิธีนี้train, val, testชุดจะเป็น 60%, 20%, 20% ของชุดข้อมูลตามลำดับ


2

คำตอบที่ดีที่สุดข้างต้นไม่ได้กล่าวถึงว่าโดยการแยกสองครั้งโดยใช้train_test_splitไม่เปลี่ยนขนาดพาร์ติชันจะไม่ให้พาร์ทิชันเริ่มแรก:

x_train, x_remain = train_test_split(x, test_size=(val_size + test_size))

จากนั้นส่วนของการตรวจสอบความถูกต้องและชุดการทดสอบในการเปลี่ยนแปลง x_remainและสามารถนับเป็น

new_test_size = np.around(test_size / (val_size + test_size), 2)
# To preserve (new_test_size + new_val_size) = 1.0 
new_val_size = 1.0 - new_test_size

x_val, x_test = train_test_split(x_remain, test_size=new_test_size)

ในโอกาสนี้พาร์ทิชันเริ่มต้นทั้งหมดจะถูกบันทึกไว้


1

นี่คือวิธีการอื่น (สมมติว่ามีการแบ่งสามทางเท่ากัน):

# randomly shuffle the dataframe
df = df.reindex(np.random.permutation(df.index))

# how many records is one-third of the entire dataframe
third = int(len(df) / 3)

# Training set (the top third from the entire dataframe)
train = df[:third]

# Testing set (top half of the remainder two third of the dataframe)
test = df[third:][:third]

# Validation set (bottom one third)
valid = df[-third:]

สิ่งนี้สามารถทำให้รัดกุมขึ้น แต่ฉันเก็บไว้อย่างละเอียดเพื่อวัตถุประสงค์ในการอธิบาย


0

รับtrain_frac=0.8ฟังก์ชั่นนี้จะสร้างการแบ่ง 80% / 10% / 10%:

import sklearn

def data_split(examples, labels, train_frac, random_state=None):
    ''' https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html
    param data:       Data to be split
    param train_frac: Ratio of train set to whole dataset

    Randomly split dataset, based on these ratios:
        'train': train_frac
        'valid': (1-train_frac) / 2
        'test':  (1-train_frac) / 2

    Eg: passing train_frac=0.8 gives a 80% / 10% / 10% split
    '''

    assert train_frac >= 0 and train_frac <= 1, "Invalid training set fraction"

    X_train, X_tmp, Y_train, Y_tmp = sklearn.model_selection.train_test_split(
                                        examples, labels, train_size=train_frac, random_state=random_state)

    X_val, X_test, Y_val, Y_test   = sklearn.model_selection.train_test_split(
                                        X_tmp, Y_tmp, train_size=0.5, random_state=random_state)

    return X_train, X_val, X_test,  Y_train, Y_val, Y_test

0

การเพิ่มคำตอบของ @ hh32ในขณะที่เคารพสัดส่วนที่กำหนดไว้ล่วงหน้าเช่น (75, 15, 10):

train_ratio = 0.75
validation_ratio = 0.15
test_ratio = 0.10

# train is now 75% of the entire data set
# the _junk suffix means that we drop that variable completely
x_train, x_test, y_train, y_test = train_test_split(dataX, dataY, test_size=1 - train_ratio)

# test is now 10% of the initial data set
# validation is now 15% of the initial data set
x_val, x_test, y_val, y_test = train_test_split(x_test, y_test, test_size=test_ratio/(test_ratio + validation_ratio)) 

print(x_train, x_val, x_test)

0

การขยายคำตอบของ@ hh32ด้วยอัตราส่วนที่สงวนไว้

# Defines ratios, w.r.t. whole dataset.
ratio_train = 0.8
ratio_val = 0.1
ratio_test = 0.1

# Produces test split.
x_remaining, x_test, y_remaining, y_test = train_test_split(
    x, y, test_size=test_ratio)

# Adjusts val ratio, w.r.t. remaining dataset.
ratio_remaining = 1 - ratio_test
ratio_val_adjusted = ratio_val / ratio_remaining

# Produces train and val splits.
x_train, x_val, y_train, y_val = train_test_split(
    x_remaining, y_remaining, test_size=ratio_val_adjusted)

เนื่องจากชุดข้อมูลที่เหลือจะลดลงหลังจากการแยกครั้งแรกอัตราส่วนใหม่ที่เกี่ยวกับชุดข้อมูลที่ลดลงจะต้องคำนวณโดยการแก้สมการ:

RremainingRnew=Rold

โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.