この記事は(Dabney, et al., Nature. 2020)におけるDistributional Reinforcement Learningを実装しながら理解しようという内容です。解説とか言うのは恐れ多いので自分用のメモだと思ってください…。また、どちらかというと神経科学寄りの内容です(深層強化学習への応用については触れません)。

この研究はDeepMindとHarvardの内田先生のラボの共同研究で、アニメーション付きの解説記事をDeepMindが書いています (DeepMindのブログ)。Botvinick氏と内田先生の講演をCiNetで聞いたにも関わらず理解が疎かだったのですが、論文が公開されたので、ちゃんと理解しておこうという次第です。また、コード(MATLAB, Python)も公開されており(https://doi.org/10.17605/OSF.IO/UX5RG) 、この記事ではこのコードをかなり参考にしています。

Classical TD learning vs Distributional TD learning

Classical TD learning

TD (Temporal difference) learningにおいて、報酬予測誤差(reward prediction error, RPE) $\delta_{i}$は次のように計算されます (この式はDistributional TD learningでも共通です)。 $$ \delta_{i}=r+\gamma V_{j}\left(x^{\prime}\right)-V_{i}(x) $$ ただし、現在の状態を$x$, 次の状態を$x'$, 予測価値分布を$V(x)$, 報酬信号を$r$, 時間割引率(time discount)を$\gamma$としました。 また、$V_{j}\left(x^{\prime}\right)$は予測価値分布$V\left(x^{\prime}\right)$からのサンプルです。 このRPEは脳内において主に中脳のVTA(腹側被蓋野)やSNc(黒質緻密部)におけるドパミン(dopamine)ニューロンの発火率として表現されています。

ただし、VTAとSNcのドパミンニューロンの役割は同一ではありません。ドパミンニューロンへの入力が異なっています (Watabe-Uchida et al., Neuron. 2012)00281-4)。 また、細かいですがドパミンニューロンの発火は報酬量に対して線形ではなく、やや飽和する非線形な応答関数 (Hill functionで近似可能)を持ちます(Eshel et al., Nat. Neurosci. 2016)。このため著者実装では報酬 $r$に非線形関数がかかっているものもあります。

先ほどRPEはドパミンニューロンの発火率で表現されている、といいました。RPEが正の場合はドパミンニューロンの発火で表現できますが、単純に考えると負の発火率というものはないため、負のRPEは表現できないように思います。ではどうしているかというと、RPEが0(予想通りの報酬が得られた場合)でもドパミンニューロンは発火しており、RPEが正の場合にはベースラインよりも発火率が上がるようになっています。逆にRPEが負の場合にはベースラインよりも発火率が減少する(抑制される)ようになっています (Schultz et al., Science. 1997; Chang et al., Nat Neurosci. 2016)。発火率というのを言い換えればISI (inter-spike interval, 発火間隔)の長さによってPREが符号化されている(ISIが短いと正のRPE, ISIが長いと負のRPEを表現)ともいえます (Bayer et al., J. Neurophysiol. 2007)。

予測価値(分布) $V(x)$ですが、これは線条体(striatum)のパッチ (SNcに抑制性の投射をする)やVTAのGABAニューロン (VTAのドパミンニューロンに投射して減算抑制をする, (Eshel, et al., Nature. 2015))などにおいて表現されています。 この予測価値は通常のTD learningでは次式により更新されます。 $$ V_{i}(x) \leftarrow V_{i}(x)+\alpha_{i} f\left(\delta_{i}\right) $$ ただし、$\alpha_{i}$は学習率(learning rate), $f(\cdot)$はRPEに対する応答関数です。生理学的には$f(\delta)=\delta$を使うのが妥当ですが、後の分位数(quantile)モデルでは$f(\delta)=\text{sign}(\delta)$を用います。

Distributional TD learning

Distributional TD learningではRPEの正負に応じて、予測報酬の更新を異なる学習率($\alpha_{i}^{+}, \alpha_{i}^{-}$)を用いて行います。 $$ \begin{cases} V_{i}(x) \leftarrow V_{i}(x)+\alpha_{i}^{+} f\left(\delta_{i}\right) &\text{for } \delta_{i} \gt 0\\ V_{i}(x) \leftarrow V_{i}(x)+\alpha_{i}^{-} f\left(\delta_{i}\right) &\text{for } \delta_{i} \leq 0 \end{cases} $$ ここで、シミュレーションにおいては$\alpha_{i}^{+}, \alpha_{i}^{-}\sim U(0, 1)$とします($U$は一様分布)。さらにasymmetric scaling factor $\tau_i$を次式により定義します。 $$ \tau_i=\frac{\alpha_{i}^{+}}{\alpha_{i}^{+}+ \alpha_{i}^{-}} $$ なお、$\alpha_{i}^{+}, \alpha_{i}^{-}\in [0, 1]$より$\tau_i \in [0,1]$です。

Classical TD learningとDistributional TD learningにおける各ニューロンのRPEに対する発火率を表現したのが次図となります。

#collapse-hide

import numpy as np
from matplotlib import pyplot as plt
 
# Classical TD learning
N = 10
cmap = plt.cm.brg(np.linspace(0, 0.5, N))
x = np.arange(-1, 1, 1e-2)[:, None]
theta = np.linspace(np.pi/6, np.pi/3, N)
alpha = np.tan(theta)
y = alpha * x
 
# Plot
plt.figure(figsize=(8, 4))
def hide_ticks(): #上と右の軸を表示しないための関数
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().yaxis.set_ticks_position('left')
    plt.gca().xaxis.set_ticks_position('bottom')
 
plt.subplot(1,2,1)
plt.axvline(x=0, color="gray", linestyle="dashed", linewidth=2)
plt.axhline(y=0, color="gray", linestyle="dashed", linewidth=2)
for i in range(N):
    if i == N//2:        
        plt.plot(x, y[:, i], color=cmap[N//2], alpha=1, linewidth=3,
                 label="Neutral")
    else:
        plt.plot(x, y[:, i], color=cmap[N//2], alpha=0.2)
 
hide_ticks()
plt.ylim(-1,1); plt.xlim(-1,1)
plt.xticks([]); plt.yticks([])
plt.legend(loc='upper left')
plt.title("Classical TD learning")
plt.xlabel("RPE")
plt.ylabel("Firing")
 
 
# Distributional TD learning
N = 20
cmap = plt.cm.brg(np.linspace(0, 0.5, N))
x = np.arange(-1, 1, 1e-2)[:, None]
theta = np.linspace(np.pi/16, np.pi*7/16, N)
alpha_pos = np.tan(theta)
alpha_neg = np.tan(theta)[::-1]
 
y = (alpha_pos*(x>0) + (alpha_neg)*(x<=0))*x
 
# Plot
ax = plt.subplot(1,2,2)
plt.axvline(x=0, color="gray", linestyle="dashed", linewidth=2)
plt.axhline(y=0, color="gray", linestyle="dashed", linewidth=2)
for i in range(N):
    if i == 0:        
        plt.plot(x, y[:, i], color=cmap[i], alpha=1, linewidth=3,
                 label="Pessimistic")
    elif i == N//2:        
        plt.plot(x, y[:, i], color=cmap[i], alpha=1, linewidth=3,
                 label="Neutral")
    elif i == N-1:        
        plt.plot(x, y[:, i], color=cmap[i], alpha=1, linewidth=3,
                 label="Optimistic")
    else:
        plt.plot(x, y[:, i], color=cmap[i], alpha=0.2)
hide_ticks()
handles, labels = ax.get_legend_handles_labels()
ax.legend(reversed(handles), reversed(labels), loc='upper left')
plt.ylim(-1,1); plt.xlim(-1,1)
plt.xticks([]); plt.yticks([])
plt.title("Distributional TD learning")
plt.xlabel("RPE")
plt.ylabel("Firing")
plt.show()

Classical TD learningではRPEに比例して発火する細胞しかありませんが、Distributional TD learningではRPEの正負に応じて発火率応答が変化していることがわかります。 特に$\alpha_{i}^{+} \gt \alpha_{i}^{-}$の細胞を楽観的細胞 (optimistic cells)、$\alpha_{i}^{+}\lt \alpha_{i}^{-}$の細胞を悲観的細胞 (pessimistic cells)と著者らは呼んでいます。実際には2群に分かれているわけではなく、gradientに遷移しています。楽観的・悲観的の意味に関しては後でも触れますが、ここではイメージだけお伝えしておきます。まず楽観的細胞ではRPEが正なら「結構もらえるやん」、RPEが負なら「まあそういうときもあるよね」となり最終的な予測価値は通常よりも高くなります。逆に悲観的細胞ではRPEが正なら「もらえたけどいつもそうではないやろ」、RPEが負なら「やっぱあんまもらえんよな」となり最終的な予測価値は通常よりも低くなります。収束する予測価値が細胞ごとに異なることで、$V$には報酬の期待値ではなく複雑な形状の報酬分布が符号化されます。その仕組みについて、次節から見ていきます。

分位数(Quantile)モデルと報酬分布の符号化

RPEに対する応答がsign関数のモデルと報酬分布の分位点への予測価値の収束

さて、Distributional RLモデルでどのようにして報酬分布が学習されるかについてみていきます。この節ではRPEに対する応答関数$f(\cdot)$が符合関数(sign function)の場合を考えます。結論から言うと、この場合はasymmetric scaling factor $\tau_i$は分位数(quantile)となり、予測価値 $V_i$は報酬分布の$\tau_i$分位数に収束します

どういうことかを簡単なシミュレーションで見てみましょう。今、報酬分布を平均2, 標準偏差5の正規分布とします (すなわち$r \sim N(2, 5^2)$となります)。また、$\tau_i = 0.25, 0.5, 0.75 (i=1,2,3)$とします。このとき、3つの予測価値 $V_i \ (i=1,2,3)$はそれぞれ$N(2, 5^2)$の0.25, 0.5, 0.75分位数に収束します。下図はシミュレーションの結果です。左が$V_i$の変化で、右が報酬分布と0.25, 0.5, 0.75分位数の位置 (黒短線)となっています。対応する分位数に見事に収束していることが分かります。

#collapse-hide

import seaborn as sns
from tqdm import tqdm
from matplotlib import gridspec
 
############
### init ###
############
response_func = lambda r: np.sign(r) # RPEの応答関数
 
num_cells = 3 # ニューロン(ユニット)の数
num_steps = 5000 # 訓練回数
base_lrate = 0.02 # ベースラインの学習率
 
reward_mu = 5 # 報酬の平均(正規分布)
reward_sigma = 2 # 報酬の標準偏差(正規分布)
 
distribution = np.zeros(num_cells) # 価値分布を記録する配列
dist_trans = np.zeros((num_steps, num_cells)) # 価値分布を記録する配列
 
alpha_pos = np.array([.1, .2, .3]) # RPEが正のときの学習率
alpha_neg = np.array([.3, .2, .1]) # RPEが負のときの学習率 
tau = alpha_pos / (alpha_pos + alpha_neg) # Asymmetric scaling factor
 
##############
# simulation #
##############
for step in tqdm(range(num_steps)): # 25000 steps
    # 報酬がrandomに選ばれる
    reward = np.random.normal(reward_mu, reward_sigma, size=(1,))
     
    # 報酬誤差(step毎に更新) reward応答をlinearとする
    delta = reward - distribution # (3, )
 
    # deltaが負なら1, 正なら0
    valence = np.array(delta <= 0., dtype=np.float32) # (3, )
 
    # 予測価値分布の更新
    alpha = valence * alpha_neg + (1. - valence) * alpha_pos
    distribution += alpha * response_func(delta) * base_lrate
    dist_trans[step] = distribution # 予測価値分布変化の記録
 
################
# Results plot #
################
steps = np.arange(num_steps)
ylim = (0, 10) # y軸のlim
gs = gridspec.GridSpec(1, 2, width_ratios=[1, 0.25]) 
plt.figure(figsize=(6,4))
plt.subplot(gs[0]) # 予測価値の変化
for i in range(num_cells):    
    plt.plot(steps, dist_trans[:, i],
             label=str((i+1)*25)+"%tile ("+r"$\tau=$"+str((i+1)*0.25)+")")
plt.title("Convergence of value prediction to \n percentile of reward distribution")
plt.xlim(0, num_steps)
plt.ylim(ylim)
plt.xlabel("Learning steps")
plt.ylabel("Learned Value")
plt.legend()
 
# 報酬のサンプリング
rewards = np.random.normal(reward_mu, reward_sigma, size=(1000,))
percentile = np.percentile(rewards, q=[25, 50, 75]) # 報酬の四分位数を取得
plt.subplot(gs[1]) # 報酬分布とその分位数
sns.kdeplot(rewards, bw=1, shade=True, vertical=True)
sns.rugplot(percentile, color='k', lw=2, height=0.2, vertical=True)
plt.title("Reward\n distribution")
plt.ylim(ylim)
plt.xlabel("Density")
plt.tight_layout()
plt.show()
100%|███████████████████████████████████████████████████████████████████████████| 5000/5000 [00:00<00:00, 79574.72it/s]

ここでoptimisticな細胞($\tau=0.75$)は中央値よりも高い予測価値、pessimisticな細胞($\tau=0.25$)は中央値よりも低い予測価値に収束しています。 つまり細胞の楽観度というものは、細胞が期待する報酬が大きいほど上がります。

同様のシミュレーションを今度は200個の細胞 (ユニット)で行います。報酬は0.1, 1, 2 μLのジュースがそれぞれ確率0.3, 0.6, 0.1で出るとします (Extended Data Fig.1と同じような分布にしています)。なお、著者らはシミュレーションとマウスに対してVariable-magnitude task (異なる量の報酬(ジュース)が異なる確率で出る)とVariable-probability task (一定量の報酬がある確率で出る)を行っています。以下はVariable-magnitude taskを行う、ということです。学習結果は次図のようになります。左はGround Truthの報酬分布で、右は$V_i$に対してカーネル密度推定 (KDE)することによって得た予測価値分布です。2つの分布はほぼ一致していることが分かります。

#collapse-hide

response_func = lambda r: np.sign(r) # RPEの応答関数
 
juice_amounts = np.array([0.1, 1, 2]) # reward(ジュース)の量(uL)
juice_probs = np.array([0.3, 0.6, 0.1]) # 各ジュースが出る確率
 
num_cells = 200 # ニューロン(ユニット)の数
num_steps = 25000 # 訓練回数
base_lrate = 0.02 # ベースラインの学習率
   
distribution = np.zeros(num_cells) # 価値分布を記録する配列
 
alpha_pos = np.random.random(size=(num_cells)) # RPEが正のときの学習率
alpha_neg = np.random.random(size=(num_cells)) # RPEが負のときの学習率 
tau = alpha_pos / (alpha_pos + alpha_neg) # Asymmetric scaling factor
 
##############
# simulation #
##############
for step in tqdm(range(num_steps)): # 25000 steps
    # 報酬がrandomに選ばれる
    reward = (np.random.choice(juice_amounts, p=juice_probs)) #(1, )
     
    # 報酬誤差(step毎に更新) reward応答をlinearとする
    delta = reward - distribution # (200, )
 
    # deltaが負なら1, 正なら0
    valence = np.array(delta <= 0., dtype=np.float32) # (200, )
 
    # 予測価値分布の更新
    alpha = valence * alpha_neg + (1. - valence) * alpha_pos
    distribution += alpha* response_func(delta) * base_lrate
 
# tauの大きさでソートする
ind = np.argsort(tau)
tau = tau[ind]
alpha_pos = alpha_pos[ind]
alpha_neg = alpha_neg[ind]
distribution = distribution[ind]
 
################
# Results plot #
################
# 報酬をサンプリング
rewards = (np.random.choice(juice_amounts,size=1000, p=juice_probs))
 
# 結果の描画(価値・報酬分布)
plt.figure(figsize=(8,4))
plt.subplot(1,2,1) # Ground Truth (Reward分布)
plt.title("Reward distribution")
sns.rugplot(rewards, color='k', lw=2, zorder=10)
sns.kdeplot(rewards, bw=.15, color='k', lw=1., shade=True)
plt.xlabel("Reward")
plt.ylabel("Density")
 
plt.subplot(1,2,2) # 学習後のValue(Reward)の分布
plt.title("Learned Value distribution")
sns.kdeplot(distribution, bw=.15, color='k', lw=1., shade=True)
sns.rugplot(distribution, color='k', lw=2, zorder=10)
plt.xlabel("Value")
plt.ylabel("Density")
plt.tight_layout()
plt.show()
100%|█████████████████████████████████████████████████████████████████████████| 25000/25000 [00:00<00:00, 31986.89it/s]

そして$V_i$の経験累積分布関数(CDF)は$r$のサンプリングしたCDFとほぼ同一となっています (下図左)。また、$\tau_i$の関数である$V_i$は分位点関数 (quantile function)または累積分布関数の逆関数 (inverse cumulative distribution function)となっています (下図右)。右の図を転置すると左の青い曲線とだいたい一致しそうなことが分かります。

#collapse-hide

# 結果の描画(累積分布)
plt.figure(figsize=(8,4))
plt.subplot(1,2,1) # 累積分布
sns.kdeplot(distribution, cumulative=True,bw=.05, label="Learned Value")
sns.kdeplot(rewards, cumulative=True, bw=.05, label="Reward (GT)")
plt.xlabel("Reward (Learned Value)")
plt.ylabel("Cumulative probability")
 
plt.subplot(1,2,2) # 累積分布
plt.plot(tau, distribution)
plt.xlabel("Asymmetric scaling factors ("+ r"$\tau$)")
plt.ylabel("Learned Value")
plt.tight_layout()
plt.show()

sign関数を用いたDistributional RLと分位点回帰

それでは、なぜ予測価値 $V_i$は$\tau_i$ 分位点に収束するのでしょうか。Extended Data Fig.1のように平衡点で考えてもよいのですが、後のために分位点回帰との関連について説明します。分位点回帰については記事を書いたので先にそちらを読んでもらうと分かりやすいと思います (→Pythonによる分位点回帰 (Quantile regression))。

実はDistributional RL (かつ、RPEの応答関数にsign関数を用いた場合)における予測報酬 $V_i$の更新式は、分位点回帰(Quantile regression)を勾配法で行うときの更新式とほとんど同じです。分位点回帰では$\delta$の関数$\rho_{\tau}(\delta)$を次のように定義します。 $$ \rho_{\tau}(\delta)=\left|\tau-\mathbb{I}_{\delta \leq 0}\right|\cdot |\delta|=\left(\tau-\mathbb{I}_{\delta \leq 0}\right)\cdot \delta $$ そして、この関数を最小化することで回帰を行います。ここで$\tau$は分位点です。また$\delta=r-V$としておきます。今回、どんな行動をしても未来の報酬に影響はないので$\gamma=0$としています。

ここで、 $$ \frac{\partial \rho_{\tau}(\delta)}{\partial \delta}=\rho_{\tau}^{\prime}(\delta)=\left|\tau-\mathbb{I}_{\delta \leq 0}\right| \cdot \operatorname{sign}(\delta) $$ なので、$r$を観測値とすると、 $$ \frac{\partial \rho_{\tau}(\delta)}{\partial V}=\frac{\partial \rho_{\tau}(\delta)}{\partial \delta}\frac{\partial \delta(V)}{\partial V}=-\left|\tau-\mathbb{I}_{\delta \leq 0}\right| \cdot \operatorname{sign}(\delta) $$ となります。ゆえに$V$の更新式は $$ V \leftarrow V - \beta\cdot\frac{\partial \rho_{\tau}(\delta)}{\partial V}=V+\beta \left|\tau-\mathbb{I}_{\delta \leq 0}\right| \cdot \operatorname{sign}(\delta) $$ です。ただし、$\beta$はベースラインの学習率です。個々の$V_i$について考え、符号で場合分けをすると $$ \begin{cases} V_{i} \leftarrow V_{i}+\beta\cdot |\tau_i|\cdot\operatorname{sign}\left(\delta_{i}\right) &\text { for } \delta_{i}>0\\ V_{i} \leftarrow V_{i}+\beta\cdot |\tau_i-1|\cdot\operatorname{sign}\left(\delta_{i}\right) &\text { for } \delta_{i} \leq 0 \end{cases} $$ となります。$0 \leq \tau_i \leq 1$であり、$\tau_i=\alpha_{i}^{+} / \left(\alpha_{i}^{+} + \alpha_{i}^{-}\right)$であることに注意すると上式は次のように書けます。 $$ \begin{cases} V_{i} \leftarrow V_{i}+\beta\cdot \frac{\alpha_{i}^{+}}{\alpha_{i}^{+}+\alpha_{i}^{-}}\cdot\operatorname{sign}\left(\delta_{i}\right) &\text { for } \delta_{i}>0\\ V_{i} \leftarrow V_{i}+\beta\cdot \frac{\alpha_{i}^{-}}{\alpha_{i}^{+}+\alpha_{i}^{-}}\cdot\operatorname{sign}\left(\delta_{i}\right) &\text { for } \delta_{i} \leq 0 \end{cases} $$ これは前節で述べたDistributional RLの更新式とほぼ同じです。いくつか違う点もありますが、RPEが正の場合と負の場合に更新される値の比は同じとなっています。

このようにRPEの応答関数にsign関数を用いた場合、報酬分布を上手く符号化することができます。しかし実際のドパミンニューロンはsign関数のような生理的に妥当でない応答はせず、RPEの大きさに応じた活動をします。そこで次節ではRPEの応答関数を線形にしたときの話をします。

Expectile モデルとドパミンニューロンからの報酬分布のDecoding

RPEに対する応答が線形なモデルとExpectile回帰

節の最後で述べたようにドパミンニューロンの活動はsign関数ではなく線形な応答をする、とした方が生理学的に妥当です (発火率を表現するならば$f(\delta)=c+\delta\quad(c > 0)$とした方が良いのでしょうが)。それでは予測価値の更新式を $$ \begin{cases} V_{i}(x) \leftarrow V_{i}(x)+\alpha_{i}^{+} \delta_{i} &\text{for } \delta_{i} \gt 0\\ V_{i}(x) \leftarrow V_{i}(x)+\alpha_{i}^{-} \delta_{i} &\text{for } \delta_{i} \leq 0 \end{cases} $$ とした場合は、分位点回帰ではなく何に対応するのでしょうか。結論から言えば、この場合はエクスペクタイル回帰(Expectile regression)と同じになります。そもそも、expectileというのは聞きなれないですが、expectileという用語自体はexpectationとquantileを合わせたような概念、というところから来ています。中央値(median)に対する分位数(quantile)が、平均(mean)あるいは期待値(expectation)に対するexpectileの関係と同じであると捉えると良いです。 もう少し言えば、前者は誤差のL1ノルム, 後者はL2ノルムの損失関数を最小化することにより得られます (cf. Quantile and Expectile Regressions)。

分位点回帰で用いた損失関数は$$\rho_{\tau}(\delta)=\left|\tau-\mathbb{I}_{\delta \leq 0}\right|\cdot |\delta|$$でしたが、最後の$|\delta|$を$\delta^2$として、 $$\rho^E_{\tau}(\delta)=\left|\tau-\mathbb{I}_{\delta \leq 0}\right|\cdot \delta^2$$ とします。これを微分すれば $$ \frac{\partial \rho^E_{\tau}(\delta)}{\partial \delta}=\rho_{\tau}^{E\prime}(\delta)=2 \cdot \left|\tau-\mathbb{I}_{\delta \leq 0}\right| \cdot \delta $$ となり、上記の予測価値の更新式がExpectile回帰の損失関数から導けることが分かります。

報酬分布のデコーディング (decoding)

それで、RPEの応答を線形とした場合は報酬分布を上手く学習できるのかという話ですが、実はRPEの応答をsign関数とした場合と同じように学習後の予測価値の分布を求めても報酬分布は復元されません (簡単な修正で確認できます)。そこで報酬分布をデコーディングする方法を考えます。

デコーデイングには各細胞が学習した予測価値(またはreversal points) $V_i$, asymmetries $\tau_i$, および報酬分布(ただし報酬の下限と上限からの一様分布)からのサンプル $z_m (m=1,2,\cdots, M)$を用います。$N$を推定する$V_i$の数、$M=100$を1つの報酬サンプル集合$\{z_m\}$内の要素数としたとき、次の損失関数を最小にする集合$\{z_m\}$を求めます。 $$ \mathcal{L}(z, V, \tau)=\frac{1}{M} \sum_{m-1}^{M} \sum_{n=1}^{N}\left|\tau_{n}-\mathbb{I}_{z_{m} \leq V_{n}}\right|\left(z_{m}-V_{n}\right)^{2} $$ ここで、集合$\{z_m\}$は20000回サンプリングするとします。損失関数$\mathcal{L}$を最小化する集合の分布が推定された報酬分布となっているので、それをplotします。以下はその結果とコードです (このコードはほとんど著者実装のままです)。灰色が元の報酬分布で、紫がデコーデイングされた分布です。完全とはいきませんが、ある程度は推定できていることが分かります。

#collapse-hide

import scipy.stats
import scipy.optimize
 
def expectile_loss_fn(expectiles, taus, samples):
  """Expectile loss function, corresponds to distributional TD model """
  # distributional TD model: delta_t = (r + \gamma V*) - V_i
  # expectile loss: delta = sample - expectile
  delta = (samples[None, :] - expectiles[:, None])
 
  # distributional TD model: alpha^+ delta if delta > 0, alpha^- delta otherwise
  # expectile loss: |taus - I_{delta <= 0}| * delta^2
 
  # Note: When used to decode we take the gradient of this loss,
  # and then evaluate the mean-squared gradient. That is because *samples* must
  # trade-off errors with all expectiles to zero out the gradient of the 
  # expectile loss.
  indic = np.array(delta <= 0., dtype=np.float32)
  grad = -0.5 * np.abs(taus[:, None] - indic) * delta
  return np.mean(np.square(np.mean(grad, axis=-1)))
 
def run_decoding(reversal_points, taus, minv=0., maxv=1., method=None,
                 max_samples=1000, max_epochs=10, M=100):
  """Run decoding given reversal points and asymmetries (taus)."""
   
  # sort
  ind = list(np.argsort(reversal_points))
  points = reversal_points[ind]
  tau = taus[ind]
 
  # Robustified optimization to infer distribution
  # Generate max_epochs sets of samples,
  # each starting the optimization at the best of max_samples initial points.
  sampled_dist = []
  for _ in range(max_epochs):
      # Randomly search for good initial conditions
      # This significantly improves the minima found
      samples = np.random.uniform(minv, maxv, size=(max_samples, M))
      fvalues = np.array([expectile_loss_fn(points, tau, x0) for x0 in samples])
 
      # Perform loss minimizing on expectile loss (w.r.t samples)
      x0 = np.array(sorted(samples[fvalues.argmin()]))
      fn_to_minimize = lambda x: expectile_loss_fn(points, tau, x)
      result = scipy.optimize.minimize(
              fn_to_minimize, method=method,
              bounds=[(minv, maxv) for _ in x0], x0=x0)['x']
      sampled_dist.extend(result.tolist())
 
  return sampled_dist, expectile_loss_fn(points, tau, np.array(sampled_dist))
 
 
# reward distribution
juice_amounts = np.array([0.1, 0.3, 1.2, 2.5, 5, 10, 20])
juice_empirical_probs = np.array(
    [0.06612594, 0.09090909, 0.14847358, 0.15489467,
     0.31159175, 0.1509519 , 0.07705306])
 
# samples of reward (1000, )
sampled_empirical_dist = np.random.choice(
    juice_amounts, p=juice_empirical_probs, size=1000)
 
n_trials = 10 # num of simulation trial
n_epochs = 20000 # num of simulation epoch
num_cells = 151  # num of cells or units
n_decodings = 5 # num of decodings
 
# Global scale for learning rates
beta = 0.2
 
# Distributional TD simulation and decoding
distribution = np.zeros((n_trials, num_cells))
alpha_pos = np.random.random((num_cells))*beta
alpha_neg = np.random.random((num_cells))*beta 
# alpha_neg = beta - alpha_pos としてもよい
 
# Simulation
for trial in tqdm(range(n_trials)):
    for step in range(n_epochs):
        # Sample reward
        reward = np.random.choice(juice_amounts, p=juice_empirical_probs)
        # Compute TD error
        delta = reward - distribution[trial]
        # Update distributional value estimate
        valence = np.array(delta <= 0., dtype=np.float32)
        alpha = valence * alpha_neg + (1. - valence) * alpha_pos
        distribution[trial] += alpha * delta
 
# Decoding from distributional TD (DTD) simulation
dtd_samples = [] # 
dtd_losses = [] # decoding loss
taus = alpha_pos / (alpha_pos + alpha_neg)
 
asym_variance = 0.2
 
for t in tqdm(range(n_decodings)):
    # Add noise to the scaling, but have mean 0.5 giving symmetric updates
    scaling_noise = np.tanh(np.random.normal(size=len(taus))) * asym_variance
    noisy_tau = np.clip(taus + scaling_noise, 0., 1.) # add noise
 
    # Run decoding for distributional TD
    values = run_decoding(
      distribution.mean(0), noisy_tau, 
      minv=juice_amounts.min(), maxv=juice_amounts.max(),
      max_epochs=1, M=100, max_samples=20000, method='TNC')
 
    dtd_samples.append(values[0])
    dtd_losses.append(values[1])
    # print(t, values[1]) 
 
# results of decoding
dtd_reward_decode = np.array(dtd_samples).flatten()
 
# plot
fig = plt.figure(figsize=(8, 5))
# Ground truth
sns.kdeplot(sampled_empirical_dist, bw=.75, color='k', lw=0., shade=True)
sns.rugplot(sampled_empirical_dist, color="red", lw=2, zorder=10, label="Empirical")
 
# decoded distribution
sns.kdeplot(dtd_reward_decode, bw=.75, color=plt.cm.plasma(0), lw=4., zorder=5, shade=False)
sns.rugplot(dtd_reward_decode, color=plt.cm.plasma(0), label='Decoded')
for draw in dtd_samples:
  sns.kdeplot(draw, bw=.5, color=plt.cm.plasma(0.), alpha=.5, lw=1., shade=False)
 
plt.tick_params(top=False, right=False, labelsize=14)
plt.legend(loc='best', fontsize=16)
plt.xlabel("Reward", fontsize=16)
plt.ylabel("Density", fontsize=16)
plt.title("Distributional TD Decoding", fontsize=18)
plt.tight_layout()
plt.show()
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00,  1.81it/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:33<00:00,  6.69s/it]

このようにしてRPEに対する応答が線形であるとした場合でも報酬分布を推定できました。同じことを著者らはドパミンニューロンの活動に対しても行い、報酬分布がデコーデイングされることを示しています。ただ、デコーデイングの手間が結構かかっている気がするので、学習した予測価値分布を利用するときにはどのような処理をしているのかは気になります。

参考文献

  • Dabney, W., Kurth-Nelson, Z., Uchida, N. et al. A distributional code for value in dopamine-based reinforcement learning. Nature (2020). https://doi.org/10.1038/s41586-019-1924-6
  • Watabe-Uchida, M. et al. Whole-Brain Mapping of Direct Inputs to Midbrain Dopamine Neurons. Neuron 74, 5, 858 - 873 (2012). https://doi.org/10.1016/j.neuron.2012.03.01700281-4) 00281-4)
  • Eshel, N., Tian, J., Bukwich, M. et al. Dopamine neurons share common response function for reward prediction error. Nat Neurosci 19, 479–486 (2016). https://doi.org/10.1038/nn.4239
  • Schultz, W., Dayan, P., Montague, P.R. A neural substrate of prediction and reward. Science. 275, 1593-9 (1997). doi:10.1126/science.275.5306.1593
  • Chang, C., Esber, G., Marrero-Garcia, Y. et al. Brief optogenetic inhibition of dopamine neurons mimics endogenous negative reward prediction errors. Nat Neurosci 19, 111–116 (2016) doi:10.1038/nn.4191
  • Bayer, H.M., Lau, B., Glimcher, P.W. Statistics of midbrain dopamine neuron spike trains in the awake primate. J Neurophysiol. 98(3):1428-39 (2007). https://doi.org/10.1152/jn.01140.2006
  • Eshel, N., Bukwich, M., Rao, V. et al. Arithmetic and local circuitry underlying dopamine prediction errors. Nature 525, 243–246 (2015). https://doi.org/10.1038/nature14855