สถาปัตยกรรมของซีเอ็นเอ็นเพื่อการถดถอย?


32

ฉันกำลังทำงานกับปัญหาการถดถอยที่อินพุตเป็นภาพและฉลากเป็นค่าต่อเนื่องระหว่าง 80 และ 350 ภาพเป็นสารเคมีบางอย่างหลังจากปฏิกิริยาเกิดขึ้น สีที่ปรากฎออกมาบ่งบอกถึงความเข้มข้นของสารเคมีอื่นที่เหลืออยู่และนั่นคือสิ่งที่แบบจำลองนั้นต้องการออก - ความเข้มข้นของสารเคมีนั้น ภาพสามารถหมุนพลิกสะท้อนและออกที่คาดหวังควรจะยังคงเหมือนเดิม การวิเคราะห์แบบนี้ทำในห้องปฏิบัติการจริง (เครื่องจักรพิเศษมากให้ความเข้มข้นของสารเคมีโดยใช้การวิเคราะห์สีเช่นเดียวกับที่ฉันกำลังฝึกรุ่นนี้ให้ทำ)

จนถึงตอนนี้ฉันได้ทดลองกับแบบจำลองโดยอ้างอิงจาก VGG (หลายลำดับของบล็อก Conv-conv-conv-conv-pool) ก่อนที่จะทำการทดลองกับสถาปัตยกรรมที่ใหม่กว่านี้ (Inception, ResNets ฯลฯ ) ฉันคิดว่าฉันจะทำการวิจัยถ้ามีสถาปัตยกรรมอื่น ๆ ที่ใช้กันโดยทั่วไปสำหรับการถดถอยโดยใช้รูปภาพ

ชุดข้อมูลมีลักษณะดังนี้:

ป้อนคำอธิบายรูปภาพที่นี่

ชุดข้อมูลมีตัวอย่างประมาณ 5,000 250x250 ตัวอย่างซึ่งฉันได้ปรับขนาดเป็น 64x64 เพื่อให้การฝึกอบรมง่ายขึ้น เมื่อฉันพบสถาปัตยกรรมที่มีแนวโน้มฉันจะทดลองกับภาพความละเอียดที่ใหญ่ขึ้น

จนถึงตอนนี้โมเดลที่ดีที่สุดของฉันมีข้อผิดพลาดกำลังสองเฉลี่ยทั้งชุดการฝึกอบรมและการตรวจสอบความถูกต้องประมาณ 0.3 ซึ่งอยู่ไกลจากที่ยอมรับได้ในกรณีใช้งานของฉัน

รุ่นที่ดีที่สุดของฉันมีลักษณะเช่นนี้:

// pseudo code
x = conv2d(x, filters=32, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=32, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=32, kernel=[3,3])->batch_norm()->relu()
x = maxpool(x, size=[2,2], stride=[2,2])

x = conv2d(x, filters=64, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=64, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=64, kernel=[3,3])->batch_norm()->relu()
x = maxpool(x, size=[2,2], stride=[2,2])

x = conv2d(x, filters=128, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=128, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=128, kernel=[3,3])->batch_norm()->relu()
x = maxpool(x, size=[2,2], stride=[2,2])

x = dropout()->conv2d(x, filters=128, kernel=[1, 1])->batch_norm()->relu()
x = dropout()->conv2d(x, filters=32, kernel=[1, 1])->batch_norm()->relu()

y = dense(x, units=1)

// loss = mean_squared_error(y, labels)

คำถาม

สถาปัตยกรรมที่เหมาะสมสำหรับเอาท์พุทการถดถอยจากอินพุตภาพคืออะไร?

แก้ไข

ฉันได้ใช้คำอธิบายใหม่และลบความแม่นยำออกแล้ว

แก้ไข 2

ฉันได้ปรับโครงสร้างคำถามของฉันดังนั้นหวังว่าจะชัดเจนในสิ่งที่ฉันหลังจาก


4
ความแม่นยำไม่ใช่ตัวชี้วัดที่สามารถนำไปใช้โดยตรงกับปัญหาการถดถอย คุณหมายถึงอะไรเมื่อคุณพูดว่าความแม่นยำของคุณคือ 30% ความแม่นยำใช้กับงานการจัดหมวดหมู่เท่านั้นไม่ใช่การถดถอย
นิวเคลียร์วัง

1
คุณหมายถึงอะไรโดย"คาดการณ์ได้อย่างถูกต้อง 30% ของเวลา" ? คุณกำลังถดถอยจริงหรือ
Firebug

1
เหตุใดคุณจึงเรียกใช้การถดถอยปัญหานี้ คุณไม่พยายามจัดประเภทเป็นป้ายกำกับใช่หรือไม่ ฉลากสำคัญหรือไม่
Aksakal

2
ฉันไม่ต้องการสิ่งเดียวกันกับ vgg ฉันกำลังทำอะไรที่คล้ายกัน vgg หมายถึงชุดของ convs ตามด้วยการรวมกำไรสูงสุดตามด้วยการเชื่อมต่ออย่างสมบูรณ์ ดูเหมือนว่าวิธีการทั่วไปสำหรับการทำงานกับรูปภาพ แต่แล้วอีกครั้งนั่นคือจุดรวมของคำถามเดิมของฉัน ดูเหมือนว่าทุกความคิดเห็นเหล่านี้แม้ว่าจะมีความลึกซึ้งต่อฉัน แต่ก็พลาดจุดที่ฉันถามมาตั้งแต่แรก
rodrigo-silveira

1
นอกจากนี้เราอาจสามารถให้ความช่วยเหลือได้ดียิ่งขึ้นหากคุณให้คำอธิบายปัญหาได้ดีขึ้น 1) ภาพอะไรบ้าง? ความละเอียดของพวกเขาคืออะไร? สิ่งที่มีความสัมพันธ์ที่มีระหว่างภาพและการตอบสนองของคุณ ? คือความสัมพันธ์นี้หมุนคงที่คือถ้าฉันหมุนภาพวงกลมของคุณโดยมุมโดยพลθทำผมคาดว่าปีการเปลี่ยนแปลง? 2) คุณทราบหรือไม่ว่าภาพ 5000 ภาพในการฝึกอบรมสถาปัตยกรรม VGG-net นั้นเป็นความทุกข์ยาก? คุณคำนวณจำนวนพารามิเตอร์ของสถาปัตยกรรมของคุณหรือไม่ มีวิธีใดบ้างที่คุณจะได้ภาพมากกว่านี้? ถ้าคุณทำไม่ได้คุณอาจต้องการ ...y[80,350]θy
DeltaIV

คำตอบ:


42

ก่อนอื่นคำแนะนำทั่วไป: ทำการค้นหาวรรณกรรมก่อนที่จะเริ่มทำการทดสอบในหัวข้อที่คุณไม่คุ้นเคย คุณจะประหยัดเวลาได้มาก

ในกรณีนี้ดูเอกสารที่มีอยู่คุณอาจสังเกตเห็นว่า

  1. ซีเอ็นเอ็นถูกนำมาใช้หลายครั้งสำหรับการถดถอย: นี่เป็นแบบคลาสสิก แต่มันเก่า (ใช่ 3 ปีใน DL) กระดาษที่ทันสมัยกว่าจะไม่ใช้ AlexNet สำหรับงานนี้ นี่เป็นเรื่องล่าสุด แต่เป็นปัญหาที่ซับซ้อนกว่ามาก (การหมุน 3D) และฉันก็ไม่คุ้นเคย
  2. การถดถอยด้วย CNN ไม่ใช่ปัญหาเล็กน้อย ดูอีกครั้งในกระดาษแรกคุณจะเห็นว่าพวกเขามีปัญหาที่พวกเขาสามารถสร้างข้อมูลที่ไม่มีที่สิ้นสุด วัตถุประสงค์ของพวกเขาคือเพื่อทำนายมุมการหมุนที่จำเป็นในการแก้ไขภาพ 2D ซึ่งหมายความว่าฉันสามารถใช้ชุดการฝึกอบรมของฉันและขยายมันได้โดยการหมุนแต่ละภาพด้วยมุมที่กำหนดเองและฉันจะได้ชุดการฝึกที่ถูกต้องและใหญ่กว่า ดังนั้นดูเหมือนว่าปัญหาจะค่อนข้างง่ายตราบใดที่ปัญหาการเรียนรู้ลึกดำเนินไป โดยวิธีการสังเกตเทคนิคการเสริมข้อมูลอื่น ๆ ที่พวกเขาใช้:

    เราใช้การแปล (สูงถึง 5% ของความกว้างของภาพ), การปรับความสว่างในช่วง [−0.2, 0.2], การปรับแกมม่าด้วยγ∈ [−0.5, 0.1] และสัญญาณรบกวนแบบเกาส์ที่มีค่าเบี่ยงเบนมาตรฐานในช่วง [0 , 0.02]

    k

    yxα=atan2(y,x)>11%ของข้อผิดพลาดที่เป็นไปได้สูงสุด พวกเขาทำได้ดีขึ้นเล็กน้อยโดยใช้สองเครือข่ายในซีรีส์: อันแรกจะทำการจำแนก (คาดการณ์ว่ามุมจะอยู่ในหรือชั้นเรียนจากนั้นภาพที่หมุนตามจำนวนที่คาดการณ์โดยเครือข่ายแรกจะถูกป้อนไปยังเครือข่ายประสาทอื่น (สำหรับการถดถอยในครั้งนี้) ซึ่งจะทำนายการหมุนเพิ่มเติมขั้นสุดท้ายในช่วง[180°,90°],[90°,0°],[0°,90°][90°,180°][45°,45°]

    ในปัญหาที่ง่ายกว่ามาก (หมุน MNIST) คุณจะได้สิ่งที่ดีกว่าแต่ก็ยังไม่ได้อยู่ใต้ข้อผิดพลาด RMSE ซึ่งเป็นของข้อผิดพลาดที่เป็นไปได้สูงสุด2.6%

ดังนั้นเราสามารถเรียนรู้อะไรจากสิ่งนี้ ก่อนอื่นรูปภาพ 5000 ภาพเป็นชุดข้อมูลขนาดเล็กสำหรับงานของคุณ กระดาษแผ่นแรกใช้เครือข่ายที่ถูกฝึกบนภาพคล้ายกับที่พวกเขาต้องการที่จะเรียนรู้งานการถดถอย: ไม่เพียง แต่คุณจะต้องเรียนรู้งานที่แตกต่างจากงานที่ออกแบบสถาปัตยกรรม (การจำแนก) แต่ชุดฝึกอบรมของคุณไม่ ไม่มองอะไรเลยเหมือนชุดฝึกอบรมที่มักใช้กับเครือข่ายเหล่านี้ (CIFAR-10/100 หรือ ImageNet) ดังนั้นคุณอาจไม่ได้รับผลประโยชน์ใด ๆ จากการเรียนรู้การถ่ายโอน ตัวอย่างของ MATLAB มีรูป 5000 รูป แต่ภาพเหล่านั้นเป็นสีดำและสีขาวและมีความหมายคล้ายกันมาก

ถ้าเช่นนั้นความเป็นจริงนั้นทำได้ดีกว่า 0.3 อย่างไร ก่อนอื่นเราต้องเข้าใจว่าคุณหมายถึงอะไรโดยการสูญเสียเฉลี่ย 0.3 คุณหมายถึงว่าข้อผิดพลาด RMSE คือ 0.3 หรือไม่

1Ni=1N(h(xi)yi)2

โดยที่คือขนาดของชุดฝึกอบรมของคุณ (เช่น ),เป็นผลลัพธ์ของ CNN ของคุณสำหรับรูปภาพและคือความเข้มข้นที่สอดคล้องกันของสารเคมีหรือไม่ เนื่องจากจากนั้นสมมติว่าคุณคลิปการคาดการณ์ของ CNN ของคุณระหว่าง 80 และ 350 (หรือคุณใช้ logit เพื่อทำให้พอดีในช่วงเวลานั้น) คุณจะได้รับข้อผิดพลาดน้อยกว่าจริงจังคุณคาดหวังอะไร ดูเหมือนว่าฉันจะไม่ผิดพลาดใหญ่เลยNN<5000h(xi)xiyiyi[80,350]0.12%

นอกจากนี้เพียงแค่พยายามคำนวณจำนวนพารามิเตอร์ในเครือข่ายของคุณ: ฉันกำลังรีบและฉันอาจจะทำผิดพลาดโง่ดังนั้นโดยทั้งหมดตรวจสอบการคำนวณของฉันด้วยsummaryฟังก์ชั่นบางอย่างจากกรอบสิ่งที่คุณอาจใช้ อย่างไรก็ตามฉันจะบอกว่าคุณมีประมาณ

9×(3×32+2×32×32+32×64+2×64×64+64×128+2×128×128)+128×128+128×32+32×32×32=533344

(หมายเหตุฉันข้ามพารามิเตอร์ของเลเยอร์ norm แบทช์ แต่พวกเขาเป็นเพียง 4 พารามิเตอร์สำหรับเลเยอร์ดังนั้นพวกเขาจึงไม่สร้างความแตกต่าง) คุณมีพารามิเตอร์ครึ่งล้านและตัวอย่าง 5,000 รายการ ... คุณคาดหวังอะไร แน่นอนว่าจำนวนพารามิเตอร์ไม่ได้เป็นตัวบ่งชี้ที่ดีสำหรับความจุของเครือข่ายประสาทเทียม (เป็นแบบจำลองที่ไม่สามารถระบุตัวตนได้) แต่ถึงกระนั้น ... ฉันไม่คิดว่าคุณจะทำได้ดีกว่านี้มาก แต่คุณสามารถลอง บางสิ่ง:

  • ทำให้อินพุตทั้งหมดเป็นปกติ (ตัวอย่างเช่น rescale ความเข้มของ RGB ของแต่ละพิกเซลระหว่าง -1 ถึง 1 หรือใช้มาตรฐาน) และเอาต์พุตทั้งหมด สิ่งนี้จะเป็นประโยชน์อย่างยิ่งหากคุณมีปัญหาการลู่เข้า
  • ไปที่โทนสีเทา: สิ่งนี้จะลดช่องสัญญาณอินพุตของคุณจาก 3 เป็น 1 ภาพทั้งหมดของคุณดูเหมือน (ถึงดวงตาที่ไม่ได้รับการฝึกฝนอย่างสูงของฉัน) ให้เป็นสีที่ค่อนข้างคล้ายกัน คุณแน่ใจหรือว่ามันเป็นสีที่ใช้ในการทำนายและไม่ใช่การมีอยู่ของพื้นที่ที่มืดหรือสว่างกว่า? บางทีคุณแน่ใจ (ฉันไม่ใช่ผู้เชี่ยวชาญ): ในกรณีนี้ให้ข้ามคำแนะนำนี้y
  • การเพิ่มข้อมูล: เนื่องจากคุณบอกว่าการพลิกหมุนโดยมุมใดก็ได้หรือการทำมิเรอร์รูปภาพของคุณควรทำให้ได้ผลลัพธ์ที่เหมือนกันคุณสามารถเพิ่มขนาดของชุดข้อมูลได้มากขึ้น โปรดทราบว่าด้วยชุดข้อมูลที่ใหญ่กว่าข้อผิดพลาดในชุดการฝึกอบรมจะเพิ่มขึ้น: สิ่งที่เรากำลังมองหาที่นี่คือช่องว่างเล็ก ๆ ระหว่างชุดการสูญเสียชุดการฝึกอบรมและชุดการสูญเสียการทดสอบ นอกจากนี้หากการสูญเสียชุดการฝึกอบรมเพิ่มขึ้นเป็นจำนวนมากนี่อาจเป็นข่าวดี: อาจหมายความว่าคุณสามารถฝึกอบรมเครือข่ายที่ลึกกว่าในชุดการฝึกอบรมที่ใหญ่กว่านี้ได้โดยไม่ต้องเสี่ยงกับการล้น ลองเพิ่มเลเยอร์เพิ่มเติมและดูว่าตอนนี้คุณได้รับชุดฝึกอบรมที่เล็กลงและชุดทดสอบสูญเสียหรือไม่ ในที่สุดคุณสามารถลองใช้เทคนิคการเติมข้อมูลอื่น ๆ ที่ฉันอ้างถึงข้างต้นได้หากมันสมเหตุสมผลในบริบทของแอปพลิเคชันของคุณ
  • ใช้เคล็ดลับการจัดหมวดหมู่แล้วถดถอย: เครือข่ายแรกกำหนดว่าควรเป็นหนึ่งใน 10 พูดเช่นฯลฯ เครือข่ายที่สองจะคำนวณการแก้ไข : การจัดกึ่งกลางและการทำให้เป็นมาตรฐานอาจช่วยได้เช่นกัน พูดไม่ได้ถ้าไม่ลองy[80,97],[97,124][0,27]
  • ลองใช้สถาปัตยกรรมสมัยใหม่ (Inception or ResNet) แทนสถาปัตยกรรมโบราณ ResNet มีพารามิเตอร์น้อยกว่า VGG-net แน่นอนว่าคุณต้องการใช้ ResNets ขนาดเล็กที่นี่ - ฉันไม่คิดว่า ResNet-101 สามารถช่วยในชุดข้อมูลภาพ 5000 ภาพ คุณสามารถเพิ่มชุดข้อมูลได้มากแม้ว่า ....
  • เนื่องจากผลลัพธ์ของคุณไม่แปรผันกับการหมุนความคิดที่ดีอีกอย่างหนึ่งก็คือการใช้CNNsแบบแบ่งกลุ่มอย่างใดอย่างหนึ่งซึ่งผลลัพธ์ (เมื่อใช้เป็นตัวแยกประเภท) จะไม่แปรผันกับการหมุนแบบไม่ต่อเนื่องหรือCNNs ที่นำพาได้ซึ่งเอาท์พุทไม่เปลี่ยนแปลงจากการหมุนอย่างต่อเนื่อง คุณสมบัติ invariance จะช่วยให้คุณได้รับผลลัพธ์ที่ดีโดยมีการเพิ่มข้อมูลน้อยลงหรือไม่มีสิ่งใดเลย (สำหรับสิ่งที่เกี่ยวข้องกับการหมุน: แน่นอนว่าคุณยังต้องใช้ da ประเภทอื่น ๆ ) ซีเอ็นเอ็นที่มีความเสมอภาคของกลุ่มนั้นมีความเป็นผู้ใหญ่มากกว่าซีเอ็นเอ็นที่นำออกมาได้จากมุมมองการนำไปใช้ดังนั้นฉันจะลองซีเอ็นเอ็นของกลุ่มก่อน คุณสามารถลองการจำแนกประเภทแล้ว - การถดถอยโดยใช้ G-CNN สำหรับส่วนการจัดประเภทหรือคุณอาจทดลองด้วยวิธีการถดถอยที่บริสุทธิ์ อย่าลืมเปลี่ยนเลเยอร์บนสุดตามลำดับ
  • ทดลองกับขนาดของชุดงาน (ใช่แล้วใช่ฉันรู้ว่าการแฮ็คพารามิเตอร์สูงเกินไปไม่เจ๋ง แต่นี่เป็นสิ่งที่ดีที่สุดที่ฉันสามารถทำได้ในช่วงเวลา จำกัด และฟรี :-)
  • ในที่สุดก็มีสถาปัตยกรรมที่ได้รับการพัฒนาเป็นพิเศษเพื่อให้การทำนายที่แม่นยำด้วยชุดข้อมูลขนาดเล็ก ที่สุดของพวกเขาใช้convolutions พอง : ตัวอย่างที่มีชื่อเสียงหนึ่งคือผสมขนาดหนาแน่นสับสนเครือข่ายประสาท การดำเนินการไม่ได้เป็นเรื่องเล็กน้อย

3
ขอบคุณสำหรับคำตอบโดยละเอียด ฉันได้ทำการเพิ่มข้อมูลที่สำคัญแล้ว พยายามใช้สองรูปแบบของโมเดลการลงทะเบียน เห็นการปรับปรุงที่เหลือเชื่อ ยังคงมีวิธีที่จะไป ฉันจะลองคำแนะนำของคุณสักเล็กน้อย ขอบคุณอีกครั้ง.
rodrigo-silveira

@ rodrigo-silveira คุณยินดีแจ้งให้เราทราบว่ามันไปอย่างไร บางทีเราสามารถพูดคุยในการแชทเมื่อคุณได้ผลลัพธ์
DeltaIV

1
คำตอบที่ดีสมควรได้รับมากขึ้น ^
Gilly

1
สงบมาก!
Karthik Thiagarajan

1
ฉันจะให้คะแนน 10k กับคุณถ้าทำได้ คำตอบที่น่าอัศจรรย์
Boppity Bop
โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.