ดูเหมือนว่าคุณได้แก้ไขปัญหาในตัวอย่างเฉพาะของคุณแล้ว แต่ฉันคิดว่ามันยังคงคุ้มค่าที่จะศึกษาอย่างละเอียดมากขึ้นเกี่ยวกับความแตกต่างระหว่างกำลังสองน้อยที่สุดและการถดถอยโลจิสติกความน่าจะเป็นสูงสุด
มารับสัญกรณ์กัน Let LS(yi,y^i)=12(yi−y^i)2และLL(yi,y^i)=yilogy^i+(1−yi)log(1−y^i)) ถ้าเรากำลังทำโอกาสสูงสุด (หรือต่ำสุดในเชิงลบบันทึกความน่าจะเป็นที่ฉันทำที่นี่) เรามี
β L:=argminข∈ Rβ^L:=argminb∈Rp−∑i=1nyilogg−1(xTib)+(1−yi)log(1−g−1(xTib))
กับgเป็นฟังก์ชั่นการเชื่อมโยงของเรา
หรืออีกวิธีหนึ่งที่เรามี
β S : = argmin ข∈ R P 1β^S:=argminb∈Rp12∑i=1n(yi−g−1(xTib))2
เป็นวิธีกำลังสองน้อยที่สุด ดังนั้น β SลดLSและในทำนองเดียวกันสำหรับLLβ^SLSLL L
Let fSและfLเป็นฟังก์ชั่นที่สอดคล้องกับวัตถุประสงค์ของการลดLSและLLตามลำดับขณะที่จะทำเพื่อβ Sและβ L สุดท้ายให้H = กรัม- 1ดังนั้นYฉัน = H ( x T ฉันข ) โปรดทราบว่าหากเราใช้ลิงก์แบบบัญญัติเรามี
h ( z ) = 1β^Sβ^Lh=g−1y^i=h(xTib)h(z)=11+e−z⟹h′(z)=h(z)(1−h(z)).
สำหรับการถดถอยโลจิสติกปกติเรามี
∂fL∂bj=−∑i=1nh′(xTib)xij(yih(xTib)−1−yi1−h(xTib)).
การใช้h′=h⋅(1−h)เราสามารถทำให้มันง่ายขึ้นถึง
∂fL∂bj=−∑i=1nxij(yi(1−y^i)−(1−yi)y^i)=−∑i=1nxij(yi−y^i)
so
∇fL(b)=−XT(Y−Y^).
Next let's do second derivatives. The Hessian
HL:=∂2fL∂bj∂bk=∑i=1nxijxiky^i(1−y^i).
This means that HL=XTAX where A=diag(Y^(1−Y^)). HL does depend on the current fitted values Y^ but Y has dropped out, and HL is PSD. Thus our optimization problem is convex in b.
Let's compare this to least squares.
∂fS∂bj=−∑i=1n(yi−y^i)h′(xTib)xij.
This means we have
∇fS(b)=−XTA(Y−Y^).
This is a vital point: the gradient is almost the same except for all i y^i(1−y^i)∈(0,1) so basically we're flattening the gradient relative to ∇fL. This'll make convergence slower.
For the Hessian we can first write
∂fS∂bj=−∑i=1nxij(yi−y^i)y^i(1−y^i)=−∑i=1nxij(yiy^i−(1+yi)y^2i+y^3i).
This leads us to
HS:=∂2fS∂bj∂bk=−∑i=1nxijxikh′(xTib)(yi−2(1+yi)y^i+3y^2i).
Let B=diag(yi−2(1+yi)y^i+3y^2i). We now have
HS=−XTABX.
Unfortunately for us, the weights in B are not guaranteed to be non-negative: if yi=0 then yi−2(1+yi)y^i+3y^2i=y^i(3y^i−2) which is positive iff y^i>23. Similarly, if yi=1 then yi−2(1+yi)y^i+3y^2i=1−4y^i+3y^2i which is positive when y^i<13 (it's also positive for y^i>1 but that's not possible). This means that HS is not necessarily PSD, so not only are we squashing our gradients which will make learning harder, but we've also messed up the convexity of our problem.
All in all, it's no surprise that least squares logistic regression struggles sometimes, and in your example you've got enough fitted values close to 0 or 1 so that y^i(1−y^i) can be pretty small and thus the gradient is quite flattened.
Connecting this to neural networks, even though this is but a humble logistic regression I think with squared loss you're experiencing something like what Goodfellow, Bengio, and Courville are referring to in their Deep Learning book when they write the following:
One recurring theme throughout neural network design is that the gradient of the cost function must be large and predictable enough to serve as a good guide for the learning algorithm. Functions that saturate (become very flat) undermine this objective because they make the gradient become very small. In many cases this happens because the activation functions used to produce the output of the hidden units or the output units saturate. The negative log-likelihood helps to avoid this problem for many models. Many output units involve an exp function that can saturate when its argument is very negative. The log function in the negative log-likelihood cost function undoes the exp of some output units. We will discuss the interaction between the cost function and the choice of output unit in Sec. 6.2.2.
and, in 6.2.2,
Unfortunately, mean squared error and mean absolute error often lead to poor results when used with gradient-based optimization. Some output units that saturate produce very small gradients when combined with these cost functions. This is one reason that the cross-entropy cost function is more popular than mean squared error or mean absolute error, even when it is not necessary to estimate an entire distribution p(y|x).
(both excerpts are from chapter 6).