ทดสอบว่าอาร์เรย์ numpy มีเพียงศูนย์หรือไม่


93

เราเริ่มต้นอาร์เรย์ numpy ด้วยศูนย์ดังต่อไปนี้:

np.zeros((N,N+1))

แต่เราจะตรวจสอบได้อย่างไรว่าองค์ประกอบทั้งหมดในเมทริกซ์อาร์เรย์ n * n จำนวนที่กำหนดเป็นศูนย์หรือไม่
วิธีการเพียงแค่ต้องคืนค่า True หากค่าทั้งหมดเป็นศูนย์จริง ๆ

คำตอบ:


73

ตรวจสอบnumpy.count_nonzero

>>> np.count_nonzero(np.eye(4))
4
>>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]])
5

9
คุณต้องการnot np.count_nonzero(np.eye(4))ส่งคืนTrueก็ต่อเมื่อค่าทั้งหมดเป็น 0
J. Martinot-Lagarde

166

คำตอบอื่น ๆ ที่โพสต์ไว้ที่นี่จะใช้งานได้ แต่ฟังก์ชันที่ชัดเจนและมีประสิทธิภาพที่สุดในการใช้คือnumpy.any():

>>> all_zeros = not np.any(a)

หรือ

>>> all_zeros = not a.any()
  • เป็นที่ต้องการมากกว่าnumpy.all(a==0)เพราะใช้ RAM น้อยกว่า (ไม่ต้องการอาร์เรย์ชั่วคราวที่สร้างขึ้นโดยa==0คำศัพท์)
  • นอกจากนี้ยังเร็วกว่าnumpy.count_nonzero(a)เนื่องจากสามารถย้อนกลับได้ทันทีเมื่อพบองค์ประกอบที่ไม่ใช่ศูนย์แรก
    • แก้ไข:ตามที่ @Rachel ชี้ให้เห็นในความคิดเห็นnp.any()ไม่ใช้ตรรกะ "ลัดวงจร" อีกต่อไปดังนั้นคุณจะไม่เห็นประโยชน์ด้านความเร็วสำหรับอาร์เรย์ขนาดเล็ก

3
ณ นาทีที่ผ่านมา numpy ของanyและallทำไม่ได้ลัดวงจร ผมเชื่อว่าพวกเขาจะมีน้ำตาลและlogical_or.reduce logical_and.reduceเปรียบเทียบกันและการลัดวงจรของฉันis_in: all_false = np.zeros(10**8) all_true = np.ones(10**8) %timeit np.any(all_false) 91.5 ms ± 1.82 ms per loop %timeit np.any(all_true) 93.7 ms ± 6.16 ms per loop %timeit is_in(1, all_true) 293 ns ± 1.65 ns per loop
Rachel

3
นั่นเป็นจุดที่ดีขอบคุณ ดูเหมือนว่าการลัดวงจรจะเคยเป็นพฤติกรรม แต่ก็หายไปในบางจุด มีการอภิปรายที่น่าสนใจในคำตอบสำหรับคำถามนี้
Stuart Berg

50

ฉันจะใช้ np.all ที่นี่ถ้าคุณมีอาร์เรย์ a:

>>> np.all(a==0)

3
ฉันชอบที่คำตอบนี้ตรวจสอบค่าที่ไม่ใช่ศูนย์เช่นกัน np.all(a==a[0])ตัวอย่างเช่นหนึ่งสามารถตรวจสอบว่าทุกองค์ประกอบในอาร์เรย์จะเหมือนกันโดยการทำ ขอบคุณมาก!
aignas

9

ดังที่คำตอบอื่นกล่าวไว้คุณสามารถใช้ประโยชน์จากการประเมินจริง / เท็จหากคุณรู้ว่านั่น0เป็นองค์ประกอบที่เป็นเท็จเพียงอย่างเดียวที่อาจเกิดขึ้นในอาร์เรย์ของคุณ องค์ประกอบทั้งหมดในอาร์เรย์เป็นเท็จหากไม่มีองค์ประกอบที่แท้จริงอยู่ในนั้น *

>>> a = np.zeros(10)
>>> not np.any(a)
True

อย่างไรก็ตามคำตอบอ้างว่าanyเร็วกว่าตัวเลือกอื่นเนื่องจากส่วนหนึ่งเกิดจากการลัดวงจร ในฐานะของ 2018 Numpy ของallและไม่ได้มีการลัดวงจรany

หากคุณทำสิ่งนี้บ่อยๆการสร้างเวอร์ชันลัดวงจรของคุณเองทำได้ง่ายมากโดยใช้numba:

import numba as nb

# short-circuiting replacement for np.any()
@nb.jit(nopython=True)
def sc_any(array):
    for x in array.flat:
        if x:
            return True
    return False

# short-circuiting replacement for np.all()
@nb.jit(nopython=True)
def sc_all(array):
    for x in array.flat:
        if not x:
            return False
    return True

สิ่งเหล่านี้มักจะเร็วกว่ารุ่นของ Numpy แม้ว่าจะไม่ลัดวงจรก็ตาม count_nonzeroช้าที่สุด

ข้อมูลบางอย่างเพื่อตรวจสอบประสิทธิภาพ:

import numpy as np

n = 10**8
middle = n//2
all_0 = np.zeros(n, dtype=int)
all_1 = np.ones(n, dtype=int)
mid_0 = np.ones(n, dtype=int)
mid_1 = np.zeros(n, dtype=int)
np.put(mid_0, middle, 0)
np.put(mid_1, middle, 1)
# mid_0 = [1 1 1 ... 1 0 1 ... 1 1 1]
# mid_1 = [0 0 0 ... 0 1 0 ... 0 0 0]

ตรวจสอบ:

## count_nonzero
%timeit np.count_nonzero(all_0) 
# 220 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.count_nonzero(all_1)
# 150 ms ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

### all
# np.all
%timeit np.all(all_1)
%timeit np.all(mid_0)
%timeit np.all(all_0)
# 56.8 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.4 ms ± 1.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 55.9 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_all
%timeit sc_all(all_1)
%timeit sc_all(mid_0)
%timeit sc_all(all_0)
# 44.4 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.7 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 288 ns ± 6.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

### any
# np.any
%timeit np.any(all_0)
%timeit np.any(mid_1)
%timeit np.any(all_1)
# 60.7 ms ± 1.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 60 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.7 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_any
%timeit sc_any(all_0)
%timeit sc_any(mid_1)
%timeit sc_any(all_1)
# 41.7 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.4 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 287 ns ± 12.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

* มีประโยชน์allและanyเท่าเทียมกัน:

np.all(a) == np.logical_not(np.any(np.logical_not(a)))
np.any(a) == np.logical_not(np.all(np.logical_not(a)))
not np.all(a) == np.any(np.logical_not(a))
not np.any(a) == np.all(np.logical_not(a))

-8

หากคุณกำลังทดสอบค่าศูนย์ทั้งหมดเพื่อหลีกเลี่ยงคำเตือนเกี่ยวกับฟังก์ชัน numpy อื่นจากนั้นลองใช้การตัดบรรทัดยกเว้นบล็อกจะบันทึกโดยต้องทำการทดสอบศูนย์ก่อนการดำเนินการที่คุณสนใจเช่น

try: # removes output noise for empty slice 
    mean = np.mean(array)
except:
    mean = 0
โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.