การเกิดขึ้นครั้งแรกของจำนวน numpy ของค่ามากกว่าค่าที่มีอยู่


144

ฉันมีอาร์เรย์ 1D เป็นจำนวนมากและฉันต้องการค้นหาตำแหน่งของดัชนีที่ค่าเกินค่าในอาร์เรย์ numpy

เช่น

aa = range(-10,10)

ค้นหาตำแหน่งในaaที่ที่ค่า5เกิน


2
หนึ่งควรมีความชัดเจนว่าจะไม่มีวิธีแก้ปัญหา (เนื่องจากเช่นคำตอบ argmax จะไม่ทำงานในกรณีนั้น (สูงสุด (0,0,0,0) = 0) เป็น ambrus แสดงความคิดเห็น
seanv507

คำตอบ:


199

มันเร็วกว่าเล็กน้อย (และดูดีกว่า)

np.argmax(aa>5)

เนื่องจากargmaxจะหยุดที่จุดแรกTrue("ในกรณีที่เกิดขึ้นหลายครั้งของค่าสูงสุดดัชนีที่สอดคล้องกับการเกิดขึ้นครั้งแรกจะถูกส่งคืน") และจะไม่บันทึกรายการอื่น

In [2]: N = 10000

In [3]: aa = np.arange(-N,N)

In [4]: timeit np.argmax(aa>N/2)
100000 loops, best of 3: 52.3 us per loop

In [5]: timeit np.where(aa>N/2)[0][0]
10000 loops, best of 3: 141 us per loop

In [6]: timeit np.nonzero(aa>N/2)[0][0]
10000 loops, best of 3: 142 us per loop

103
เพียงแค่คำเตือน: หากไม่มีค่า True ในอินพุตอาร์เรย์ np.argmax จะส่งคืน 0 อย่างมีความสุข (ซึ่งไม่ใช่สิ่งที่คุณต้องการในกรณีนี้)
ambrus

8
ผลลัพธ์นั้นถูกต้อง แต่ฉันพบว่าคำอธิบายนั้นน่าสงสัยเล็กน้อย ดูเหมือนจะไม่หยุดที่แรกargmax True(สิ่งนี้สามารถทดสอบได้โดยการสร้างอาร์เรย์บูลีนด้วยตำแหน่งเดียวTrueในตำแหน่งที่แตกต่างกัน) ความเร็วอาจถูกอธิบายโดยข้อเท็จจริงที่argmaxไม่จำเป็นต้องสร้างรายการผลลัพธ์
DrV

1
ฉันคิดว่าคุณพูดถูก @DrV argmaxคำอธิบายของฉันก็หมายความว่าจะเกี่ยวกับเหตุผลที่จะให้ผลที่ถูกต้องแม้จะมีความตั้งใจเดิมไม่จริงที่กำลังมองหาสูงสุดไม่ว่าทำไมมันจะเร็วเท่าที่ผมไม่สามารถเรียกร้องที่จะเข้าใจรายละเอียดด้านในของ
askewchan

1
@ George ฉันเกรงว่าฉันไม่รู้ว่าทำไม ฉันสามารถบอกได้ว่ามันเร็วกว่าในตัวอย่างที่ฉันแสดงดังนั้นฉันจะไม่พิจารณาโดยทั่วไปเร็วกว่าโดยที่ (i) รู้ว่าทำไมมันถึงเป็น (ดูความคิดเห็นของ @ DrV) หรือ (ii) ทดสอบกรณีเพิ่มเติม (เช่นไม่ว่าaaจะเรียงเป็นหรือไม่ดังในคำตอบของ @ Michael)
askewchan

3
@DrV ฉันเพิ่งรันargmaxอาร์เรย์บูลีน 10 ล้านองค์ประกอบโดยมีหนึ่งTrueตำแหน่งที่แตกต่างกันโดยใช้ NumPy 1.11.2 และตำแหน่งของTrueวัตถุ ดังนั้น 1.11.2 argmaxดูเหมือนว่าจะ "ลัดวงจร" ในอาร์เรย์บูลีน
Ulrich Stern

96

กำหนดเนื้อหาเรียงลำดับของอาร์เรย์ของคุณมีวิธีการที่เร็วยิ่งขึ้น: searchsorted

import time
N = 10000
aa = np.arange(-N,N)
%timeit np.searchsorted(aa, N/2)+1
%timeit np.argmax(aa>N/2)
%timeit np.where(aa>N/2)[0][0]
%timeit np.nonzero(aa>N/2)[0][0]

# Output
100000 loops, best of 3: 5.97 µs per loop
10000 loops, best of 3: 46.3 µs per loop
10000 loops, best of 3: 154 µs per loop
10000 loops, best of 3: 154 µs per loop

19
นี่เป็นคำตอบที่ดีที่สุดโดยสมมติว่ามีการเรียงลำดับอาร์เรย์ (ซึ่งไม่ได้ระบุไว้ในคำถาม) คุณสามารถหลีกเลี่ยงความอึดอัดใจ+1ด้วยnp.searchsorted(..., side='right')
askewchan

3
ฉันคิดว่าการsideโต้แย้งสร้างความแตกต่างหากมีค่าซ้ำในอาร์เรย์ที่เรียงลำดับ ไม่เปลี่ยนความหมายของดัชนีที่ส่งคืนซึ่งเป็นดัชนีที่คุณสามารถแทรกค่าข้อความค้นหาได้ตลอดเวลาเลื่อนรายการทั้งหมดต่อไปนี้ไปทางขวาและรักษาอาร์เรย์ที่เรียงลำดับ
กัส

@Gus sideมีผลเมื่อค่าเดียวกันอยู่ในทั้งเรียงและอาร์เรย์แทรกโดยไม่คำนึงถึงค่าซ้ำในทั้ง ค่าที่ซ้ำกันในอาร์เรย์ที่เรียงลำดับมีผลกระทบเกินจริง (ความแตกต่างระหว่างด้านข้างคือจำนวนครั้งที่ค่าที่ถูกแทรกปรากฏในอาร์เรย์ที่เรียงลำดับ) side จะเปลี่ยนความหมายของดัชนีที่ส่งคืนแม้ว่ามันจะไม่เปลี่ยนอาร์เรย์ผลลัพธ์จากการแทรกค่าลงในอาร์เรย์ที่เรียงลำดับที่ดัชนีเหล่านั้น ความแตกต่างที่ลึกซึ้ง แต่สำคัญ ในความเป็นจริงคำตอบนี้จะช่วยให้ดัชนีผิดถ้าไม่ได้อยู่ในN/2 aa
askewchan

ในฐานะที่เป็นนัยในความคิดเห็นข้างต้นคำตอบนี้ปิดโดยหนึ่งถ้าไม่ได้อยู่ในN/2 aaแบบฟอร์มที่ถูกต้องจะเป็นnp.searchsorted(aa, N/2, side='right')(ไม่มี+1) ทั้งสองรูปแบบให้ดัชนีเดียวกันเป็นอย่างอื่น พิจารณากรณีทดสอบของNการเป็นคี่ (และN/2.0เพื่อบังคับให้ลอยถ้าใช้หลาม 2)
askewchan

21

ผมยังมีความสนใจในเรื่องนี้และฉันได้เมื่อเทียบกับทุกคำตอบปัญหากับperfplot (ข้อจำกัดความรับผิดชอบ: ฉันเป็นผู้เขียนข้อตกลง)

หากคุณรู้ว่าอาเรย์ที่คุณค้นหานั้นเรียงลำดับแล้ว

numpy.searchsorted(a, alpha)

สำหรับคุณ. มันเป็นการดำเนินการเวลาคงที่นั่นคือความเร็วไม่ได้ขึ้นอยู่กับขนาดของอาเรย์ คุณไม่สามารถไปได้เร็วกว่านั้น

หากคุณไม่รู้อะไรเกี่ยวกับอาเรย์ของคุณคุณจะไม่ผิดพลาด

numpy.argmax(a > alpha)

เรียงแล้ว:

ป้อนคำอธิบายรูปภาพที่นี่

ไม่ได้เรียงลำดับ:

ป้อนคำอธิบายรูปภาพที่นี่

รหัสในการทำซ้ำพล็อต:

import numpy
import perfplot


alpha = 0.5

def argmax(data):
    return numpy.argmax(data > alpha)

def where(data):
    return numpy.where(data > alpha)[0][0]

def nonzero(data):
    return numpy.nonzero(data > alpha)[0][0]

def searchsorted(data):
    return numpy.searchsorted(data, alpha)

out = perfplot.show(
    # setup=numpy.random.rand,
    setup=lambda n: numpy.sort(numpy.random.rand(n)),
    kernels=[
        argmax, where,
        nonzero,
        searchsorted
        ],
    n_range=[2**k for k in range(2, 20)],
    logx=True,
    logy=True,
    xlabel='len(array)'
    )

4
np.searchsortedไม่ใช่เวลาคงที่ O(log(n))มันเป็นเรื่องจริง แต่กรณีทดสอบของคุณเป็นมาตรฐานที่ดีที่สุดsearchsorted(ซึ่งก็คือO(1))
MSeifert

@MSeifert คุณต้องการอาร์เรย์ชนิดใด / อัลฟ่าที่ต้องดู O (บันทึก (n))
Nico Schlömer

1
การรับไอเท็มที่ดัชนี sqrt (ความยาว) ทำให้ได้ประสิทธิภาพที่แย่มาก ฉันยังเขียนคำตอบ ที่นี่รวมถึงเกณฑ์มาตรฐานด้วย
MSeifert

ฉันสงสัยว่าsearchsorted(หรืออัลกอริทึมใด ๆ ) สามารถเอาชนะการO(log(n))ค้นหาแบบไบนารีสำหรับข้อมูลที่กระจายอย่างสม่ำเสมอ แก้ไข: searchsorted คือการค้นหาไบนารี
Mateen Ulhaq

16
In [34]: a=np.arange(-10,10)

In [35]: a
Out[35]:
array([-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,
         3,   4,   5,   6,   7,   8,   9])

In [36]: np.where(a>5)
Out[36]: (array([16, 17, 18, 19]),)

In [37]: np.where(a>5)[0][0]
Out[37]: 16

8

อาร์เรย์ที่มีขั้นตอนคงที่ระหว่างองค์ประกอบ

ในกรณีของ a rangeหรืออาเรย์ที่เพิ่มขึ้นแบบเส้นตรงอื่น ๆ คุณสามารถคำนวณดัชนีโดยทางโปรแกรมไม่จำเป็นต้องวนซ้ำจริง ๆ ในอาร์เรย์เลย:

def first_index_calculate_range_like(val, arr):
    if len(arr) == 0:
        raise ValueError('no value greater than {}'.format(val))
    elif len(arr) == 1:
        if arr[0] > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    first_value = arr[0]
    step = arr[1] - first_value
    # For linearly decreasing arrays or constant arrays we only need to check
    # the first element, because if that does not satisfy the condition
    # no other element will.
    if step <= 0:
        if first_value > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    calculated_position = (val - first_value) / step

    if calculated_position < 0:
        return 0
    elif calculated_position > len(arr) - 1:
        raise ValueError('no value greater than {}'.format(val))

    return int(calculated_position) + 1

หนึ่งอาจปรับปรุงที่เล็กน้อย ฉันแน่ใจว่ามันทำงานอย่างถูกต้องสำหรับอาร์เรย์และค่าตัวอย่างสองสามตัว แต่นั่นไม่ได้หมายความว่าจะไม่มีข้อผิดพลาดเกิดขึ้น

>>> import numpy as np
>>> first_index_calculate_range_like(5, np.arange(-10, 10))
16
>>> np.arange(-10, 10)[16]  # double check
6

>>> first_index_calculate_range_like(4.8, np.arange(-10, 10))
15

เนื่องจากมันสามารถคำนวณตำแหน่งโดยไม่มีการวนซ้ำใด ๆ ซึ่งจะเป็นเวลาคงที่ ( O(1)) และอาจเอาชนะวิธีการอื่น ๆ ที่กล่าวถึงทั้งหมดได้ อย่างไรก็ตามต้องมีขั้นตอนคงที่ในอาร์เรย์มิฉะนั้นจะให้ผลลัพธ์ที่ผิด

วิธีแก้ปัญหาทั่วไปโดยใช้ numba

แนวทางทั่วไปที่มากกว่านั้นก็คือการใช้ฟังก์ชัน numba:

@nb.njit
def first_index_numba(val, arr):
    for idx in range(len(arr)):
        if arr[idx] > val:
            return idx
    return -1

มันจะใช้ได้กับอาเรย์ใด ๆ แต่มันจะต้องวนซ้ำไปเรื่อย ๆ ในอาเรย์ดังนั้นโดยเฉลี่ยแล้วมันจะเป็นO(n):

>>> first_index_numba(4.8, np.arange(-10, 10))
15
>>> first_index_numba(5, np.arange(-10, 10))
16

เกณฑ์มาตรฐาน

แม้ว่า Nico Schlömerได้จัดทำเกณฑ์มาตรฐานบางอย่างแล้วฉันคิดว่ามันอาจมีประโยชน์ในการรวมโซลูชันใหม่ของฉันและเพื่อทดสอบ "ค่า" ที่แตกต่างกัน

การตั้งค่าการทดสอบ:

import numpy as np
import math
import numba as nb

def first_index_using_argmax(val, arr):
    return np.argmax(arr > val)

def first_index_using_where(val, arr):
    return np.where(arr > val)[0][0]

def first_index_using_nonzero(val, arr):
    return np.nonzero(arr > val)[0][0]

def first_index_using_searchsorted(val, arr):
    return np.searchsorted(arr, val) + 1

def first_index_using_min(val, arr):
    return np.min(np.where(arr > val))

def first_index_calculate_range_like(val, arr):
    if len(arr) == 0:
        raise ValueError('empty array')
    elif len(arr) == 1:
        if arr[0] > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    first_value = arr[0]
    step = arr[1] - first_value
    if step <= 0:
        if first_value > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    calculated_position = (val - first_value) / step

    if calculated_position < 0:
        return 0
    elif calculated_position > len(arr) - 1:
        raise ValueError('no value greater than {}'.format(val))

    return int(calculated_position) + 1

@nb.njit
def first_index_numba(val, arr):
    for idx in range(len(arr)):
        if arr[idx] > val:
            return idx
    return -1

funcs = [
    first_index_using_argmax, 
    first_index_using_min, 
    first_index_using_nonzero,
    first_index_calculate_range_like, 
    first_index_numba, 
    first_index_using_searchsorted, 
    first_index_using_where
]

from simple_benchmark import benchmark, MultiArgument

และแปลงถูกสร้างขึ้นโดยใช้:

%matplotlib notebook
b.plot()

รายการอยู่ที่จุดเริ่มต้น

b = benchmark(
    funcs,
    {2**i: MultiArgument([0, np.arange(2**i)]) for i in range(2, 20)},
    argument_name="array size")

ป้อนคำอธิบายรูปภาพที่นี่

ฟังก์ชัน numba ทำงานได้ดีที่สุดตามด้วยฟังก์ชันคำนวณและฟังก์ชันค้นหา โซลูชันอื่น ๆ ทำงานได้แย่กว่ามาก

รายการอยู่ท้าย

b = benchmark(
    funcs,
    {2**i: MultiArgument([2**i-2, np.arange(2**i)]) for i in range(2, 20)},
    argument_name="array size")

ป้อนคำอธิบายรูปภาพที่นี่

สำหรับอาร์เรย์ขนาดเล็กฟังก์ชัน numba จะทำงานได้อย่างรวดเร็วอย่างน่าอัศจรรย์อย่างไรก็ตามสำหรับอาร์เรย์ที่ใหญ่กว่านั้นมีฟังก์ชันที่ดีกว่าด้วยฟังก์ชันการคำนวณและฟังก์ชันค้นหา

รายการอยู่ที่ sqrt (len)

b = benchmark(
    funcs,
    {2**i: MultiArgument([np.sqrt(2**i), np.arange(2**i)]) for i in range(2, 20)},
    argument_name="array size")

ป้อนคำอธิบายรูปภาพที่นี่

มันน่าสนใจกว่านี้ numba อีกครั้งและฟังก์ชั่นการคำนวณมีประสิทธิภาพดีเยี่ยม แต่นี่เป็นจุดเริ่มต้นของกรณีค้นหาที่เลวร้ายที่สุดซึ่งจริงๆแล้วใช้งานไม่ได้ในกรณีนี้

การเปรียบเทียบฟังก์ชั่นเมื่อไม่มีค่าที่เป็นไปตามเงื่อนไข

อีกจุดที่น่าสนใจคือการทำงานของฟังก์ชั่นเหล่านี้หากไม่มีค่าที่ควรส่งคืนดัชนี:

arr = np.ones(100)
value = 2

for func in funcs:
    print(func.__name__)
    try:
        print('-->', func(value, arr))
    except Exception as e:
        print('-->', e)

ด้วยผลลัพธ์นี้:

first_index_using_argmax
--> 0
first_index_using_min
--> zero-size array to reduction operation minimum which has no identity
first_index_using_nonzero
--> index 0 is out of bounds for axis 0 with size 0
first_index_calculate_range_like
--> no value greater than 2
first_index_numba
--> -1
first_index_using_searchsorted
--> 101
first_index_using_where
--> index 0 is out of bounds for axis 0 with size 0

Searchsorted, argmax และ numba เพียงแค่คืนค่าที่ผิด อย่างไรก็ตามsearchsortedและnumbaส่งคืนดัชนีที่ไม่ใช่ดัชนีที่ถูกต้องสำหรับอาร์เรย์

ฟังก์ชั่นwhere, min, nonzeroและcalculateโยนข้อยกเว้น อย่างไรก็ตามข้อยกเว้นสำหรับการcalculateพูดสิ่งที่เป็นประโยชน์จริง ๆ เท่านั้น

นั่นหมายความว่าเราต้องตัดการเรียกเหล่านี้ในฟังก์ชัน wrapper ที่เหมาะสมซึ่งจะจับข้อยกเว้นหรือค่าส่งคืนที่ไม่ถูกต้องและจัดการอย่างเหมาะสมอย่างน้อยถ้าคุณไม่แน่ใจว่าค่าอาจอยู่ในอาร์เรย์


หมายเหตุ: การคำนวณและsearchsortedตัวเลือกใช้งานได้ในเงื่อนไขพิเศษเท่านั้น ฟังก์ชัน "คำนวณ" ต้องใช้ขั้นตอนคงที่และการค้นหาเรียงตามลำดับจะต้องมีการเรียงลำดับ ดังนั้นสิ่งเหล่านี้อาจเป็นประโยชน์ในสถานการณ์ที่เหมาะสม แต่ไม่ใช่วิธีแก้ไขปัญหาทั่วไปสำหรับปัญหานี้ ในกรณีที่คุณกำลังจัดการกับรายการ Python ที่เรียงลำดับคุณอาจต้องการดูโมดูลbisectแทนที่จะใช้ Numpys searchsorted


3

ฉันอยากจะเสนอ

np.min(np.append(np.where(aa>5)[0],np.inf))

สิ่งนี้จะส่งคืนดัชนีที่เล็กที่สุดซึ่งตรงกับเงื่อนไขในขณะที่คืนค่าอินฟินิตี้หากไม่ตรงตามเงื่อนไข (และwhereส่งคืนอาร์เรย์ว่าง)


1

ฉันจะไปกับ

i = np.min(np.where(V >= x))

โดยที่Vvector (อาร์เรย์ 1d) xคือค่าและiเป็นดัชนีผลลัพธ์

โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.