sparse_softmax_cross_entropy_with_logits กับ softmax_cross_entropy_with_logits ต่างกันอย่างไร


112

ฉันเพิ่งมาข้ามtf.nn.sparse_softmax_cross_entropy_with_logitsและผมก็ไม่สามารถคิดออกสิ่งที่แตกต่างเมื่อเทียบกับtf.nn.softmax_cross_entropy_with_logits

ความแตกต่างเพียงอย่างเดียวที่เวกเตอร์การฝึกอบรมyต้องเข้ารหัสแบบร้อนเดียวเมื่อใช้sparse_softmax_cross_entropy_with_logits?

อ่าน API ที่ผมไม่สามารถที่จะค้นหาความแตกต่างอื่น ๆ softmax_cross_entropy_with_logitsเมื่อเทียบกับ แต่ทำไมเราถึงต้องการฟังก์ชันพิเศษ?

ไม่ควรsoftmax_cross_entropy_with_logitsให้ผลลัพธ์เช่นเดียวsparse_softmax_cross_entropy_with_logitsกับที่มาพร้อมกับข้อมูลการฝึกอบรม / เวกเตอร์ที่เข้ารหัสร้อนเดียวหรือไม่?


1
ฉันสนใจที่จะดูการเปรียบเทียบประสิทธิภาพของพวกเขาว่าสามารถใช้ทั้งสองอย่างได้หรือไม่ (เช่นกับป้ายกำกับภาพพิเศษ) ฉันคาดหวังว่าเวอร์ชันที่กระจัดกระจายจะมีประสิทธิภาพมากขึ้นอย่างน้อยก็ต้องใช้หน่วยความจำ
Yibo Yang

1
ดูคำถามนี้ซึ่งกล่าวถึงฟังก์ชันข้ามเอนโทรปีทั้งหมดในเทนเซอร์โฟลว์ (ปรากฎว่ามีจำนวนมาก)
Maxim

คำตอบ:


175

การมีฟังก์ชั่นที่แตกต่างกันสองอย่างถือเป็นความสะดวกสบายเนื่องจากให้ผลลัพธ์เหมือนกัน

ความแตกต่างนั้นง่ายมาก:

  • สำหรับsparse_softmax_cross_entropy_with_logitsป้ายกำกับต้องมีรูปร่าง [batch_size] และ dtype int32 หรือ int64 แต่ละป้ายเป็น int [0, num_classes-1]อยู่ในช่วง
  • สำหรับsoftmax_cross_entropy_with_logitsป้ายกำกับต้องมีรูปร่าง [batch_size, num_classes] และ dtype float32 หรือ float64

ป้ายชื่อที่ใช้ในการsoftmax_cross_entropy_with_logitsเป็นรุ่นร้อนหนึ่งsparse_softmax_cross_entropy_with_logitsของป้ายที่ใช้ในการ

อีกความแตกต่างเล็ก ๆ เป็นที่ที่มีsparse_softmax_cross_entropy_with_logitsคุณสามารถให้ -1 เป็นป้ายที่จะมีการสูญเสีย0บนฉลากนี้


15
-1 ถูกต้องหรือไม่? ตามที่เอกสารอ่าน: "แต่ละรายการในป้ายกำกับต้องเป็นดัชนีใน [0, num_classes) ค่าอื่น ๆ จะทำให้เกิดข้อยกเว้นเมื่อเรียกใช้ op นี้บน CPU และส่งคืน NaN สำหรับแถวการสูญเสียและการไล่ระดับสีบน GPU"
user1761806

1
[0, num_classes) = [0, num_classes-1]
Karthik C

24

ฉันต้องการเพิ่ม 2 สิ่งในคำตอบที่ยอมรับซึ่งคุณสามารถพบได้ในเอกสาร TF

อันดับแรก:

tf.nn.softmax_cross_entropy_with_logits

หมายเหตุ: แม้ว่าชั้นเรียนจะไม่รวมกัน แต่ความน่าจะเป็นของพวกเขาก็ไม่จำเป็นต้องเป็น สิ่งที่ต้องมีก็คือป้ายกำกับแต่ละแถวเป็นการแจกแจงความน่าจะเป็นที่ถูกต้อง หากไม่เป็นเช่นนั้นการคำนวณของการไล่ระดับสีจะไม่ถูกต้อง

ประการที่สอง:

tf.nn.sparse_softmax_cross_entropy_with_logits

หมายเหตุ: สำหรับการดำเนินการนี้ความน่าจะเป็นของฉลากที่ระบุถือเป็นเอกสิทธิ์ นั่นคือไม่อนุญาตให้ใช้คลาสแบบอ่อนและเวกเตอร์ป้ายกำกับต้องระบุดัชนีเฉพาะสำหรับคลาสจริงสำหรับแต่ละแถวของบันทึก (รายการมินิแบทช์แต่ละรายการ)


4
เราควรใช้อะไรหากคลาสนั้นไม่สามารถใช้ร่วมกันได้ ฉันหมายถึงถ้าเรารวมป้ายกำกับหมวดหมู่หลาย ๆ
Hayro

ฉันยังอ่านสิ่งนี้ ดังนั้นหมายความว่าเราใช้ความน่าจะเป็นของคลาสกับเอนโทรปีแบบไขว้แทนที่จะใช้เป็นเวกเตอร์ onehot
Shamane Siriwardhana

@Hayro - คุณหมายความว่าคุณไม่สามารถทำการเข้ารหัสร้อนได้หรือไม่? ฉันคิดว่าคุณจะต้องดูรุ่นอื่น สิ่งนี้กล่าวถึงบางสิ่งเช่น "มันจะเหมาะสมกว่าในการสร้างตัวแยกประเภทการถดถอยโลจิสติกส์แบบไบนารี 4 ตัว" ก่อนอื่นให้แน่ใจว่าคุณสามารถแยกคลาสได้
ashley

21

ฟังก์ชันทั้งสองจะคำนวณผลลัพธ์ที่เหมือนกันและsparse_softmax_cross_entropy_with_logitsจะคำนวณเอนโทรปีแบบไขว้โดยตรงบนฉลากแบบกระจัดกระจายแทนที่จะแปลงด้วยการเข้ารหัสแบบร้อนเดียว

คุณสามารถตรวจสอบได้โดยรันโปรแกรมต่อไปนี้:

import tensorflow as tf
from random import randint

dims = 8
pos  = randint(0, dims - 1)

logits = tf.random_uniform([dims], maxval=3, dtype=tf.float32)
labels = tf.one_hot(pos, dims)

res1 = tf.nn.softmax_cross_entropy_with_logits(       logits=logits, labels=labels)
res2 = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=tf.constant(pos))

with tf.Session() as sess:
    a, b = sess.run([res1, res2])
    print a, b
    print a == b

ที่นี่ฉันสร้างแบบสุ่ม logitsเวกเตอร์ของความยาวdimsและสร้างป้ายกำกับที่เข้ารหัสแบบร้อนเดียว (โดยที่องค์ประกอบในposคือ 1 และอื่น ๆ เป็น 0)

หลังจากนั้นฉันคำนวณ softmax และ softmax แบบเบาบางและเปรียบเทียบผลลัพธ์ของมัน ลองรันใหม่สองสามครั้งเพื่อให้แน่ใจว่าจะให้ผลลัพธ์เดียวกันเสมอ

โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.