ปัญหานี้สามารถแก้ไขได้อย่างมีประสิทธิภาพด้วยตัวเลขบริสุทธิ์โดยการประมวลผลอาร์เรย์เป็นชิ้น ๆ :
def find_first(x):
idx, step = 0, 32
while idx < x.size:
nz, = x[idx: idx + step].nonzero()
if len(nz): # found non-zero, return it
return nz[0] + idx
# move to the next chunk, increase step
idx += step
step = min(9600, step + step // 2)
return -1
step
อาร์เรย์จะถูกประมวลผลในก้อนขนาด ยิ่งstep
ขั้นตอนยาวมากเท่าไหร่การประมวลผล Zeroed Array ก็จะเร็วขึ้นเท่านั้น (กรณีที่แย่ที่สุด) ยิ่งมีขนาดเล็กการประมวลผลอาร์เรย์ก็จะเร็วขึ้นโดยไม่มีศูนย์ในช่วงเริ่มต้น เคล็ดลับคือการเริ่มต้นด้วยขนาดเล็กstep
และเพิ่มขึ้นแบบทวีคูณ ยิ่งไปกว่านั้นไม่จำเป็นต้องเพิ่มให้สูงกว่าเกณฑ์เนื่องจากผลประโยชน์ที่ จำกัด
ฉันได้เปรียบเทียบโซลูชันกับโซลูชัน ndarary.nonzero และ numba บริสุทธิ์กับ 10 ล้านอาร์เรย์ลอย
import numpy as np
from numba import jit
from timeit import timeit
def find_first(x):
idx, step = 0, 32
while idx < x.size:
nz, = x[idx: idx + step].nonzero()
if len(nz):
return nz[0] + idx
idx += step
step = min(9600, step + step // 2)
return -1
@jit(nopython=True)
def find_first_numba(vec):
"""return the index of the first occurence of item in vec"""
for i in range(len(vec)):
if vec[i]:
return i
return -1
SIZE = 10_000_000
# First only
x = np.empty(SIZE)
find_first_numba(x[:10])
print('---- FIRST ----')
x[:] = 0
x[0] = 1
print('ndarray.nonzero', timeit(lambda: x.nonzero()[0][0], number=100)*10, 'ms')
print('find_first', timeit(lambda: find_first(x), number=1000), 'ms')
print('find_first_numba', timeit(lambda: find_first_numba(x), number=1000), 'ms')
print('---- LAST ----')
x[:] = 0
x[-1] = 1
print('ndarray.nonzero', timeit(lambda: x.nonzero()[0][0], number=100)*10, 'ms')
print('find_first', timeit(lambda: find_first(x), number=100)*10, 'ms')
print('find_first_numba', timeit(lambda: find_first_numba(x), number=100)*10, 'ms')
print('---- NONE ----')
x[:] = 0
print('ndarray.nonzero', timeit(lambda: x.nonzero()[0], number=100)*10, 'ms')
print('find_first', timeit(lambda: find_first(x), number=100)*10, 'ms')
print('find_first_numba', timeit(lambda: find_first_numba(x), number=100)*10, 'ms')
print('---- ALL ----')
x[:] = 1
print('ndarray.nonzero', timeit(lambda: x.nonzero()[0][0], number=100)*10, 'ms')
print('find_first', timeit(lambda: find_first(x), number=100)*10, 'ms')
print('find_first_numba', timeit(lambda: find_first_numba(x), number=100)*10, 'ms')
และผลลัพธ์บนเครื่องของฉัน:
---- FIRST ----
ndarray.nonzero 54.733994480002366 ms
find_first 0.0013148509997336078 ms
find_first_numba 0.0002839310000126716 ms
---- LAST ----
ndarray.nonzero 54.56336712999928 ms
find_first 25.38929685000312 ms
find_first_numba 8.022820680002951 ms
---- NONE ----
ndarray.nonzero 24.13432420999925 ms
find_first 25.345200140000088 ms
find_first_numba 8.154927100003988 ms
---- ALL ----
ndarray.nonzero 55.753537260002304 ms
find_first 0.0014760300018679118 ms
find_first_numba 0.0004358099977253005 ms
บริสุทธิ์ndarray.nonzero
เป็นที่แน่นอนคลาย โซลูชัน numba เร็วกว่าประมาณ 5 เท่าสำหรับกรณีที่ดีที่สุด มันเร็วขึ้นประมาณ 3 เท่าในกรณีที่เลวร้ายที่สุด