python numpy.where () ทำงานอย่างไร


94

ฉันกำลังเล่นnumpyและขุดเอกสารและได้พบกับเวทมนตร์บางอย่าง ฉันกำลังพูดถึงnumpy.where():

>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(array([2, 2, 2]), array([0, 1, 2]))

พวกเขาบรรลุภายในได้อย่างไรว่าคุณสามารถผ่านบางสิ่งบางอย่างx > 5เข้าไปในวิธีการได้? ฉันเดาว่ามันมีบางอย่างที่เกี่ยวข้อง__gt__แต่ฉันกำลังมองหาคำอธิบายโดยละเอียด

คำตอบ:


75

พวกเขาบรรลุภายในได้อย่างไรว่าคุณสามารถส่งบางอย่างเช่น x> 5 ไปเป็นวิธีการได้?

คำตอบสั้น ๆ คือพวกเขาไม่ทำ

การดำเนินการทางตรรกะใด ๆ บนอาร์เรย์ numpy จะส่งคืนอาร์เรย์บูลีน (เช่น__gt__, __lt__ฯลฯ ทั้งหมดอาร์เรย์ผลตอบแทนบูลที่ได้รับเงื่อนไขที่เป็นจริง)

เช่น

x = np.arange(9).reshape(3,3)
print x > 5

ผลตอบแทน:

array([[False, False, False],
       [False, False, False],
       [ True,  True,  True]], dtype=bool)

นี่เป็นเหตุผลเดียวกับที่บางสิ่งเช่นif x > 5:เพิ่ม ValueError ถ้าxเป็นอาร์เรย์ numpy มันคืออาร์เรย์ของค่า True / False ไม่ใช่ค่าเดียว

นอกจากนี้อาร์เรย์จำนวนนับยังสามารถจัดทำดัชนีได้โดยอาร์เรย์บูลีน เช่นx[x>5]อัตราผลตอบแทน[6 7 8]ในกรณีนี้

ตรงไปตรงมาก็ค่อนข้างหายากที่คุณต้องการจริงแต่มันเป็นเพียงแค่ผลตอบแทนดัชนีที่อาร์เรย์แบบบูลคือnumpy.where Trueโดยปกติคุณสามารถทำสิ่งที่คุณต้องการได้ด้วยการสร้างดัชนีบูลีนอย่างง่าย


10
เพียงเพื่อชี้ให้เห็นว่าnumpy.whereจะมี 2 รูปแบบการดำเนินงาน 'แรกผลตอบแทนindicesที่condition is Trueและถ้าพารามิเตอร์ที่ไม่จำเป็นxและyเป็นปัจจุบัน (รูปร่างเช่นเดียวกับconditionหรือ broadcastable รูปร่างดังกล่าว!) ก็จะกลับค่าจากxเมื่อเป็นอย่างอื่นจากcondition is True yดังนั้นสิ่งนี้ทำให้whereมีความหลากหลายมากขึ้นและทำให้สามารถใช้งานได้บ่อย ขอบคุณ
ทาน

1
นอกจากนี้ยังสามารถมีค่าใช้จ่ายในบางกรณีโดยใช้__getitem__ไวยากรณ์ของ[]over อย่างใดอย่างหนึ่งnumpy.whereหรือnumpy.take. เนื่องจาก__getitem__ต้องรองรับการหั่นด้วยจึงมีค่าใช้จ่ายบางส่วน ฉันเห็นความแตกต่างของความเร็วที่เห็นได้ชัดเมื่อทำงานกับโครงสร้างข้อมูล Python Pandas และการสร้างดัชนีคอลัมน์ที่มีขนาดใหญ่มากอย่างมีเหตุผล ในกรณีที่ถ้าคุณไม่จำเป็นต้องหั่นแล้วtakeและwhereเป็นจริงดี
ely

24

คำตอบเก่า มันค่อนข้างสับสน มันทำให้คุณมีสถานที่ (ทั้งหมด) ที่สถานะของคุณเป็นจริง

ดังนั้น:

>>> a = np.arange(100)
>>> np.where(a > 30)
(array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
       48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
       65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
       82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
       99]),)
>>> np.where(a == 90)
(array([90]),)

a = a*40
>>> np.where(a > 1000)
(array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
       43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
       60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
       77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
       94, 95, 96, 97, 98, 99]),)
>>> a[25]
1000
>>> a[26]
1040

ฉันใช้เป็นอีกทางเลือกหนึ่งของ list.index () แต่ก็มีประโยชน์อื่น ๆ อีกมากมายเช่นกัน ฉันไม่เคยใช้กับอาร์เรย์ 2D

http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

คำตอบใหม่ ดูเหมือนว่าบุคคลนั้นกำลังถามบางสิ่งที่เป็นพื้นฐานมากกว่า

คำถามคือคุณจะใช้บางสิ่งที่ช่วยให้ฟังก์ชัน (เช่นที่ใด) ทราบสิ่งที่ร้องขอ

ก่อนอื่นโปรดทราบว่าการเรียกใช้ตัวดำเนินการเปรียบเทียบจะเป็นสิ่งที่น่าสนใจ

a > 1000
array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True`,  True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)`

ซึ่งทำได้โดยการโอเวอร์โหลดเมธอด "__gt__" ตัวอย่างเช่น:

>>> class demo(object):
    def __gt__(self, item):
        print item


>>> a = demo()
>>> a > 4
4

อย่างที่คุณเห็น "a> 4" เป็นรหัสที่ถูกต้อง

คุณสามารถรับรายการและเอกสารของฟังก์ชันที่โอเวอร์โหลดทั้งหมดได้ที่นี่: http://docs.python.org/reference/datamodel.html

สิ่งที่เหลือเชื่อคือการทำสิ่งนี้ง่ายเพียงใด การดำเนินการทั้งหมดใน python ทำได้ในลักษณะดังกล่าว การพูดว่า a> b เทียบเท่ากับ a. gt (ข)!


3
การโอเวอร์โหลดตัวดำเนินการเปรียบเทียบนี้ดูเหมือนจะไม่ทำงานได้ดีกับนิพจน์เชิงตรรกะที่ซับซ้อนกว่านี้เช่นฉันทำไม่ได้np.where(a > 30 and a < 50)หรือnp.where(30 < a < 50)เพราะมันพยายามประเมินตรรกะ AND ของสองอาร์เรย์ของบูลีนซึ่งค่อนข้างไม่มีความหมาย มีวิธีเขียนเงื่อนไขดังกล่าวด้วยnp.whereหรือไม่?
davidA

@meowsqueaknp.where((a > 30) & (a < 50))
tibalt

เหตุใด np.where () จึงส่งคืนรายการในตัวอย่างของคุณ
Andreas Yankopolus

0

np.whereส่งคืนค่า tuple ของความยาวเท่ากับมิติของ numpy ndarray ที่เรียกว่า (หรืออีกนัยหนึ่งndim) และแต่ละรายการของ tuple เป็นดัชนีที่เป็นตัวเลขของค่าเหล่านั้นทั้งหมดใน ndarray เริ่มต้นซึ่งเงื่อนไขเป็นจริง (โปรดอย่าสับสนมิติกับรูปร่าง)

ตัวอย่างเช่น:

x=np.arange(9).reshape(3,3)
print(x)
array([[0, 1, 2],
      [3, 4, 5],
      [6, 7, 8]])
y = np.where(x>4)
print(y)
array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))


y คือทูเปิลของความยาว 2 เนื่องจากx.ndimเป็น 2 รายการที่ 1 ในทูเปิลมีหมายเลขแถวขององค์ประกอบทั้งหมดที่มากกว่า 4 และรายการที่ 2 มีหมายเลขคอลัมน์ของรายการทั้งหมดที่มากกว่า 4 อย่างที่คุณเห็น [1,2,2 , 2] สอดคล้องกับหมายเลขแถวที่ 5,6,7,8 และ [2,0,1,2] ตรงกับหมายเลขคอลัมน์ที่ 5,6,7,8 โปรดทราบว่า ndarray เคลื่อนที่ไปตามมิติแรก (row-wise ).

ในทำนองเดียวกัน

x=np.arange(27).reshape(3,3,3)
np.where(x>4)


จะคืนค่าทูเพิลของความยาว 3 เนื่องจาก x มี 3 มิติ

แต่เดี๋ยวก่อนยังมีอีกมากใน np ที่ไหน!

เมื่อทั้งสองมีปากเสียงจะมีการเพิ่มการnp.where; มันจะทำการแทนที่สำหรับชุดค่าผสมแถว - คอลัมน์คู่ที่ได้รับจากทูเปิลข้างต้น

x=np.arange(9).reshape(3,3)
y = np.where(x>4, 1, 0)
print(y)
array([[0, 0, 0],
   [0, 0, 1],
   [1, 1, 1]])
โดยการใช้ไซต์ของเรา หมายความว่าคุณได้อ่านและทำความเข้าใจนโยบายคุกกี้และนโยบายความเป็นส่วนตัวของเราแล้ว
Licensed under cc by-sa 3.0 with attribution required.