การไล่ระดับสีไม่พบวิธีแก้ปัญหาสำหรับกำลังสองน้อยที่สุดธรรมดาบนชุดข้อมูลนี้?


12

ฉันได้ศึกษาการถดถอยเชิงเส้นและลองใช้ชุดด้านล่าง {(x, y)} ซึ่ง x ระบุพื้นที่ของบ้านเป็นตารางฟุตและ y ระบุราคาเป็นดอลลาร์ นี่คือตัวอย่างแรกในแอนดรูอึ้งหมายเหตุ

2104.400
1600.330
2400.369
1416.232
3000.540

ฉันพัฒนารหัสตัวอย่าง แต่เมื่อฉันเรียกใช้ราคาจะเพิ่มขึ้นในแต่ละขั้นตอนในขณะที่ควรลดลงในแต่ละขั้นตอน รหัสและผลลัพธ์ที่ได้รับด้านล่าง biasคือ W 0 X 0โดยที่ X 0 = 1 featureWeightsคืออาร์เรย์ของ [X 1 , X 2 , ... , X N ]

ฉันยังพยายามวิธีการแก้ปัญหาหลามออนไลน์อยู่ที่นี่และอธิบายที่นี่ แต่ตัวอย่างนี้ยังให้ผลลัพธ์เดียวกัน

ช่องว่างในการทำความเข้าใจแนวคิดอยู่ที่ไหน

รหัส:

package com.practice.cnn;

import java.util.Arrays;

public class LinearRegressionExample {

    private float ALPHA = 0.0001f;
    private int featureCount = 0;
    private int rowCount = 0;

    private float bias = 1.0f;
    private float[] featureWeights = null;

    private float optimumCost = Float.MAX_VALUE;

    private boolean status = true;

    private float trainingInput[][] = null;
    private float trainingOutput[] = null;

    public void train(float[][] input, float[] output) {
        if (input == null || output == null) {
            return;
        }

        if (input.length != output.length) {
            return;
        }

        if (input.length == 0) {
            return;
        }

        rowCount = input.length;
        featureCount = input[0].length;

        for (int i = 1; i < rowCount; i++) {
            if (input[i] == null) {
                return;
            }

            if (featureCount != input[i].length) {
                return;
            }
        }

        featureWeights = new float[featureCount];
        Arrays.fill(featureWeights, 1.0f);

        bias = 0;   //temp-update-1
        featureWeights[0] = 0;  //temp-update-1

        this.trainingInput = input;
        this.trainingOutput = output;

        int count = 0;
        while (true) {
            float cost = getCost();

            System.out.print("Iteration[" + (count++) + "] ==> ");
            System.out.print("bias -> " + bias);
            for (int i = 0; i < featureCount; i++) {
                System.out.print(", featureWeights[" + i + "] -> " + featureWeights[i]);
            }
            System.out.print(", cost -> " + cost);
            System.out.println();

//          if (cost > optimumCost) {
//              status = false;
//              break;
//          } else {
//              optimumCost = cost;
//          }

            optimumCost = cost;

            float newBias = bias + (ALPHA * getGradientDescent(-1));

            float[] newFeaturesWeights = new float[featureCount];
            for (int i = 0; i < featureCount; i++) {
                newFeaturesWeights[i] = featureWeights[i] + (ALPHA * getGradientDescent(i));
            }

            bias = newBias;

            for (int i = 0; i < featureCount; i++) {
                featureWeights[i] = newFeaturesWeights[i];
            }
        }
    }

    private float getCost() {
        float sum = 0;
        for (int i = 0; i < rowCount; i++) {
            float temp = bias;
            for (int j = 0; j < featureCount; j++) {
                temp += featureWeights[j] * trainingInput[i][j];
            }

            float x = (temp - trainingOutput[i]) * (temp - trainingOutput[i]);
            sum += x;
        }
        return (sum / rowCount);
    }

    private float getGradientDescent(final int index) {
        float sum = 0;
        for (int i = 0; i < rowCount; i++) {
            float temp = bias;
            for (int j = 0; j < featureCount; j++) {
                temp += featureWeights[j] * trainingInput[i][j];
            }

            float x = trainingOutput[i] - (temp);
            sum += (index == -1) ? x : (x * trainingInput[i][index]);
        }
        return ((sum * 2) / rowCount);
    }

    public static void main(String[] args) {
        float[][] input = new float[][] { { 2104 }, { 1600 }, { 2400 }, { 1416 }, { 3000 } };

        float[] output = new float[] { 400, 330, 369, 232, 540 };

        LinearRegressionExample example = new LinearRegressionExample();
        example.train(input, output);
    }
}

เอาท์พุท:

Iteration[0] ==> bias -> 0.0, featureWeights[0] -> 0.0, cost -> 150097.0
Iteration[1] ==> bias -> 0.07484, featureWeights[0] -> 168.14847, cost -> 1.34029099E11
Iteration[2] ==> bias -> -70.60721, featureWeights[0] -> -159417.34, cost -> 1.20725801E17
Iteration[3] ==> bias -> 67012.305, featureWeights[0] -> 1.51299168E8, cost -> 1.0874295E23
Iteration[4] ==> bias -> -6.3599688E7, featureWeights[0] -> -1.43594258E11, cost -> 9.794949E28
Iteration[5] ==> bias -> 6.036088E10, featureWeights[0] -> 1.36281745E14, cost -> 8.822738E34
Iteration[6] ==> bias -> -5.7287012E13, featureWeights[0] -> -1.29341617E17, cost -> Infinity
Iteration[7] ==> bias -> 5.4369677E16, featureWeights[0] -> 1.2275491E20, cost -> Infinity
Iteration[8] ==> bias -> -5.1600908E19, featureWeights[0] -> -1.1650362E23, cost -> Infinity
Iteration[9] ==> bias -> 4.897313E22, featureWeights[0] -> 1.1057068E26, cost -> Infinity
Iteration[10] ==> bias -> -4.6479177E25, featureWeights[0] -> -1.0493987E29, cost -> Infinity
Iteration[11] ==> bias -> 4.411223E28, featureWeights[0] -> 9.959581E31, cost -> Infinity
Iteration[12] ==> bias -> -4.186581E31, featureWeights[0] -> -Infinity, cost -> Infinity
Iteration[13] ==> bias -> Infinity, featureWeights[0] -> NaN, cost -> NaN
Iteration[14] ==> bias -> NaN, featureWeights[0] -> NaN, cost -> NaN

นี่ไม่ใช่หัวข้อที่นี่
Michael R. Chernick

3
หากสิ่งต่าง ๆ ระเบิดขึ้นอย่างไม่มีที่สิ้นสุดเช่นที่นี่คุณอาจลืมหารด้วยขนาดของเวกเตอร์ที่ใดที่หนึ่ง
StasK

5
คำตอบที่ได้รับการยอมรับโดยแมทธิวนั้นชัดเจนทางสถิติ ซึ่งหมายความว่าคำถามที่ต้องการความเชี่ยวชาญทางสถิติ (และไม่ใช่การเขียนโปรแกรม) เพื่อตอบ มันทำให้มันในหัวข้อตามคำนิยาม ฉันลงคะแนนเพื่อเปิดใหม่
อะมีบาพูดว่า Reinstate Monica

คำตอบ:


35

คำตอบสั้น ๆ คือขนาดก้าวของคุณใหญ่เกินไป แทนที่จะก้าวลงมาจากกำแพงแคนยอนก้าวของคุณนั้นใหญ่จนคุณกระโดดข้ามจากด้านหนึ่งไปอีกด้านหนึ่งขึ้นไปอีกด้านหนึ่ง!

ฟังก์ชั่นค่าใช้จ่ายด้านล่าง:

ป้อนคำอธิบายรูปภาพที่นี่

คำตอบยาว ๆ คือมันยากสำหรับการไล่ระดับสีที่ไร้เดียงสาในการแก้ปัญหานี้เพราะชุดระดับของฟังก์ชันต้นทุนของคุณนั้นเป็นวงรีที่มีความยาวสูงแทนที่จะเป็นวงกลม เมื่อต้องการแก้ไขปัญหานี้อย่างแน่นหนาโปรดทราบว่ามีวิธีที่ซับซ้อนกว่าให้เลือก:

  • ขนาดขั้นตอน (กว่าการเข้ารหัสฮาร์ดค่าคงที่)
  • ทิศทางขั้นตอน (มากกว่าการไล่ระดับสีลาด)

ปัญหาพื้นฐาน

ปัญหาพื้นฐานคือชุดระดับของฟังก์ชั่นค่าใช้จ่ายของคุณเป็นวงรีที่มีความยาวสูงและทำให้เกิดปัญหาสำหรับการไล่ระดับสี รูปด้านล่างแสดงชุดระดับสำหรับฟังก์ชันต้นทุน

  • 026.789
  • ถ้าขนาดก้าวใหญ่เกินไปคุณจะกระโดดข้ามพื้นที่สีน้ำเงินล่างและขึ้นไปแทนที่จะลงมา
  • θ0

ฉันขอแนะนำให้อ่านคำตอบนี้ใน Quora

ป้อนคำอธิบายรูปภาพที่นี่

แก้ไขด่วน 1:

เปลี่ยนรหัสของคุณเป็นprivate float ALPHA = 0.0000002f;และคุณจะหยุดการแก้ไขปัญหา

แก้ไขด่วน 2:

XX

การแก้ไขขั้นสูงเพิ่มเติม

หากเป้าหมายคือการแก้ปัญหาสี่เหลี่ยมจัตุรัสน้อยที่สุดธรรมดาได้อย่างมีประสิทธิภาพแทนที่จะเรียนรู้การลาดลงของชั้นเรียนให้สังเกตว่า:

  • มีวิธีการที่ซับซ้อนมากขึ้นกับขนาดของขั้นตอนการคำนวณเช่นมีการค้นหาบรรทัดและกฎ Armijo
  • ใกล้กับคำตอบที่สภาพท้องถิ่นดีกว่าวิธีการของนิวตันได้รับการลู่เข้าแบบสมการกำลังสองและเป็นวิธีที่ยอดเยี่ยมในการเลือกทิศทางและขนาดของก้าว
  • การแก้กำลังสองน้อยที่สุดเทียบเท่ากับการแก้ระบบเชิงเส้น อัลกอริทึมสมัยใหม่ไม่ได้ใช้การไล่ระดับสีที่ไร้เดียงสา แทน:
    • k
    • สำหรับระบบขนาดใหญ่ที่พวกเขาไม่กำหนดมันเป็นปัญหาที่เพิ่มประสิทธิภาพและการใช้วิธีการที่กล่าวย้ำเช่นสเปซ Krylovวิธี

(XX)b=Xyb

ทางออกที่แท้จริงคือ

  26.789880528523071
   0.165118878075797

คุณจะพบว่าสิ่งเหล่านั้นบรรลุค่าต่ำสุดสำหรับฟังก์ชันต้นทุน


5
+1 มันเป็นความหรูหราที่จะให้คนอื่นแก้จุดบกพร่องรหัส!
Haitao Du

4
@ hxd1011 ฉันคิดว่ามันเป็นข้อผิดพลาดในการเขียนใบ้ในตอนแรก แต่กลับกลายเป็นว่าฉันเป็นตัวอย่างที่ให้คำแนะนำในสิ่งที่ผิดพลาดได้โดยการสืบเชื้อสายที่ไร้เดียงสา
แมทธิวกันน์

@ MatthunGunn ฉันได้ทางออก b = 0.99970686, m = 0.17655967 (y = mx + b) และคุณหมายถึงอะไรโดย "ขนาดก้าวกว่าการเข้ารหัสค่าคงที่" นั่นหมายความว่าเราควรเปลี่ยนมันทุกครั้งหรือไม่ หรือเราจำเป็นต้องคำนวณตามค่าอินพุต?
เหลืองอำพัน Beriwal

αiiααif

@Beriwal คุณจะพบว่า (26.789, .1651) จะมีต้นทุนที่ต่ำกว่าเล็กน้อย มันตกต่ำเล็กน้อยจาก (.9997, .1766) ในทิศทางที่ฟังก์ชั่นค่าใช้จ่ายมีความชันเล็กน้อย
Matthew Gunn

2

ดังที่แมทธิว (กันน์) ได้ระบุไว้แล้วรูปทรงของฟังก์ชั่นค่าใช้จ่ายหรือประสิทธิภาพ 3 มิตินั้นเป็นรูปวงรีสูงในกรณีนี้ ตั้งแต่รหัสของคุณ Java ใช้ค่าขั้นตอนขนาดเดียวสำหรับการคำนวณเชื้อสายลาดปรับปรุงน้ำหนัก (เช่นตัดแกน y และความลาดชันของฟังก์ชันเชิงเส้น) จะทั้งควบคุมโดยขั้นตอนนี้เพียงครั้งเดียวขนาด

ด้วยเหตุนี้ขั้นตอนขนาดเล็กมากที่ต้องการควบคุมการอัปเดตของน้ำหนักที่เกี่ยวข้องกับการไล่ระดับสีที่ใหญ่กว่า (ในกรณีนี้ความชันของฟังก์ชันเชิงเส้น) จะ จำกัด ความเร็วของน้ำหนักอื่นด้วยความชันที่น้อยลงอย่างมาก มีการอัพเดทการสกัดกั้นแกน y ของฟังก์ชันเชิงเส้น) ภายใต้เงื่อนไขปัจจุบันน้ำหนักหลังไม่ได้รวมกันเป็นมูลค่าที่แท้จริงประมาณ 26.7

ด้วยเวลาและความพยายามที่คุณลงทุนในการเขียนโค้ด Java ของคุณฉันขอแนะนำให้คุณแก้ไขด้วยการใช้ค่าขนาดขั้นตอนแบบแยกกันสองค่าคือขนาดขั้นตอนที่เหมาะสมสำหรับแต่ละน้ำหนัก แอนดรูว์อึ้งแนะนำในบันทึกย่อของเขาว่าควรใช้การปรับขนาดฟีเจอร์เพื่อให้แน่ใจว่ารูปทรงของฟังก์ชั่นค่าใช้จ่ายมีความสม่ำเสมอมากขึ้น (เช่นวงกลม) ในรูปแบบ อย่างไรก็ตามการแก้ไขโค้ด Java ของคุณเพื่อใช้ขนาดขั้นตอนที่แตกต่างกันสำหรับแต่ละน้ำหนักอาจเป็นการออกกำลังกายที่ดีนอกเหนือจากการดูการปรับขนาดฟีเจอร์

แนวคิดอื่นที่ควรพิจารณาคือวิธีเลือกค่าน้ำหนักเริ่มต้น ในโค้ด Java ของคุณคุณกำหนดค่าทั้งสองให้เป็นศูนย์ นอกจากนี้ยังเป็นเรื่องธรรมดาที่จะเริ่มต้นน้ำหนักให้มีค่าน้อยและมีเศษส่วน อย่างไรก็ตามในกรณีพิเศษนี้วิธีการทั้งสองวิธีนี้จะไม่สามารถใช้งานได้ในรูปทรงที่มีลักษณะเป็นวงรีสูง (เช่นไม่ใช่แบบวงกลม) ของฟังก์ชันต้นทุนสามมิติ เนื่องจากน้ำหนักของปัญหานี้สามารถพบได้โดยใช้วิธีอื่นเช่นวิธีแก้ปัญหาสำหรับระบบเชิงเส้นที่ Matthew แนะนำไว้ที่ส่วนท้ายของโพสต์คุณสามารถลองกำหนดน้ำหนักให้ใกล้เคียงกับน้ำหนักที่ถูกต้องมากขึ้นและดูว่ารหัสต้นฉบับของคุณเป็นอย่างไร โดยใช้คอนเวอร์เจนซ์ขนาดก้าวเดียว

รหัส Python ที่คุณพบนั้นเป็นวิธีการแก้ปัญหาในลักษณะเดียวกับรหัส Java ของคุณ - ทั้งคู่ใช้พารามิเตอร์ขนาดขั้นตอนเดียว ฉันปรับเปลี่ยนรหัส Python นี้เพื่อใช้ขนาดขั้นตอนที่แตกต่างกันสำหรับแต่ละน้ำหนัก ฉันได้รวมไว้ด้านล่าง

from numpy import *

def compute_error_for_line_given_points(b, m, points):
    totalError = 0
    for i in range(0, len(points)):
        x = points[i, 0]
        y = points[i, 1]
        totalError += (y - (m * x + b)) ** 2
    return totalError / float(len(points))

def step_gradient(b_current, m_current, points, learningRate_1, learningRate_2):
    b_gradient = 0
    m_gradient = 0
    N = float(len(points))
    for i in range(0, len(points)):
        x = points[i, 0]
        y = points[i, 1]
        b_gradient += -(2/N) * (y - ((m_current * x) + b_current))
        m_gradient += -(2/N) * x * (y - ((m_current * x) + b_current))
    new_b = b_current - (learningRate_1 * b_gradient)
    new_m = m_current - (learningRate_2 * m_gradient)
    return [new_b, new_m]

def gradient_descent_runner(points, starting_b, starting_m, learning_rate_1, learning_rate_2, num_iterations):
    b = starting_b
    m = starting_m
    for i in range(num_iterations):
        b, m = step_gradient(b, m, array(points), learning_rate_1, learning_rate_2)
    return [b, m]

def run():
    #points = genfromtxt("data.csv", delimiter=",")
    #learning_rate = 0.0001
    #num_iterations = 200

    points = genfromtxt("test_set.csv", delimiter=",")
    learning_rate_1 = 0.5
    learning_rate_2 = 0.0000001
    num_iterations = 1000

    initial_b = 0 # initial y-intercept guess
    initial_m = 0 # initial slope guess


    print("Starting gradient descent at b = {0}, m = {1}, error = {2}".format(initial_b, initial_m, compute_error_for_line_given_points(initial_b, initial_m, points)))
    print("Running...")

    [b, m] = gradient_descent_runner(points, initial_b, initial_m, learning_rate_1, learning_rate_2, num_iterations)

    print("After {0} iterations b = {1}, m = {2}, error = {3}".format(num_iterations, b, m, compute_error_for_line_given_points(b, m, points)))

if __name__ == '__main__':
    run()

มันทำงานภายใต้ Python 3 ซึ่งต้องการวงเล็บในอาร์กิวเมนต์สำหรับคำสั่ง "print" มิฉะนั้นมันจะทำงานภายใต้ Python 2 โดยการลบวงเล็บ คุณจะต้องสร้างไฟล์ CSV พร้อมข้อมูลจากตัวอย่างของ Andrew Ng

ใช้สามารถอ้างอิงข้ามรหัส Python เพื่อตรวจสอบรหัส Java ของคุณ

โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.