Numpy tutorial:np.argminとnp.partition

その買うを、もっとハッピーに。|ハピタス

numpy tutorial seriesの第何弾か忘れたが、numpyチュートリアルの一環としてnumpy.argmin(), numpy.partition()をやる。

スポンサーリンク

numpy.argmin()

numpy.argminとは何なのかをこのサイトを参考にして実践してみる。

import numpy as np

a=np.array([[1,2,4,7],[9,88,6,45],[9,76,3,4]])
print(a)
print(a.shape)
print(a.size)
[[ 1  2  4  7]
 [ 9 88  6 45]
 [ 9 76  3  4]]
(3, 4)
12
# axisの指定がないと行列は1Dにフラットされて
# 結果として0に最小値の1が含まれることになる。
np.argmin(a)
0
array([[ 1,  2,  4,  7],  # 0
       [ 9, 88,  6, 45],  # 1
       [ 9, 76,  3,  4]]) # 2
# 最小値1が行0に、次の2も行0に
# 次の3は行2に、次の4も行2に含まれる
np.argmin(a, axis=0) # axis=0はrow(行)
array([0, 0, 2, 2])
#        0   1   2   3
array([[ 1,  2,  4,  7],
       [ 9, 88,  6, 45],
       [ 9, 76,  3,  4]])
# 最小値1は列0に、次の6は列2に
# 次の3も列2に含まれている。
np.argmin(a, axis=1) # axis=1はcolumn(列)
array([0, 2, 2])

matrix(行列)のrow(行)とcolumn(列)は非常に覚え難い。特に、俺みたいな頭が悪い人間には拷問とも言える。上の例の場合、row(行)は上から下に数字が変わっていくので一行に数字を4つ含んでいる一方で、column(列)は左から右に数字が変わっていくので一列に数字を3つ含む3行✕4列の行列を意味する。自分で言っていることが、果たして正しいのかどうかさえも分からくなっているぐらいの酷い有様だ。numpy.argminは、軸に沿った最小値をインデックスで返す関数と覚えておけばいいようだ。

スポンサーリンク

numpy.partition()

numpy.partitionとは何なのかはこのサイトに以下のように書いてあった。

Return a partitioned copy of an array. Creates a copy of the array with its elements rearranged in such a way that the value of the element in k-th position is in the position it would be in a sorted array. All elements smaller than the k-th element are moved before this element and all equal or greater are moved behind it. The ordering of the elements in the two partitions is undefined.

分割した配列のコピーを返す。k-th(k番目に小さい)要素より小さい全要素は、この要素の前に移動され、等しいか大きい要素は後ろに移動される。2つのパーティションの要素の序列は未定義。

a = np.array([3, 4, 2, 1])
a
array([3, 4, 2, 1])
# 3番目に小さい3より小さい数字は前に
# 3より大きい数字は後ろに移動
# 要素の順序が2,1なのに注目
np.partition(a, 3)
array([2, 1, 3, 4])
# 先ず最少値の1より大きい数字が後ろに
# 次に3番目に小さい3より小さい数字が前に
# 3より大きい数字は後ろに移動
np.partition(a, (1, 3))
array([1, 2, 3, 4])

使い方としては、例えば、numpy.minで最小値を選んだ後に、2番目以降に小さい数字を選ぶ時などに使える。

a = np.array([[18, 12, 8, 22],
         [45, 17, 32, 15],
         [19, 10, 26, 9]])
a = a.flatten()
a
array([18, 12,  8, 22, 45, 17, 32, 15, 19, 10, 26,  9])
# 2番目に小さい数字9より小さい数字が前に
# 9より大きい数字は後ろに移動される。
np.partition(a,2)
array([ 8,  9, 10, 22, 45, 17, 32, 15, 19, 18, 26, 12])

上記の1D配列から2番目に小さい数字を抜き出す場合。

np.partition(a,2)[1]
9

最小値を抜き出す場合。

np.min(a)
8

最大値を抜き出す場合。

np.max(a)
45

最大値の次に大きい数字を抜き出す場合。

np.partition(a,11)[10]
32

覚えておくと色々使える便利な関数だ。

スポンサーリンク
スポンサーリンク