フィッティング
|
SciPyフィッティング (fitting)SciPy モジュール「scipy.optimize.curve_fit」を使ったフィッティング (fitthing) の方法を示します. 目次はじめにフィッティングとは実際に観測されるデータは,誤差を持ちます.この誤差を持ったデータから,推定される曲線 (関数) を得る手法をフィティングと言います.具体例を示しましょう.図のようにバネが吊り下げられています.そこに,重りを1[kg]から10[kg]まで変化させて,バネの長さを測定します.この測定結果は以下のとおりです. この測定結果から,重さがゼロの時のバネの長さ(初期長さ)とバネ定数を求めることを考えます.測定結果をプロットすると,図が得られます.このプロットから最も確からしい直線を引くことができれば,最も確からしい初期長さとバネ定数を得ることができます.バネの長さ\(y\)と重りの重さ\(x\)には, \begin{align} y=a+bx \end{align} の関係があります.\(a\) が初期長さ,\(b\) がバネ定数です.図が最も確からしい直線です.これは,誤差の二乗である\(\sum\left[f(x_i)-y_i\right]^2\)を最小にした直線です. フィッティング方法Python には,フィッティングのためのモジュール「scipy.optimize.curve_fit」があります.これを使うと容易に誤差を持つデータを任意の関数でフィッティングすることができます.これを使うためにのステップは,次のとおりです.
フィッティングに必要な手順を踏んでおり,この流れは簡単に理解できます. 使ってみようここでは,誤差を持ったデータとして,\(y=\sin(x)\exp(-x/5)+\mathrm{Noise}\) を考えます.\(\mathrm{Noise}\) が観測で,それは標準偏差1を持った正規分布に0.05を乗じた値とします.これを,フィッティングパラメーター\((a,\,b,\,c)\)を持った関数 \(f(x)=a\sin(x)\exp(-bx)+c\) でフィッティングします.ノイズの無い元の関数と比較すると,\((a,\,b,\,c)\simeq(1,\,0.2,\,0)\) となることが分かります.サンプル数が増えると,この値になるはずです. それでは,実際のプログラムを以下に示します.先に示したステップのモジュールのインポートは 004 行,フィッティング関数の定義は 056 行,データの準備は 065 と 067 行,フィッティングの実行は 023 行です. このプログラムはメインルーチンで,(1) データ\((x,\,y)\)のデータを作成し,(2) インスタンス fit を生成し.(3) フィッティングの操作を行い,(4)プロットを描きます.メインルーチンを見れば分かるでしょう. 001 # -*- coding: utf-8 -*- 002 import numpy as np 003 import matplotlib.pyplot as plt 004 from scipy.optimize import curve_fit 005 006 # =========================================================== 007 # フィッティングのクラス 008 # =========================================================== 009 class FITTING(object): 010 # ------------------------------------------------------- 011 # コンストラクター 012 # ------------------------------------------------------- 013 def __init__(self, function, x, y): 014 self.f = function 015 self.x = x 016 self.y = y 017 018 019 # ------------------------------------------------------- 020 # フィッティングの実行 021 # ------------------------------------------------------- 022 def do_fitting(self): 023 self.popt, self.pcov = curve_fit(self.f, self.x, self.y) 024 025 print('a: {0:e}\nb: {1:e}\nc: {2:e}'.\ 026 format(self.popt[0], self.popt[1], self.popt[2])) 027 028 029 # ------------------------------------------------------- 030 # フィッティング結果のプロット 031 # ------------------------------------------------------- 032 def plot(self, Nx=65): 033 xmin, xmax = min(self.x), max(self.x) 034 xp = np.linspace(xmin, xmax, Nx) 035 fig = plt.figure() 036 plot = fig.add_subplot(1,1,1) 037 plot.set_xlabel("x", fontsize=12, fontname='serif') 038 plot.set_ylabel("y", fontsize=12, fontname='serif') 039 plot.tick_params(axis='both', length=10, which='major') 040 plot.tick_params(axis='both', length=5, which='minor') 041 plot.set_xlim(xmin, xmax) 042 plot.set_ylim([-1.2,1.2]) 043 plot.minorticks_on() 044 plot.plot(xp, self.f(xp, *self.popt), 'b-') 045 plot.plot(x, y, 'ro', markersize=10) 046 fig.tight_layout() 047 plt.show() 048 fig.savefig('result.pdf', orientation='portrait', \ 049 transparent=False, bbox_inches=None, frameon=None) 050 fig.clf() 051 052 053 # ------------------------------------------------------- 054 # データをフィッティングする関数 055 # ------------------------------------------------------- 056 def fit_func(x, a, b, c): 057 return a*np.sin(x)*np.exp(-b*x)+c 058 059 060 # ------------------------------------------------------- 061 # メイン関数 062 # ------------------------------------------------------- 063 if __name__ == "__main__": 064 065 x = np.linspace(0, 6*np.pi, 32) 066 noise = 0.05*np.random.normal(size=x.size) 067 y = np.sin(x)*np.exp(-x/5) + noise 068 069 fit = FITTING(fit_func, x, y) 070 fit.do_fitting() 071 fit.plot(Nx=257) 072 データをフィッティング(当てはめ)する関数は,056 – 057 行で定義されています.fit_func(x, a, b, c) がデータをフィッティングする関数です.第一引数のxが関数の独立変数です.残りの (a, b, c) がフィッティングパラメーターです.これらの3つの値を調整し,関数とデータの距離を最小にします. クラス FITTING には,フィッティングを行い結果を表示する機能 (関数) があります.コンストラクター __init__ はフィットする関数とデータをインスタンス変数に,do_fitting フィッティングを実行する,plot は結果を表示する関数です.コンストラクター __init__ は説明するまでも無いでしょう.結果を表示する関数 plot の動作については,Matplotlib を参照ください. 本ページのテーマであるフィッティングを行う関数は,do_fitting です.関数(メソッド) curve_fit(self.f, self.x, self.y) です.引数 self.x がフィッティングする関数,引数 (self.x, self.y) はデータです.戻り値の self.popt はパラメータの最適値です.これは配列で,フィットする関数 fit_func(x, a, b, c) のフィッティングパラメータと\((a,\,b,\,c)\)=(self.popt[0], self.popt[1], self.popt[2]) の関係があります.戻り値のself.pcov は共分散です.これについては,あとで説明します. プログラムを実行すると,端末に popt と pconv の値が以下のように表示されます.popt は推定される値 \((1,\,0.2,\,0)\) に近いことが分かるでしょう. popt = [ 0.93104142 0.18006138 0.00347886] pcov = [[ 3.66792254e-03 7.50297351e-04 -1.80400603e-04] [ 7.50297351e-04 2.65394736e-04 -3.83320805e-05] [ -1.80400603e-04 -3.83320805e-05 1.12286047e-04]] 合わせて,図1 に示すプロットも表示されます. 共分散とはscipy.optimize.curve_fit は,あまり聞き慣れない共分散が計算結果として出力されます. scipy.optimize.curve_fitフィッティングのモジュール「scipy.optimize.curve_fit」の引数と戻り値,エラー/ワーニングについて説明します. scipy.optimize.curve_fit( f, xdata, ydata, p0=None, sigma=None, absolute_sigma=False, check_finite=True, bounds=(-inf, inf), method=None, jac=None, **kwargs) 引数
戻り値
エラー/ワーニング
サンプルプログラムフィッティング結果をプロット中に001 ''' 002 測定データとそのフィッティング結果を表示する. 003 ''' 004 import matplotlib.pyplot as plt 005 from matplotlib.backends.backend_pdf import PdfPages 006 import numpy as np 007 from scipy import interpolate 008 import scipy.optimize as optimize 009 010 011 # ================================================================= 012 # クラス: データを読み込んでフィッティングパラメーターを決める. 013 # ================================================================= 014 class Fit_Fuction(): 015 ''' 016 データからフィッティング関数を決めるクラス. 017 ''' 018 def __init__(self, data_file, x_column=0, y_column=1): 019 '''初期化を実行する.''' 020 print('data file: {0:s}'.format(data_file)) 021 self.data = np.genfromtxt(data_file) 022 self.x, self.y = self.data[:,x_column], self.data[:,y_column] 023 print(self.x) 024 print('\n\n') 025 print(self.y) 026 # フィッティングの実行.sol[0]がパラメーターリスト 027 self.sol = optimize.curve_fit(self.f, self.x, self.y) 028 029 def f(self, x, a, s, x0): 030 '''関数の定義: 関数の値を返す.''' 031 return a*np.exp(-(x-x0)**2/(2*s**2)) 032 033 def fitted_f(self, x): 034 '''フィッティング結果関数の値を返す.''' 035 return self.f(x, *self.sol[0]) 036 037 def get_data(self): 038 '''フィッティングの元データを返す.''' 039 return self.x, self.y 040 041 def get_fit_results(self): 042 '''フィッティングのパラメーターを返す.''' 043 return self.sol # sol[0] がフィッティング結果 044 045 046 # ================================================================= 047 # クラス: データをフィッティング関数をプロット 048 # ================================================================= 049 class Plot_Data(): 050 ''' 051 データとフィッティング関数をプロットするクラス. 052 ''' 053 def __init__(self, fit_class): 054 '''プロットのクラスの初期化をおこなう.''' 055 self.x, self.y = fit_class.get_data() 056 self.xmin, self.xmax = 0, 1.0 057 self.ymin, self.ymax = 0, 1.2 058 self.fit_x = np.linspace(self.xmin, self.xmax, 512) 059 self.fit_y = fit_class.fitted_f(self.fit_x) 060 title_fn = '$f(x) = a_0\exp[-x^2/(2\sigma^2)]$\n' 061 title_res = ' $a_0={0:g}$\n $\sigma={1:g}$\n $x_0={2:g}$'.\ 062 format(*fit_class.get_fit_results()[0]) 063 self.title = title_fn + title_res # プロット中のタイトル 064 self.xl, self.yl = 0.05, 0.8 # タイトルの座標 065 066 def mk_plot(self): 067 '''元データとフィッティング結果をプロットする''' 068 pp = PdfPages('plot_data.pdf') 069 fig = plt.figure() 070 ax1 = fig.add_subplot(2,2,1) 071 ax1.text(self.xl, self.yl, self.title, color='red', fontsize=10) 072 ax1.set_xlabel("x", fontsize=12, fontname='serif') 073 ax1.set_ylabel("y", fontsize=12, fontname='serif') 074 ax1.tick_params(direction='in', axis='both', length=10, which='major', 075 bottom=True, top=True, left=True, right=True) 076 ax1.tick_params(direction='in', axis='both', length=5, which='minor', 077 bottom=True, top=True, left=True, right=True) 078 ax1.set_xlim([self.xmin, self.xmax]) 079 ax1.set_ylim([self.ymin, self.ymax]) 080 ax1.minorticks_on() 081 ax1.plot(self.fit_x, self.fit_y, color='red', linestyle='solid', linewidth=1.0) 082 ax1.scatter(self.x, self.y, color='blue', marker='o', s=20) 083 084 plt.tight_layout() 085 plt.subplots_adjust(top=0.9) 086 plt.draw() 087 plt.show() 088 fig.savefig(pp, format='pdf', orientation='portrait', \ 089 transparent=False, bbox_inches=None, frameon=None) 090 fig.clf() 091 pp.close() 092 093 094 # ================================================================= 095 # メインルーチン 096 # ================================================================= 097 if __name__ == '__main__': 098 fit_func = Fit_Fuction('sample.dat', x_column=0, y_column=1) 099 plot = Plot_Data(fit_func) 100 plot.mk_plot() ページ作成情報参考資料
更新履歴
|