最近はずっとPythonで、RもMatlabも使っていません。頑張れば使えないわけではないですが、できればPythonで完結させたいものです。というわけで今回は基本的な統計解析の描画をPythonでやってみます。

やりたいことは回帰における信頼帯 (confidence band) または信頼区間 (confidence interval) の描画です。線形回帰を例にしてみます。

seabornを使う場合

「Pythonで信頼区間を描画」と検索すればseabornのsns.regplotを用いた方法がすぐに出てきます。今回はこれを使わずに描画を目指しますが、とりあえずコードを紹介。

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("whitegrid")
np.random.seed(0)

# Generate toy data
n_data = 100
x = np.random.randn(n_data)
y = x + np.random.randn(n_data) + 10
# Plot
plt.figure(figsize=(4, 4))
sns.regplot(x, y)
plt.xlabel("x"); plt.ylabel("y")
plt.tight_layout()
plt.show()

描画はできるものの、肝心の統計値が出てきません。

Statsmodels + matplotlibの場合

Statsmodelsとmatplotlibを用いてconfidence bandを描画します。要はfill_betweenで上部信頼限界と下部信頼限界の間を塗りつぶせばいいわけです。 (参考)Using python statsmodels for OLS linear regression

計算式

回帰直線の$1−\alpha$ 信頼区間 $I_y$は

$$ I_y = \hat{y}\pm t_{n-2, \alpha}^* \cdot \frac{\hat{\sigma}}{\sqrt{n}}\sqrt{1+\left(\frac{x-\bar{x}}{s_x}\right)^2} $$

ただし、$n$はサンプルサイズ、$t_{n-2, \alpha}^*$ は$t$値、$\hat{\sigma}^2$ は誤差分散の推定量(二乗誤差を$n−2$で割ったもの)、$s_x$は$x$の標準偏差です。

実装

まず、ライブラリをimportします。

import statsmodels.api as sm
from scipy import stats

次に回帰を実行します。

# Regression
X = sm.add_constant(x) # constant intercept term
 
# Model: y ~ a*x + c
model = sm.OLS(y, X)
fitted = model.fit()
x_pred = np.linspace(x.min(), x.max(), 50)
X_pred = sm.add_constant(x_pred)
y_pred = fitted.predict(X_pred)

回帰の結果を見ておきましょう。

#print(fitted.params)     # the estimated parameters for the regression line
print(fitted.summary())  # summary statistics for the regression
                            OLS Regression Results                            
==============================================================================
Dep. Variable:                      y   R-squared:                       0.544
Model:                            OLS   Adj. R-squared:                  0.540
Method:                 Least Squares   F-statistic:                     117.0
Date:                Wed, 08 Jul 2020   Prob (F-statistic):           2.06e-18
Time:                        15:19:28   Log-Likelihood:                -144.67
No. Observations:                 100   AIC:                             293.3
Df Residuals:                      98   BIC:                             298.6
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
const         10.0752      0.104     96.834      0.000       9.869      10.282
x1             1.1147      0.103     10.817      0.000       0.910       1.319
==============================================================================
Omnibus:                        5.184   Durbin-Watson:                   1.995
Prob(Omnibus):                  0.075   Jarque-Bera (JB):                3.000
Skew:                           0.210   Prob(JB):                        0.223
Kurtosis:                       2.262   Cond. No.                         1.06
==============================================================================

Warnings:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.

Congfidence bandの計算を行います。

# Congfidence band
y_hat = fitted.predict(X) # x is an array from line 12 above
y_err = y - y_hat
mean_x = np.mean(x)
dof = n_data - fitted.df_model - 1 # degree of freedom
alpha = 0.025
t = stats.t.ppf(1-alpha, df=dof) # t-value
s_err = np.sum(y_err**2)
std_err = np.sqrt(s_err/(n_data-2))
std_x = np.std(x)
conf = t*std_err/np.sqrt(n_data)*np.sqrt(1+((x_pred-mean_x)/std_x)**2) 
upper = y_pred + abs(conf)
lower = y_pred - abs(conf)
# Plot
plt.figure(figsize=(4, 4))
plt.scatter(x, y)
plt.plot(x_pred, y_pred, '-', linewidth=2)
plt.fill_between(x_pred, lower, upper, color='#888888', alpha=0.4)
plt.xlabel("x"); plt.ylabel("y")
plt.tight_layout()
plt.show()