พารามิเตอร์ class_weight ใน scikit-learn ทำงานอย่างไร


116

ฉันมีปัญหามากในการทำความเข้าใจว่าclass_weightพารามิเตอร์ใน Logistic Regression ของ scikit-learn ทำงานอย่างไร

สถานการณ์

ฉันต้องการใช้การถดถอยโลจิสติกเพื่อทำการจำแนกไบนารีในชุดข้อมูลที่ไม่สมดุลมาก ชั้นเรียนมีข้อความว่า 0 (ลบ) และ 1 (บวก) และข้อมูลที่สังเกตได้อยู่ในอัตราส่วนประมาณ 19: 1 โดยกลุ่มตัวอย่างส่วนใหญ่มีผลลบ

ความพยายามครั้งแรก: เตรียมข้อมูลการฝึกอบรมด้วยตนเอง

ฉันแบ่งข้อมูลที่ฉันมีเป็นชุดที่ไม่ปะติดปะต่อสำหรับการฝึกอบรมและการทดสอบ (ประมาณ 80/20) จากนั้นฉันสุ่มตัวอย่างข้อมูลการฝึกด้วยมือเพื่อรับข้อมูลการฝึกอบรมในสัดส่วนที่แตกต่างจาก 19: 1; ตั้งแต่ 2: 1 -> 16: 1

จากนั้นฉันได้ฝึกการถดถอยโลจิสติกในชุดย่อยข้อมูลการฝึกอบรมที่แตกต่างกันเหล่านี้และการเรียกคืนแบบพล็อต (= TP / (TP + FN)) เป็นฟังก์ชันของสัดส่วนการฝึกที่แตกต่างกัน แน่นอนว่าการเรียกคืนนั้นคำนวณจากตัวอย่างการทดสอบที่ไม่ปะติดปะต่อซึ่งมีสัดส่วนที่สังเกตได้เท่ากับ 19: 1 หมายเหตุแม้ว่าฉันจะฝึกโมเดลที่แตกต่างกันในข้อมูลการฝึกอบรมที่แตกต่างกันฉันก็คำนวณการเรียกคืนสำหรับทุกคนในข้อมูลการทดสอบ (ไม่ปะติดปะต่อ) เดียวกัน

ผลลัพธ์เป็นไปตามที่คาดไว้: การเรียกคืนประมาณ 60% ที่สัดส่วนการฝึก 2: 1 และลดลงค่อนข้างเร็วเมื่อถึงเวลา 16: 1 มีหลายสัดส่วน 2: 1 -> 6: 1 ที่การเรียกคืนสูงกว่า 5% อย่างเหมาะสม

ความพยายามครั้งที่สอง: การค้นหาแบบกริด

ต่อไปฉันต้องการทดสอบพารามิเตอร์การกำหนดมาตรฐานที่แตกต่างกันดังนั้นฉันจึงใช้ GridSearchCV และสร้างตารางที่มีค่าCพารามิเตอร์หลายค่ารวมทั้งclass_weightพารามิเตอร์ ในการแปลสัดส่วน n: m ของฉันเป็นค่าลบ: ตัวอย่างการฝึกอบรมเชิงบวกเป็นภาษาพจนานุกรมของclass_weightฉันคิดว่าฉันระบุพจนานุกรมหลาย ๆ แบบดังนี้:

{ 0:0.67, 1:0.33 } #expected 2:1
{ 0:0.75, 1:0.25 } #expected 3:1
{ 0:0.8, 1:0.2 }   #expected 4:1

และฉันยังรวมถึงและNoneauto

คราวนี้ผลลัพท์แตกโดยสิ้นเชิง การเรียกคืนทั้งหมดของฉันออกเล็ก ๆ (<0.05) มาสำหรับค่าของทุกยกเว้นclass_weight autoดังนั้นฉันสามารถสันนิษฐานได้ว่าความเข้าใจของฉันเกี่ยวกับการตั้งค่าclass_weightพจนานุกรมนั้นไม่ถูกต้อง ที่น่าสนใจคือclass_weightค่าของ 'auto' ในการค้นหาแบบกริดอยู่ที่ประมาณ 59% สำหรับค่าทั้งหมดCและฉันเดาว่ามันจะสมดุลเป็น 1: 1?

คำถามของฉัน

  1. คุณclass_weightจะใช้ข้อมูลการฝึกอบรมที่แตกต่างกันอย่างเหมาะสมอย่างไรเพื่อให้ได้ข้อมูลการฝึกอบรมจากสิ่งที่คุณให้มาจริง โดยเฉพาะพจนานุกรมใดที่ฉันclass_weightใช้ในการใช้ n: m สัดส่วนของค่าลบ: ตัวอย่างการฝึกอบรมเชิงบวก

  2. หากคุณส่งผ่านclass_weightพจนานุกรมต่างๆไปยัง GridSearchCV ในระหว่างการตรวจสอบความถูกต้องข้ามจะปรับสมดุลข้อมูลการฝึกอบรมตามพจนานุกรม แต่ใช้สัดส่วนตัวอย่างที่กำหนดจริงสำหรับการคำนวณฟังก์ชันการให้คะแนนของฉันในพับทดสอบหรือไม่ นี่เป็นสิ่งสำคัญเนื่องจากเมตริกใด ๆ จะมีประโยชน์กับฉันก็ต่อเมื่อมันมาจากข้อมูลในสัดส่วนที่สังเกตได้

  3. สิ่งที่ไม่autoคุ้มค่าของการclass_weightทำเท่าที่เป็นสัดส่วน? ฉันอ่านเอกสารและคิดว่า "ทำให้ข้อมูลสมดุลเป็นสัดส่วนผกผันกับความถี่ของข้อมูล" หมายความว่ามันทำให้เป็น 1: 1 ถูกต้องหรือไม่ ถ้าไม่มีใครสามารถชี้แจงได้?


เมื่อใช้ class_weight ฟังก์ชัน loss จะถูกแก้ไข ตัวอย่างเช่นแทนที่จะเป็นเอนโทรปีแบบไขว้มันจะกลายเป็นเอนโทรปีแบบชั่งน้ำหนัก towardsdatascience.com/…
prashanth

คำตอบ:


123

ก่อนอื่นอาจไม่ใช่เรื่องดีที่จะเพียงแค่เรียกคืนอย่างเดียว คุณสามารถเรียกคืนได้ 100% โดยจำแนกทุกอย่างเป็นคลาสเชิงบวก ฉันมักจะแนะนำให้ใช้ AUC ในการเลือกพารามิเตอร์จากนั้นค้นหาเกณฑ์สำหรับจุดปฏิบัติการ (พูดระดับความแม่นยำที่กำหนด) ที่คุณสนใจ

สำหรับวิธีการclass_weightทำงาน: จะลงโทษข้อผิดพลาดในตัวอย่างclass[i]ด้วยclass_weight[i]แทนที่จะเป็น 1 น้ำหนักคลาสที่สูงขึ้นหมายความว่าคุณต้องการให้ความสำคัญกับคลาสมากขึ้น จากที่คุณบอกดูเหมือนว่าคลาส 0 จะบ่อยกว่าคลาส 1 ถึง 19 เท่าดังนั้นคุณควรเพิ่มclass_weightคลาส 1 เทียบกับคลาส 0 โดยพูดว่า {0: .1, 1: .9} หากclass_weightผลรวมไม่เป็น 1 โดยทั่วไปจะเปลี่ยนพารามิเตอร์การทำให้เป็นมาตรฐาน

สำหรับวิธีการclass_weight="auto"ทำงานคุณสามารถดูการสนทนานี้ ในเวอร์ชัน dev คุณสามารถใช้ได้class_weight="balanced"ซึ่งง่ายต่อการเข้าใจ: โดยพื้นฐานแล้วหมายถึงการจำลองคลาสที่เล็กกว่าจนกว่าคุณจะมีตัวอย่างมากเท่าในคลาสที่ใหญ่กว่า แต่ในทางปริยาย


1
ขอบคุณ! คำถามด่วน: ฉันพูดถึงการจำเพื่อความชัดเจนและอันที่จริงฉันกำลังพยายามตัดสินใจว่าจะใช้ AUC ใดเป็นมาตรการของฉัน ความเข้าใจของฉันคือฉันควรจะขยายพื้นที่ให้ใหญ่ที่สุดภายใต้เส้นโค้ง ROC หรือพื้นที่ภายใต้การเรียกคืนเทียบกับเส้นโค้งความแม่นยำเพื่อค้นหาพารามิเตอร์ หลังจากเลือกพารามิเตอร์ด้วยวิธีนี้ฉันเชื่อว่าฉันเลือกเกณฑ์สำหรับการจำแนกโดยเลื่อนไปตามเส้นโค้ง นี่คือสิ่งที่คุณหมายถึง? ถ้าเป็นเช่นนั้นเส้นโค้งสองเส้นใดที่เหมาะสมที่สุดในการดูว่าเป้าหมายของฉันคือการจับ TP ให้ได้มากที่สุดหรือไม่ นอกจากนี้ขอขอบคุณสำหรับผลงานและการมีส่วนร่วมของ scikit-learn !!!
kilgoretrout

1
ฉันคิดว่าการใช้ ROC จะเป็นวิธีที่เป็นมาตรฐานมากกว่า แต่ฉันไม่คิดว่าจะมีความแตกต่างอย่างมาก คุณต้องมีเกณฑ์บางอย่างเพื่อเลือกจุดบนเส้นโค้ง
Andreas Mueller

3
@MiNdFrEaK ฉันคิดว่าสิ่งที่แอนดรูหมายถึงคือตัวประมาณค่าจำลองตัวอย่างในชั้นเรียนของชนกลุ่มน้อยเพื่อให้ตัวอย่างของคลาสต่างๆมีความสมดุล เป็นเพียงการสุ่มตัวอย่างมากเกินไปโดยปริยาย
Shawn TIAN

8
@MiNdFrEaK และ Shawn Tian: ตัวแยกประเภทตาม SV จะไม่สร้างตัวอย่างของคลาสที่เล็กกว่าเมื่อคุณใช้ 'บาลานซ์' มันเป็นการลงโทษความผิดพลาดที่เกิดขึ้นในชั้นเรียนขนาดเล็กอย่างแท้จริง การพูดเป็นอย่างอื่นถือเป็นความผิดพลาดและทำให้เข้าใจผิดโดยเฉพาะในชุดข้อมูลขนาดใหญ่เมื่อคุณไม่สามารถสร้างตัวอย่างเพิ่มเติมได้ คำตอบนี้ต้องได้รับการแก้ไข
Pablo Rivas

4
scikit-learn.org/dev/glossary.html#term-class-weightน้ำหนักคลาสจะถูกใช้แตกต่างกันไปขึ้นอยู่กับอัลกอริทึม: สำหรับแบบจำลองเชิงเส้น (เช่น linear SVM หรือการถดถอยโลจิสติก) น้ำหนักคลาสจะเปลี่ยนฟังก์ชันการสูญเสียโดย การถ่วงน้ำหนักการสูญเสียของแต่ละตัวอย่างตามน้ำหนักระดับชั้น สำหรับอัลกอริทึมแบบต้นไม้จะใช้น้ำหนักคลาสสำหรับการถ่วงน้ำหนักเกณฑ์การแยกซ้ำ อย่างไรก็ตามโปรดทราบว่าการปรับสมดุลใหม่นี้ไม่ได้คำนึงถึงน้ำหนักของตัวอย่างในแต่ละคลาส
prashanth

3

คำตอบแรกคือความเข้าใจวิธีการทำงานที่ดี แต่ฉันอยากจะเข้าใจว่าฉันควรใช้มันอย่างไรในทางปฏิบัติ

สรุป

  • สำหรับข้อมูลที่ไม่สมดุลในระดับปานกลางโดยไม่มีเสียงรบกวนการใช้น้ำหนักคลาสนั้นไม่แตกต่างกันมากนัก
  • สำหรับข้อมูลที่ไม่สมดุลในระดับปานกลางซึ่งมีสัญญาณรบกวนและไม่สมดุลอย่างยิ่งควรใช้น้ำหนักคลาส
  • param class_weight="balanced"ทำงานได้ดีในกรณีที่คุณไม่ต้องการปรับให้เหมาะสมด้วยตนเอง
  • เมื่อclass_weight="balanced"คุณจับภาพเหตุการณ์จริงได้มากขึ้น (การเรียกคืน TRUE ที่สูงขึ้น) แต่คุณก็มีแนวโน้มที่จะได้รับการแจ้งเตือนที่ผิดพลาด (ความแม่นยำของ TRUE ต่ำลง)
    • เป็นผลให้% TRUE รวมอาจสูงกว่าจริงเนื่องจากผลบวกปลอมทั้งหมด
    • AUC อาจทำให้คุณเข้าใจผิดที่นี่หากการเตือนที่ผิดพลาดเป็นปัญหา
  • ไม่จำเป็นต้องเปลี่ยนเกณฑ์การตัดสินใจเป็น% ความไม่สมดุลแม้จะเกิดความไม่สมดุลอย่างมากก็สามารถเก็บ 0.5 (หรือบางที่ก็ได้ขึ้นอยู่กับสิ่งที่คุณต้องการ)

NB

ผลลัพธ์อาจแตกต่างกันเมื่อใช้ RF หรือ GBM sklearn ไม่มี class_weight="balanced"สำหรับ GBM แต่มีlightgbmLGBMClassifier(is_unbalance=False)

รหัส

# scikit-learn==0.21.3
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, classification_report
import numpy as np
import pandas as pd

# case: moderate imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.8]) #,flip_y=0.1,class_sep=0.5)
np.mean(y) # 0.2

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.184
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X).mean() # 0.296 => seems to make things worse?
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.292 => seems to make things worse?

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.83
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X)) # 0.86 => about the same
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.86 => about the same

# case: strong imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.95])
np.mean(y) # 0.06

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.02
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X).mean() # 0.25 => huh??
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.22 => huh??
(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).mean() # same as last

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.64
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X)) # 0.84 => much better
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.85 => similar to manual
roc_auc_score(y,(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).astype(int)) # same as last

print(classification_report(y,LogisticRegression(C=1e9).fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True,normalize='index') # few prediced TRUE with only 28% TRUE recall and 86% TRUE precision so 6%*28%~=2%

print(classification_report(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True,normalize='index') # 88% TRUE recall but also lot of false positives with only 23% TRUE precision, making total predicted % TRUE > actual % TRUE
โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.