การเปิดใช้งาน GELU คืออะไร


18

ฉันกำลังอ่านกระดาษ BERTซึ่งใช้GELU (Gaussian Error Linear Unit)ซึ่งระบุสมการเป็น ซึ่งจะอยู่ที่ประมาณ

GELยู(x)=xP(Xx)=xΦ(x).
0.5x(1+เสื้อanชั่วโมง[2/π(x+0.044715x3)])

คุณช่วยทำให้สมการง่ายขึ้นและอธิบายว่ามันประมาณได้อย่างไร

คำตอบ:


19

ฟังก์ชั่น GELU

เราสามารถขยายการแจกแจงสะสมของN(0,1) , คือΦ(x) , ดังนี้:

GELU(x):=xP(Xx)=xΦ(x)=0.5x(1+erf(x2))

โปรดทราบว่านี่เป็นคำนิยามไม่ใช่สมการ (หรือความสัมพันธ์) ผู้เขียนได้ให้เหตุผลบางอย่างสำหรับข้อเสนอนี้เช่นการเปรียบเทียบสุ่มอย่างไรก็ตามในทางคณิตศาสตร์นี่เป็นเพียงคำจำกัดความ

นี่คือเนื้อเรื่องของ GELU:

การประมาณทาห์

สำหรับการประมาณค่าตัวเลขประเภทนี้แนวคิดหลักคือการหาฟังก์ชั่นที่คล้ายกัน (ส่วนใหญ่ขึ้นอยู่กับประสบการณ์) กำหนดพารามิเตอร์แล้วปรับให้เข้ากับชุดของคะแนนจากฟังก์ชันต้นฉบับ

รู้ว่าerf(x)อยู่ใกล้กับtanh(x)

และอนุพันธ์อันดับแรกของerf(x2)เกิดขึ้นพร้อมกับของtanh(2πx)ที่x=0ซึ่งก็คือ2πเราดำเนินการเพื่อให้พอดีกับ

tanh(2π(x+ax2+x3+x4+dx5))
(หรือด้วยเงื่อนไขเพิ่มเติม) กับชุดของคะแนน(xผม,ERF(xผม2)))

ฉันได้ติดตั้งฟังก์ชั่นนี้กับตัวอย่าง 20 ตัวระหว่าง(-1.5,1.5) ( ใช้ไซต์นี้ ) และนี่คือสัมประสิทธิ์:

โดยการตั้งค่า= C = d = 0 , Bก็จะประมาณ0.04495641 ด้วยตัวอย่างเพิ่มเติมจากช่วงที่กว้างขึ้น (ไซต์นั้นอนุญาตให้ 20 เท่านั้น) สัมประสิทธิ์bจะใกล้เคียงกับ0.044715ของกระดาษมากขึ้น ในที่สุดเราก็ได้a=c=d=0b0.04495641b0.044715

GELU(x)=xΦ(x)=0.5x(1+erf(x2))0.5x(1+tanh(2π(x+0.044715x3)))

มีค่าเฉลี่ยข้อผิดพลาด Squared 108สำหรับx[10,10] ]

โปรดทราบว่าหากเราไม่ได้ใช้ประโยชน์จากความสัมพันธ์ระหว่างอนุพันธ์อันดับแรกคำศัพท์2πจะรวมอยู่ในพารามิเตอร์ดังนี้

0.5x(1+tanh(0.797885x+0.035677x3))
ซึ่งมีความสวยงามน้อยกว่า (การวิเคราะห์น้อยกว่าตัวเลขมากกว่า)!

ใช้ความเท่าเทียมกัน

ตามที่แนะนำโดย@BookYourLuckเราสามารถใช้ฟังก์ชั่นพาริตี้เพื่อ จำกัด พื้นที่ของพหุนามที่เราค้นหา นั่นคือเนื่องจากerfเป็นฟังก์ชันคี่เช่นf(x)=f(x)และtanhยังเป็นฟังก์ชันคี่ฟังก์ชันพหุนามฟังก์ชันpol(x)ภายในtanhก็ควรจะแปลกด้วย (ควรมีพลังคี่ของx ) มี

erf(x)tanh(pol(x))=tanh(pol(x))=tanh(pol(x))erf(x)

x2x40.23x20x2

การประมาณ Sigmoid

erf(x)2(σ(x)12)104 for x[10,10].

Here is a Python code for generating data points, fitting the functions, and calculating the mean squared errors:

import math
import numpy as np
import scipy.optimize as optimize


def tahn(xs, a):
    return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]


def sigmoid(xs, a):
    return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]


print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
#       .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])

# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))

sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))

# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()

# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()

print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)

if print_points == 1:
    print(len(xs))
    for x, erf in zip(xs, erfs):
        print(x, erf)

Output:

Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05

2
Why is the approximation needed? Couldn't they just use erf function?
SebiSebi

8

First note that

Φ(x)=12erfc(x2)=12(1+erf(x2))
by parity of erf. We need to show that
erf(x2)tanh(2π(x+ax3))
for a0.044715.

For large values of x, both functions are bounded in [1,1]. For small x, the respective Taylor series read

tanh(x)=xx33+o(x3)
and
erf(x)=2π(xx33)+o(x3).
Substituting, we get that
tanh(2π(x+ax3))=2π(x+(a-23π)x3)+โอ(x3)
และ
อีR(x2)=2π(x-x36)+โอ(x3).
การเทียบค่าสัมประสิทธิ์สำหรับ x3เราพบว่า
a0.04553992412
ใกล้กับกระดาษ 0.044715.

โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.