怎么理解numpy的where()函数?

不太理解numpy.where()这个函数,看了官方文件,不太明白,比如下面这段 [xv if c else yv for (c,xv,yv) in …
关注者
75
被浏览
284,009

12 个回答

官方解释连接如下,可惜对于小白来说有点难以理解

numpy.where - NumPy v1.14 Manual

我的理解如下:

numpy.where()分两种调用方式:

1、三个参数np.where(cond,x,y):满足条件(cond)输出x,不满足输出y

2、一个参数np.where(arry):输出arry中‘真’值的坐标(‘真’也可以理解为非零)

实例:

1、np.where(cond,x,y):

同理:

2、np.where(arry)

np.where(x)输出的是八个不为0的数(为'真'的数)的坐标,第一个array[ ]是横坐标,第二个array[ ]是纵坐标;

即如下图所示:

同理:

如有错误欢迎指正!


以下是在看《Python科学计算(第二版)》时看到的关于NumPy的where函数的介绍(感觉用语比我这样野生的要专业):

在NumPy中,where()函数可以看作判断表达式的数组版本:

x = where(condition,y,z)

其中condition、y和z都是数组,它的返回值是一个形状与condition相同的数组。当condition中的某个元素为True时,x中对应下标的值从数组y获取,否则从数组z获取:

如果y和z是单个数值或者它们的形状与condition的不同,将先通过广播运算使其形状一致:

由于运算是在C语言级别完成的,所以计算效率比较高。

也欢迎关注我的知乎账号 @石溪 ,将持续发布机器学习数学基础及Python数据分析编程应用等方面的精彩内容。

条件逻辑的数组运算:np.where

这个其实功能上类似于python内置列表中的列表解析式,但是其表述更为简洁,在大数据运算方面更快(因为列表解析式的底层是纯python),从例子中可以看出,赋值既可以是标量,也可以是数组形式

import numpy as np  
arr = np.random.randn(4,4) 
print(arr) 
print(np.where(arr>0,2,-2)) 
print(np.where(arr>0,2,arr))  
[[ 0.19699344 -0.6502777  -1.03611804 -0.43403437]  
 [-1.95661572  0.44830588 -0.98746604 -0.57244612]  
 [ 0.44935834 -0.67782579 -0.49945472 -0.46147115]  
 [-0.26284806 -0.4260144   0.43380332 -0.04461859]] 
[[ 2 -2 -2 -2]  
 [-2  2 -2 -2]  
 [ 2 -2 -2 -2]