このノートブックでは視覚におけるSparse codingのモデル (Olshausen & Field, Nature. 1996)を解説し、Pythonによる実装を行います。コードを実行したい場合はhttps://github.com/takyamamoto/SparseCoding-OlshausenField-ModelをCloneしてsparse-coding.ipynbか、train.pyを実行してください。

画像のsparse coding

画像 $\boldsymbol{I} (\boldsymbol{x})$が基底関数(basis function) $\Phi = [\phi_i(\boldsymbol{x})]$の線形和 (係数は$\boldsymbol{r}=[r_i] $)で表されるとします。

$$ \boldsymbol{I}(\boldsymbol{x}) = \sum_i r_i \phi_i (\boldsymbol{x}) + \epsilon(\boldsymbol{x})= \Phi \boldsymbol{r} + \epsilon(\boldsymbol{x}) \tag{1} $$

ただし、$\boldsymbol{x}$は画像上の座標, $\epsilon$は平均0のGaussianノイズを表します。また、$\boldsymbol{r}$は入力よりも高次の神経細胞の活動、$\Phi$は重み行列とします。

Sparse codingは、少数の基底で画像 (や目的変数)を表すことを目的とします。要は(1)式において、ほとんどが0で、一部だけ0以外の値を取るという疎 (=sparse, スカスカ)な係数$\boldsymbol{r}$を求めたい、ということです。なお、sparse codingは統計学や機械学習において一般に使用される語であることを念のために言っておきます。

目的関数の設定

Sparse codingのための目的関数(cost function) $E$は(2)式のようになります。

$$ E = \underbrace{\left\|\boldsymbol{I}-\Phi \boldsymbol{r}\right\|^2}_{\text{preserve information}} + \lambda \underbrace{\sum_i S\left(\frac{r_i}{\sigma}\right)}_{\text{sparseness of}\ r_i} \tag{2} $$

ただし、$\lambda$は正則化係数、$\sigma$は定数(scaling constant)です。ここで、第一項が復元損失、第二項が罰則項 (係数が大きな値とならないようにする項)となっています。

$S(x)$としては $-\exp(-x^2), \ln(1+x^2), |x|$ などの関数が用いられます。これらの関数は原点において尖った形状をしており、解が0になりやすくなっています。描画すると以下のようになります。ただし、スケールを合わせるために $-\exp(-x^2)$には$1$を加算しました。

import numpy as np
import matplotlib.pyplot as plt

x = np.arange(-5, 5, 0.1)

plt.figure(figsize=(5,4))
plt.plot(x, np.log(1+x**2), label=r"$\ln(1+x^2)$")
plt.plot(x, np.abs(x), label=r"$|x|$")
plt.plot(x, 1-np.exp(-x**2), label=r"$1-\exp(-x^2)$")
plt.xlabel("x")
plt.ylabel(r"$S(x)$")
plt.legend()
plt.tight_layout()
plt.show()

次節からはSparse codingに基づいて最適な$\boldsymbol{r}$と$\Phi$を得る方法について説明していきます。

Locally Competitive Algorithm (LCA)

$\boldsymbol{r}$の勾配法による更新則は、目的関数 $E$の微分により次のように得られます。

$$ \begin{aligned} \frac{d \boldsymbol{r}}{dt} &= -\frac{\eta_\boldsymbol{r}}{2}\frac{\partial E}{\partial \boldsymbol{r}}\\ &=\eta_\boldsymbol{r} \cdot\left[\Phi^T (\boldsymbol{I}-\Phi\boldsymbol{r})- \frac{\lambda}{2\sigma}S'\left(\frac{r_i}{\sigma}\right)\right] \end{aligned}\tag{3} $$

ただし、$\eta_{\boldsymbol{r}}$は学習率です。この式により$\boldsymbol{r}$が収束するまで最適化しますが、単なる勾配法ではなく、(Olshausen & Field, 1996)では共役勾配法 (conjugate gradient method)を用いています。しかし、共役勾配法は実装が煩雑で非効率であるため、より効率的かつ生理学的な妥当性の高い学習法として、LCA (locally competitive algorithm)が提案されています (Rozell et al., Neural Comput. 2008)。LCAは側抑制 (local competition, lateral inhibition)と閾値関数 (thresholding function)を用いる更新則です。LCAによる更新を行うRNNは通常のRNNとは異なり、コスト関数(またはエネルギー関数)を最小化する動的システムです。このような機構はHopfield networkで用いられているために、Olshausenは"the Hopfield trick"と呼んでいます。

軟判定閾値関数を用いる場合 (ISTA)

$S(x)=|x|$とした場合の閾値関数を用いる手法としてISTA(Iterative Shrinkage Thresholding Algorithm)があります。ISTAはL1-norm正則化項に対する近接勾配法で、要はLasso回帰に用いる勾配法です。

解くべき問題は(4)式で表されます。

$$ \boldsymbol{r} = \mathop{\rm arg~min}\limits_{\boldsymbol{r}}\left\{\|\boldsymbol{I}-\Phi\boldsymbol{r}\|^2_2+\lambda\|\boldsymbol{r}\|_1\right\}\tag{4} $$

詳細は後述しますが、次のように更新することで解が得られます。

  1. $\boldsymbol{r}(0)$を要素が全て0のベクトルで初期化
  2. $\boldsymbol{r}_*(t+1)=\boldsymbol{r}(t)+\eta_\boldsymbol{r}\cdot \Phi^T(\boldsymbol{I}-\Phi\boldsymbol{r}(t))$
  3. $\boldsymbol{r}(t+1) = S_\lambda(\boldsymbol{r}_*(t+1))$
  4. $\boldsymbol{r}$が収束するまで2と3を繰り返す

ここで$S_\lambda(\cdot)$は軟判定閾値関数 (Soft thresholding function)と呼ばれ、(5)式で表されます。

$$ S_\lambda(y)= \begin{cases} y-\lambda & (y>\lambda)\\ 0 & (-\lambda\leq y\leq\lambda)\\ y+\lambda & (y<-\lambda) \end{cases}\tag{5} $$

$S_\lambda(\cdot)$を描画すると次のようになります。

import numpy as np
import matplotlib.pyplot as plt

# thresholding function of S(x)=|x|
def soft_thresholding_func(x, lmda):
    return np.maximum(x - lmda, 0) - np.maximum(-x - lmda, 0)

plt.figure(figsize=(5,4))
x = np.arange(-5, 5, 0.1)
y = soft_thresholding_func(x, 1)
plt.plot(x, x, "k--", label=r"$y=x$")
plt.plot(x, y, label=r"Soft thresholding ($\lambda=1$)")
plt.xlabel("x")
plt.ylabel(r"$S_\lambda (x)$")
plt.legend()
plt.tight_layout()
plt.show()

なお、Soft thresholding関数は次の目的関数$C$を最小化する$x$を求めることで導出できます。

$$ C=\frac{1}{2}(y-x)^2+\lambda |x| \tag{6} $$

ただし、$x, y, \lambda$はスカラー値であるとします。$|x|$が微分できませんが、これは場合分けを考えることで解決します。$x\geq 0$を考えると、(6)式は

$$ C=\frac{1}{2}(y-x)^2+\lambda x = \{x-(y-\lambda)\}^2+\lambda(y-\lambda) \tag{7} $$

となります。ここで、(7)式の最小値を与える$x$は場合分けをして考えると、$y-\lambda\geq0$のとき二次関数の頂点を考えて$x=y-\lambda$となります。 一方で$y-\lambda<0$のときは$x\geq0$において単調増加な関数となるので、最小となるのは$x=0$のときです。同様の議論を$x\leq0$に対しても行うことで (5)式が得られます。

Cauchy分布の閾値関数を用いる場合

次に、$S(x)=\ln(1+x^2)$を用いた場合を考えます。これはCauchy分布を$\boldsymbol{r}$の事前分布 $p(\boldsymbol{r})=\prod_i p(r_i)=\prod_i \exp\left[-\alpha \ln(1+r_i^2)\right]$としたときの負の対数事前分布 $g(\boldsymbol{r}):=-\ln p(\boldsymbol{r})=\alpha \sum_i \ln(1+r_i^2)$に対応しています。

前節と同様に次式で表される、1次元の目的関数$C$を最小にする$x$を求めます。

$$ C=(y-x)^2+\lambda \ln(1+x^2) \tag{8} $$

今度は微分可能なので(8)式の勾配が0となるような$x$を求めます。

$$ \frac{\partial C}{\partial x}=-2(y-x)+\frac{2\lambda x}{1+x^2} \tag{9} $$

(9)式 $=0$とおいて$x$について整理したいところですが、まずは$y$について解いてみると、

$$ y=x+\frac{\lambda x}{1+x^2} \tag{10} $$

となります。後は(10)式の逆関数を求めればよいのですが、どう考えても複雑なのでWolfram alphaを使いましょう (結果)。解を整理すると(11)式のように表されます。

$$ \begin{aligned} S_\lambda(y)&=\frac{1}{3}\left[y-\frac{h_\lambda(y)}{\sqrt[3]{2}}+\frac{\sqrt[3]{2} \cdot g_\lambda(y)}{h_\lambda(y)}\right]\\ f_\lambda(y) &:= 9\lambda y - 2y^3-18y\\ g_\lambda(y) &:= 3\lambda - y^2+3\\ h_\lambda(y) &:= \left[\sqrt{f_\lambda(y)^2+4g_\lambda(y)^3}+f_\lambda(y)\right]^\frac{1}{3} \end{aligned}\tag{11} $$

(11)式を実装してみると次のようになります。

# thresholding function of S(x)=ln(1+x^2)
def ln_thresholding_func(x, lmda):
    f = 9*lmda*x - 2*np.power(x, 3) - 18*x
    g = 3*lmda - np.square(x) + 3
    h = np.cbrt(np.sqrt(np.square(f) + 4*np.power(g, 3)) + f)
    two_croot = np.cbrt(2) # cubic root of two
    return (1/3)*(x - h / two_croot + two_croot*g / (1e-8+h))

plt.figure(figsize=(5,4))
x = np.arange(-5, 5, 0.1)
y = ln_thresholding_func(x, 1)
plt.plot(x, x, "k--", label=r"$y=x$")
plt.plot(x, y, label=r"$\ln(1+x^2)$ thresholding ($\lambda=1$)")
plt.xlabel("x")
plt.ylabel(r"$S_\lambda (x)$")
plt.legend()
plt.tight_layout()
plt.show()

Soft thresholding 関数を緩やかにしたような形状であることが分かります。しかし、これをそのまま用いてもスパースな表現は得られませんでした (ハイパーパラメータの調整不足かもしれませんが)。そこで、(Mayo et al., arXiv. 2020)で提案されているCauchy thresholding関数を用いてみましょう。

Cauchy thresholding関数は次式で表されます。

$$ S_\lambda(y)= \begin{cases} \dfrac{1}{2}\left(y+\sqrt{y^2-4\lambda}\right) & (y>2\lambda)\\ 0 & (-2\lambda\leq y\leq 2\lambda)\\ \dfrac{1}{2}\left(y-\sqrt{y^2-4\lambda}\right) & (y<-2\lambda) \end{cases}\tag{12} $$

実装は次のようになります。

def cauchy_thresholding_func(x, lmda):
    f = 0.5*(x + np.sqrt(np.maximum(x**2 - lmda,0)))
    g = 0.5*(x - np.sqrt(np.maximum(x**2 - lmda,0)))
    return f*(x>=lmda) + g*(x<=-lmda) 

plt.figure(figsize=(5,4))
x = np.arange(-5, 5, 0.1)
ln_thresholding_func
y = cauchy_thresholding_func(x, 1)
plt.plot(x, x, "k--", label=r"$y=x$")
plt.plot(x, y, label=r"Cauchy thresholding ($\lambda=1$)")
plt.xlabel("x")
plt.ylabel(r"$S_\lambda (x)$")
plt.legend()
plt.tight_layout()
plt.show()

硬判定閾値関数 (hard thresholding function)に似た形状をしていることが分かります。

重み行列の更新則

$\boldsymbol{r}$が収束したら勾配法により$\Phi$を更新します。

$$ \begin{aligned} \Delta \phi_i(\boldsymbol{x}) &= -\eta \frac{\partial E}{\partial \Phi}\\ &=\eta\cdot\left[\left([\boldsymbol{I}-\Phi\boldsymbol{r}\right)\boldsymbol{r}^T\right] \end{aligned}\tag{13} $$

Sparse coding networkの実装

ネットワークの構造

入力層を含め2層の単純な構造です。今回は、入力はランダムに切り出した16×16 (=256)の画像パッチとし、これを入力層の256個のニューロンが受け取るとします。入力層のニューロンは次層の100個のニューロンに投射するとします。100個のニューロンが入力をSparseに符号化するようにその活動および重み行列を最適化します。下図はネットワークの構造図です。ただし、$a_i$は本記事における$r_i$に対応します。

fig2a

(Olshausen, 2002; Fig. 2a)

ライブラリのimport

import numpy as np
import matplotlib.pyplot as plt
import network
from tqdm.notebook import tqdm
import scipy.io as sio

np.random.seed(0)

画像データの読み込み

データはhttp://www.rctn.org/bruno/sparsenet/からダウンロードできます。MATLABコードもダウンロードできますが、手元では動かすことができませんでした。IMAGES.matには白色化された10枚の自然画像が記録されています。画像の描画などは『Predictive coding (Rao & Ballard, 1999) モデルの実装』を参照してください。

# datasets from http://www.rctn.org/bruno/sparsenet/
mat_images = sio.loadmat('datasets/IMAGES.mat')
imgs = mat_images['IMAGES']

ネットワークの実装

LCAに基づき、ネットワークを実装します。基本的には__call__関数を実行して係数rと重み行列Phiを更新します。外部から呼び出し、rを収束するまで更新し、収束したときにtrainingTrueにしてPhiを更新します。また、閾値関数としてはsoft_thresholding_func, ln_thresholding_func, cauchy_thresholding_funcの3つを使用できるようにしています。どれを用いるかはコメントアウトを外すなどして変更してください。

class OlshausenField1996Model:
    def __init__(self, num_inputs, num_units, batch_size,
                 lr_r=1e-2, lr_Phi=1e-2, lmda=5e-3):
        self.lr_r = lr_r # learning rate of r
        self.lr_Phi = lr_Phi # learning rate of Phi
        self.lmda = lmda # regularization parameter
        
        self.num_inputs = num_inputs
        self.num_units = num_units
        self.batch_size = batch_size
        
        # Weights
        Phi = np.random.randn(self.num_inputs, self.num_units).astype(np.float32)
        self.Phi = Phi * np.sqrt(1/self.num_units)

        # activity of neurons
        self.r = np.zeros((self.batch_size, self.num_units))
    
    def initialize_states(self):
        self.r = np.zeros((self.batch_size, self.num_units))
        
    def normalize_rows(self):
        self.Phi = self.Phi / np.maximum(np.linalg.norm(self.Phi, ord=2, axis=0, keepdims=True), 1e-8)

    # thresholding function of S(x)=|x|
    def soft_thresholding_func(self, x, lmda):
        return np.maximum(x - lmda, 0) - np.maximum(-x - lmda, 0)

    # thresholding function of S(x)=ln(1+x^2)
    def ln_thresholding_func(self, x, lmda):
        f = 9*lmda*x - 2*np.power(x, 3) - 18*x
        g = 3*lmda - np.square(x) + 3
        h = np.cbrt(np.sqrt(np.square(f) + 4*np.power(g, 3)) + f)
        two_croot = np.cbrt(2) # cubic root of two
        return (1/3)*(x - h / two_croot + two_croot*g / (1e-8+h))

    # thresholding function https://arxiv.org/abs/2003.12507
    def cauchy_thresholding_func(self, x, lmda):
        f = 0.5*(x + np.sqrt(np.maximum(x**2 - lmda,0)))
        g = 0.5*(x - np.sqrt(np.maximum(x**2 - lmda,0)))
        return f*(x>=lmda) + g*(x<=-lmda) 

    def calculate_total_error(self, error):
        recon_error = np.mean(error**2)
        sparsity_r = self.lmda*np.mean(np.abs(self.r)) 
        return recon_error + sparsity_r
        
    def __call__(self, inputs, training=True):
        # Updates                
        error = inputs - self.r @ self.Phi.T
        
        r = self.r + self.lr_r * error @ self.Phi
        self.r = self.soft_thresholding_func(r, self.lmda)
        #self.r = self.cauchy_thresholding_func(r, self.lmda)
        
        if training:  
            error = inputs - self.r @ self.Phi.T
            dPhi = error.T @ self.r
            self.Phi += self.lr_Phi * dPhi
            
        return error, self.r

モデルの定義

上で実装したネットワークをmodelとして定義します。また、シミュレーションで用いる定数などを初期化しておきます。

# Simulation constants
H, W, num_images = imgs.shape
num_iter = 500 # number of iterations
nt_max = 1000 # Maximum number of simulation time
batch_size = 250 # Batch size

sz = 16 # image patch size
num_units = 100 # number of neurons (units)

eps = 1e-2 # small value which determines convergence
error_list = [] # List to save errors

# Define model
model = OlshausenField1996Model(num_inputs=sz**2, num_units=num_units,
                                batch_size=batch_size)

シミュレーションの実行

シミュレーションを実行します。外側のfor loopでは画像パッチの作成とrの初期化を行います。内側のfor loopではrが収束するまで更新を行い、収束したときに重み行列Phiを更新します。

# Run simulation
for iter_ in tqdm(range(num_iter)):
    # Get the coordinates of the upper left corner of clopping image randomly.
    beginx = np.random.randint(0, W-sz, batch_size)
    beginy = np.random.randint(0, H-sz, batch_size)

    inputs_list = []

    # Get images randomly
    for i in range(batch_size):        
        idx = np.random.randint(0, num_images)
        img = imgs[:, :, idx]
        clop = img[beginy[i]:beginy[i]+sz, beginx[i]:beginx[i]+sz].flatten()
        inputs_list.append(clop - np.mean(clop))
        
    inputs = np.array(inputs_list) # Input image patches
    
    model.initialize_states() # Reset states
    model.normalize_rows() # Normalize weights
    
    # Input an image patch until latent variables are converged 
    r_tm1 = model.r # set previous r (t minus 1)

    for t in range(nt_max):
        # Update r without update weights 
        error, r = model(inputs, training=False)
        dr = r - r_tm1 

        # Compute norm of r
        dr_norm = np.linalg.norm(dr, ord=2) / (eps + np.linalg.norm(r_tm1, ord=2))
        r_tm1 = r # update r_tm1
        
        # Check convergence of r, then update weights
        if dr_norm < eps:
            error, r = model(inputs, training=True)
            break
        
        # If failure to convergence, break and print error
        if t >= nt_max-2: 
            print("Error at patch:", iter_)
            print(dr_norm)
            break
   
    error_list.append(model.calculate_total_error(error)) # Append errors

    # Print moving average error
    if iter_ % 100 == 99:  
        print("iter: "+str(iter_+1)+"/"+str(num_iter)+", Moving error:",
              np.mean(error_list[iter_-99:iter_]))
iter: 100/500, Moving error: 0.05376330095599187
iter: 200/500, Moving error: 0.049364071115642096
iter: 300/500, Moving error: 0.04899764202648089
iter: 400/500, Moving error: 0.04885324009619198
iter: 500/500, Moving error: 0.04789606380023766

訓練中の損失の描画

訓練中の損失の変化を描画してみましょう。損失が低下し、学習が進行したことが分かります。

# Plot error
plt.figure(figsize=(5, 3))
plt.ylabel("Error")
plt.xlabel("Iterations")
plt.plot(np.arange(len(error_list)), np.array(error_list))
plt.tight_layout()
plt.show()

重み行列 (受容野)の描画

学習後の重み行列 Phi ($\Phi$)を可視化してみましょう。

# Plot Receptive fields
fig = plt.figure(figsize=(6, 6))
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i in tqdm(range(num_units)):
    plt.subplot(10, 10, i+1)
    plt.imshow(np.reshape(model.Phi[:, i], (sz, sz)), cmap="gray")
    plt.axis("off")

fig.suptitle("Receptive fields", fontsize=20)
plt.subplots_adjust(top=0.9)
plt.show()

白色がON領域(興奮)、黒色がOFF領域(抑制)を表します。Gaborフィルタ様の局所受容野が得られており、これは一次視覚野(V1)における単純型細胞(simple cells)の受容野に類似しています。

なお、cauchy_thresholding_funcを用い、lmda=1e-3とした結果は次図のようになりました。

RF_cauchy_thresholding

ICAとPCAによる自然画像の分析

参考までに自然画像に対してICAとPCAを実行してみます。ICAを実行することはSparse codingを行うことと同じであり、結果としてGaborフィルタ様の局所受容野が得られます。一方で、PCAを実行しても広域的な受容野しか得られません (Sanger TD, 1989)。

ライブラリのimport

sklearn.decompositionを持ちます。

from sklearn.decomposition import FastICA, PCA

画像パッチの作成

白色化前の画像を用いた方がPCAの結果が綺麗です。

mat_images_raw = sio.loadmat('datasets/IMAGES_RAW.mat')
imgs_raw = mat_images_raw['IMAGESr']

# Simulation constants
H, W, num_images = imgs_raw.shape

num_patches = 15000
patchs_list = []
w, h = 16, 16 # patch size

# generate patches
for i in tqdm(range(num_patches)):
    i = np.random.randint(0, num_images)
    # Get the coordinates of the upper left corner of clopping image randomly.
    beginx = np.random.randint(0, W-w-1)
    beginy = np.random.randint(0, H-h-1)
    img_clopped = imgs_raw[beginy:beginy+h, beginx:beginx+w, i]
    patchs_list.append(img_clopped.flatten())

patches = np.array(patchs_list)

ICAの実行

# perform ICA
n_comp = 100
ica = FastICA(n_components=n_comp)
ica.fit(patches)
ica_filters = ica.components_

# plot filters
plt.figure(figsize=(6,6))
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i in tqdm(range(n_comp)):
    plt.subplot(10, 10, i+1)
    plt.imshow(np.reshape(ica_filters[i], (w, h)), cmap="gray")
    plt.axis("off")
plt.suptitle("ICA", fontsize=20)
plt.subplots_adjust(top=0.9)
plt.show()

他にはK-means法を用いることでもGaborフィルタを得ることはできます。

PCAの実行

# perform PCA
pca = PCA(n_components=n_comp)
pca.fit(patches)
pca_filters = pca.components_

# plot filters
plt.figure(figsize=(6,6))
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i in tqdm(range(n_comp)):
    plt.subplot(10, 10, i+1)
    plt.imshow(np.reshape(pca_filters[i], (w, h)), cmap="gray")
    plt.axis("off")
plt.suptitle("PCA", fontsize=20)
plt.subplots_adjust(top=0.9)
plt.show()

まとめ

  • Sparse codingのモデルを実装した。
  • LCA (locally competitive algorithm)は効率良くsparce codingをneural networkで実現するための手法。
  • Sparce codingによりV1の単純型細胞と同様のGaborフィルタが得られる。
  • 脳においても発火の抑制などによる正則化によりsparse codingが実現されていると考えられる。