ฉันกำลังเล่นกับ convnets เล็กน้อย โดยเฉพาะฉันใช้ชุดข้อมูล cats-vs-dogs kaggle ซึ่งประกอบไปด้วยรูปภาพ 25,000 ภาพที่มีป้ายกำกับว่าเป็น cat หรือ dog (12500 อัน)
ฉันจัดการเพื่อให้บรรลุความถูกต้องจำแนก 85% ในชุดทดสอบของฉัน แต่ฉันกำหนดเป้าหมายของการบรรลุความแม่นยำ 90%
ปัญหาหลักของฉันคือ overfitting อย่างใดก็มักจะเกิดขึ้นเสมอ (ปกติหลังจากยุค 8-10) สถาปัตยกรรมของเครือข่ายของฉันได้รับแรงบันดาลใจมาจาก VGG-16 โดยเฉพาะอย่างยิ่งภาพของฉันได้รับการปรับขนาดเป็นจากนั้นฉันเรียกใช้:
Convolution 1 128x128x32 (kernel size is 3, strides is 1)
Convolution 2 128x128x32 (kernel size is 3, strides is 1)
Max pool 1 64x64x32 (kernel size is 2, strides is 2)
Convolution 3 64x64x64 (kernel size is 3, strides is 1)
Convolution 4 64x64x64 (kernel size is 3, strides is 1)
Max pool 2 32x32x64 (kernel size is 2, strides is 2)
Convolution 5 16x16x128 (kernel size is 3, strides is 1)
Convolution 6 16x16x128 (kernel size is 3, strides is 1)
Max pool 3 8x8x128 (kernel size is 2, strides is 2)
Convolution 7 8x8x256 (kernel size is 3, strides is 1)
Max pool 4 4x4x256 (kernel size is 2, strides is 2)
Convolution 8 4x4x512 (kernel size is 3, strides is 1)
Fully connected layer 1024 (dropout 0.5)
Fully connected layer 1024 (dropout 0.5)
เลเยอร์ทั้งหมดยกเว้นอันสุดท้ายมี relus เป็นฟังก์ชันเปิดใช้
โปรดทราบว่าฉันได้ลองใช้ชุดค่าผสมที่แตกต่างกัน (ฉันเริ่มด้วยชุดรูปแบบที่ง่ายกว่า)
นอกจากนี้ฉันได้เพิ่มชุดข้อมูลโดยทำมิเรอร์รูปภาพเพื่อให้โดยรวมแล้วฉันมี 50,000 ภาพ
นอกจากนี้ฉันกำลังทำให้รูปภาพเป็นปกติโดยใช้การปรับค่าสูงสุดขั้นต่ำโดยที่ X คือรูปภาพ
รหัสถูกเขียนใน tensorflow และขนาดของชุดงานคือ 128
ชุดข้อมูลการฝึกอบรมจบลงด้วยการล้นและมีความแม่นยำ 100% ในขณะที่ข้อมูลการตรวจสอบดูเหมือนว่าจะหยุดเรียนที่ประมาณ 84-85%
ฉันได้พยายามเพิ่ม / ลดอัตราการออกกลางคัน
เครื่องมือเพิ่มประสิทธิภาพที่ใช้คือ AdamOptimizer ด้วยอัตราการเรียนรู้ 0.0001
ในขณะนี้ฉันได้เล่นกับปัญหานี้ในช่วง 3 สัปดาห์ที่ผ่านมาและ 85% ดูเหมือนจะเป็นอุปสรรคต่อหน้าฉัน
สำหรับบันทึกนี้ฉันรู้ว่าฉันสามารถใช้การเรียนรู้การถ่ายโอนเพื่อให้ได้ผลลัพธ์ที่สูงขึ้นมาก แต่ฉันสนใจที่จะสร้างเครือข่ายนี้เป็นประสบการณ์การเรียนรู้ด้วยตนเอง
ปรับปรุง:
ฉันกำลังใช้งานเครือข่ายเดียวกันด้วยขนาดแบทช์ที่แตกต่างกันในกรณีนี้ฉันใช้ขนาดแบตช์ที่เล็กกว่ามาก (16 แทน 128) จนถึงตอนนี้ฉันถึงความแม่นยำ 87.5% (แทนที่จะเป็น 85%) ที่กล่าวว่าเครือข่ายจบลงด้วยการ overfitting ต่อไป ถึงกระนั้นฉันก็ยังไม่เข้าใจว่าการลดลง 50% ของหน่วยไม่ได้ช่วย ... เห็นได้ชัดว่าฉันกำลังทำอะไรผิดที่นี่ ความคิดใด ๆ
อัปเดต 2:
ดูเหมือนว่าปัญหาจะเกี่ยวข้องกับขนาดแบทช์เช่นเดียวกับขนาดที่เล็กกว่า (16 แทนที่จะเป็น 128) ตอนนี้ฉันได้รับความถูกต้องแม่นยำ 92.8% ในชุดทดสอบของฉันด้วยขนาดแบตช์ที่เล็กลง ด้วยความแม่นยำ 100%) อย่างไรก็ตามการสูญเสีย (ข้อผิดพลาด) จะลดลงและโดยทั่วไปจะมีเสถียรภาพมากขึ้น ข้อเสียคือเวลาทำงานช้าลงมาก แต่มันก็คุ้มค่ากับการรอ