แบบจำลองการถดถอยโลจิสติกอย่างง่ายบรรลุความแม่นยำในการจำแนกประเภท 92% สำหรับ MNIST อย่างไร


64

แม้ว่าภาพทั้งหมดในชุดข้อมูล MNIST จะอยู่กึ่งกลาง แต่มีขนาดใกล้เคียงกันและไม่มีการหมุน แต่ก็มีการเปลี่ยนแปลงของลายมือที่สำคัญที่ไขปริศนาว่าแบบจำลองเชิงเส้นบรรลุความแม่นยำในการจำแนกสูงอย่างไร

เท่าที่ฉันสามารถมองเห็นได้เนื่องจากความแปรปรวนของลายมือที่สำคัญตัวเลขควรแยกกันไม่ออกเป็นเส้นตรงในพื้นที่มิติ 784 กล่าวคือควรมีความซับซ้อนเล็กน้อย (แม้ว่าจะไม่ซับซ้อนมาก) ไม่ใช่ขอบเขตเชิงเส้นที่แยกตัวเลขที่แตกต่างกัน คล้ายกับตัวอย่างอ้างถึงเป็นอย่างดีซึ่งคลาสบวกและลบไม่สามารถคั่นด้วยตัวแยกประเภทเชิงเส้นใด ๆ ดูเหมือนจะทำให้ฉันงงงวยว่าการถดถอยโลจิสติกหลายระดับนั้นให้ความแม่นยำสูงด้วยคุณลักษณะเชิงเส้นอย่างสิ้นเชิงได้อย่างไร (ไม่มีคุณสมบัติพหุนาม)XOR

ยกตัวอย่างเช่นเมื่อกำหนดพิกเซลในภาพความแตกต่างของตัวเลขและเขียนด้วยลายมือจะทำให้พิกเซลนั้นสว่างหรือไม่ ดังนั้นกับชุดของน้ำหนักเรียนรู้แต่ละพิกเซลสามารถทำให้ดูเป็นหลักเป็นเช่นเดียวกับ3เท่านั้นที่มีการรวมกันของค่าพิกเซลมันควรจะเป็นไปได้ที่จะบอกว่าไม่ว่าจะเป็นหลักเป็นหรือ3สิ่งนี้เป็นจริงสำหรับคู่หลักส่วนใหญ่ ดังนั้นการถดถอยแบบลอจิสติกเป็นอย่างไรซึ่งสุ่มเลือกการตัดสินใจของแต่ละพิกเซลอย่างอิสระ (โดยไม่พิจารณาการพึ่งพาระหว่างพิกเซลใด ๆ เลย) สามารถบรรลุความแม่นยำสูงได้232323

ฉันรู้ว่าฉันผิดที่ใดที่หนึ่งหรือแค่ประเมินความแปรปรวนของภาพมากเกินไป อย่างไรก็ตามมันจะดีมากถ้ามีคนช่วยฉันด้วยสัญชาตญาณว่าตัวเลขจะแยกออกเป็นเส้นตรงได้อย่างไร


ลองดูที่ตำราเรียนสถิติการเรียนรู้ที่มี Sparsity: The Lasso และ Generalisation 3.3.1 ตัวอย่าง: Handwritten Digits web.stanford.edu/~hastie/StatLearnSparsity_files/SLS.pdf
Adrian

ฉันอยากรู้อยากเห็น: แบบจำลองเชิงเส้นที่ถูกลงโทษ (เช่น glmnet) ทำปัญหาได้ดีเพียงใด? หากฉันจำได้ว่าสิ่งที่คุณกำลังรายงานคือความแม่นยำที่ไม่เป็นตัวอย่าง
หน้าผา AB

คำตอบ:


82

tl; drแม้ว่านี่จะเป็นชุดข้อมูลการจำแนกภาพ แต่ก็ยังคงเป็นงานที่ง่ายมากซึ่งหนึ่งสามารถค้นหาการแมปโดยตรงจากอินพุตไปยังการทำนายได้อย่างง่ายดาย


ตอบ:

นี่เป็นคำถามที่น่าสนใจมากและต้องขอบคุณความเรียบง่ายของการถดถอยโลจิสติกที่คุณสามารถหาคำตอบได้

78478428×28

ทราบอีกครั้งว่าสิ่งเหล่านี้มีน้ำหนัก

ทีนี้ลองดูที่ภาพด้านบนแล้วจดจ่อกับตัวเลขสองหลักแรก (เช่นศูนย์และหนึ่ง) ตุ้มน้ำหนักสีน้ำเงินหมายความว่าความเข้มของพิกเซลนี้มีส่วนช่วยอย่างมากสำหรับคลาสนั้นและค่าสีแดงหมายความว่ามันมีส่วนช่วยในทางลบ

0

1

2378

ด้วยวิธีนี้คุณจะเห็นได้ว่าการถดถอยโลจิสติกส์มีโอกาสดีมากที่จะได้รับภาพจำนวนมากและนั่นเป็นสาเหตุที่คะแนนสูงมาก


รหัสในการทำซ้ำตัวเลขด้านบนนั้นค่อนข้างเก่า แต่นี่คุณไป:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)

9
2378

13
แน่นอนว่ามันช่วยให้ตัวอย่างของ MNIST นั้นอยู่กึ่งกลางปรับขนาดและปรับความเปรียบต่างก่อนที่ตัวแยกประเภทจะเห็น คุณไม่ต้องตอบคำถามเช่น "จะเกิดอะไรขึ้นถ้าขอบของศูนย์เป็นจริงผ่านกลางกล่อง" เนื่องจากตัวประมวลผลล่วงหน้าได้ผ่านไปนานแล้วทำให้ศูนย์ทั้งหมดดูเหมือนกัน
ฮอบส์

1
@EricDuminil ฉันได้เพิ่มการยกย่องสคริปต์ด้วยคำแนะนำของคุณ ขอบคุณมากสำหรับการป้อนข้อมูล! : D
Djib2011

1
@NishishAgarwal หากคุณคิดว่าคำตอบนี้เป็นคำตอบสำหรับคำถามของคุณให้ลองทำเครื่องหมายเช่นนั้น
sintax

7
สำหรับคนที่สนใจ แต่ไม่คุ้นเคยเป็นพิเศษกับการประมวลผลประเภทนี้คำตอบนี้ให้ตัวอย่างที่ยอดเยี่ยมสำหรับกลศาสตร์
chrylis -on strike-
โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.