Numpy tutorial:np.expand_dimsとnp.where

numpy tutorialの一環として、このサイトを参考にしてnumpy.expand_dims, numpy.whereの使い方を調べる。

スポンサーリンク

numpy.expand_dims

numpy.expand_dimsの説明は以下のように書いてある。

“Insert a new axis that will appear at the axis position in the expanded array shape.”
拡張された配列形状の軸位置に表示される新たな軸を挿入する。

import numpy as np
x = np.array([1,2,3])
print(x)
print(x.shape)
[1 2 3]
(3,)

The following is equivalent to x[np.newaxis,:] or x[np.newaxis]:
以下はx[np.newaxis,:]またはx[np.newaxis]に等しい。

# (3,) → (1, 3)に変換
y = np.expand_dims(x, axis=0)
print(y)
print(y.shape)
#z = np.expand_dims(y, axis=0)
#print(z)
#print(z.shape)
[[1 2 3]]
(1, 3)
# (3,) → (3, 1)に変換
y = np.expand_dims(x,axis=1) #Equivalent to x[:,np.newaxis]
print(y)
print(y.shape)
[[1]
 [2]
 [3]]
(3, 1)

Note that some examples may use None instead of np.newaxis. These are the same objects:
いくつかの例は、np.newaxisの代わりにNoneを使うかもしれないことに留意する。これらは同じオブジェクトである。

np.newaxis is None
True

numpy.where

numpy.whereの説明は以下のように書いてある。

“Return elements, either from x or y, depending on condition. If only condition is given, return condition.nonzero().”
条件によりxかyの要素を返す。条件だけの場合condition.nonzero()を返す。

a = np.array([[1,2,3],
                 [4,5,6],
                 [7,8,9]])
print(a)
print(a.shape)
[[1 2 3]
 [4 5 6]
 [7 8 9]]
(3, 3)
# 8のインデックスの行[2]列[1]を返す
np.where(a == 8)
(array([2]), array([1]))
# 8以上のインデックスを返す
# 9のインデックスの行[2]列[2]
np.where(a >= 8)
(array([2, 2]), array([1, 2]))
# True = [1, 2], [3, 4]
# False = [9, 8], [7, 6]
# [True=1,False=8],[True=3,True=4]
np.where([[True,False],[True,True]],
...          [[1, 2], [3, 4]],
...          [[9, 8], [7, 6]])
array([[1, 8],
       [3, 4]])
# True = [1, 2], [3, 4]
# False = [9, 8], [7, 6]
# [False=9,True=2],[False=7,False=6]
np.where([[False,True],[False,False]],
          [[1, 2], [3, 4]],
          [[9, 8], [7, 6]])
array([[9, 2],
       [7, 6]])
np.where([[0, 1], [1, 0]])
(array([0, 1]), array([1, 0]))
# x > 5 → 6,7,8は全て行2
# 6は列0、7は列1、8は列2
x = np.arange(9.).reshape(3, 3)
print(x)
np.where( x > 5 )
[[0. 1. 2.]
 [3. 4. 5.]
 [6. 7. 8.]]
(array([2, 2, 2]), array([0, 1, 2]))
# 行列が1Dに平坦化される。
x[np.where( x > 3.0 )] #Note: result is 1D.
array([4., 5., 6., 7., 8.])
# x < 5 → 0,1,2,3,4以外を-1に置き換える
np.where(x < 5, x, -1) #Note: broadcasting.
array([[ 0.,  1.,  2.],
       [ 3.,  4., -1.],
       [-1., -1., -1.]])
# x < 5 → 6,7,8以外を-1に置き換える
np.where(x > 5, x, -1) #Note: broadcasting.
array([[-1., -1., -1.],
       [-1., -1., -1.],
       [ 6.,  7.,  8.]])

Find the indices of elements of x that are in a.
aの中にあるxの要素のインデックスを探し出す。

a = [3, 4, 7]
ix = np.isin(x, a)
ix
array([[False, False, False],
       [ True,  True, False],
       [False,  True, False]])
# 3,4,7(True)の位置を見つける
# 3は行1列0、4は行1列1、7は行2列1
np.where(ix)
(array([1, 1, 2]), array([0, 1, 1]))