numpy tutorialの一環としてnumpy.linalg.lstsqをやる。この関数についてはこのサイトに以下のように書いてある。ついでに、scipy版とも何が違うのか比較してみる。
Solves the equation $ax = b$ by computing a vector $x$ that minimizes the Euclidean 2-norm $|| b – ax ||^2$. The equation may be under-, well-, or over- determined (i.e., the number of linearly independent rows of a can be less than, equal to, or greater than its number of linearly independent columns). If a is square and of full rank, then $x$ (but for round-off error) is the “exact” solution of the equation.
方程式$ax = b$をユークリッド2ノルム$|| b – ax ||^2$を最小にするベクトル$x$を算出することで解く。$a$が平方で最大階数なら、(丸め誤差がなければ)$x$は方程式の正確な解になる。
numpy.linalg.lstsq¶
Fit a line, $y = mx + c$, through some noisy data-points:
線$y = mx + c$を、幾つかのノイジーデータ点に適合させる。
import numpy as np
x = np.array([0, 1, 2, 3])
y = np.array([-1, 0.2, 0.9, 2.1])
A = np.vstack([x, np.ones(len(x))]).T
A
By examining the coefficients, we see that the line should have a gradient of roughly 1 and cut the $y$-axis at, more or less, -1.
係数を調べることで、線の傾きが約1で、$y軸$の-1辺りを通るだろうことが分かる。
We can rewrite the line equation as $y = Ap$, where $A = [[x 1]]$ and $p = [[m], [c]]$. Now use lstsq to solve for $p$:
直線方程式は、$A = [[x 1]]$と$p = [[m], [c]]$である、$y = Ap$として書き換えられる。次に、$p$を解くのにlstsqを使用する。
m, c = np.linalg.lstsq(A, y, rcond=None)[0]
print(m, c)
Plot the data along with the fitted line:
適合線と一緒にデータをプロットする。
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = 12, 8
plt.rcParams["font.size"] = "17"
plt.plot(x, y, 'o', label='Original data', markersize=10)
plt.plot(x, m*x + c, 'r', label='Fitted line')
plt.legend()
plt.show()
rcond=Noneの意味¶
np.linalg.lstsq(A, y, rcond=None)のrcond=Noneとは何なのか?こいつが何をしているのかを調べるために、先ずこいつを省いてみた。
m, c = np.linalg.lstsq(A, y)[0]
print(m, c)
rcond=Noneは警告をサプレッションするために必要らしい。よく見ると関数の説明書きにも書いてあった。
m, c = np.linalg.lstsq(A, y, rcond=0)[0]
print(m, c)
scipy.linalg.lstsq¶
scipy.linalg.lstsqとnumpy.linalg.lstsqを比較する。scipy版についてはこのサイトに以下のように書いてある。
Compute a vector $x$ such that the 2-norm $|b – A x|$ is minimized.
2-ノルム$|b – A x|$が最少になるようなベクトル$x$を算出する。
from scipy.linalg import lstsq
n, d = lstsq(A, y)[0]
print(n, d)
import matplotlib.pyplot as plt
plt.plot(x, y, 'o', label='Original data', markersize=10)
plt.plot(x, n*x + d, 'r', label='Fitted line')
plt.legend()
plt.show()
両関数は同じような機能を有しているようだ。