概要

scGenの論文の解説.実際はほぼVAEの話になった.実験の詳細とかは割愛.LPSが異なる生物に与える影響とかを予測していてかなりアツい.

細胞が外界から刺激を受けたときにどんな反応をするかを予測する.潜在空間で摂動$\delta$が加えられたときに,遺伝子発現空間での変化をニューラルネットワークを使って予測する.

VAE  

variational autoencoder(VAE)では確率分布$P(\boldsymbol{x} _ i;\boldsymbol{\theta})$に従って新しいデータ点が生成される.ただし,確率分布は$P(\boldsymbol{x} _ i;\boldsymbol{\theta})$の対数尤度(を各$\boldsymbol{x} _ i$について足したもの)が最大となるように取る.潜在変数を$\boldsymbol{z}$とすれば,この確率は次のようになる:

P(xi;θ)=P(xizi;θ)P(zi;θ)dzi.(1)\begin{aligned} P(\boldsymbol{x} _ i;\boldsymbol{\theta})=\int P(\boldsymbol{x} _ i\mid \boldsymbol{z} _ i;\boldsymbol{\theta})P(\boldsymbol{z} _ i;\boldsymbol{\theta})\,d\boldsymbol{z} _ i. \end{aligned}\tag{1}

$\boldsymbol{x} _ i$を生成しそうな$\boldsymbol{z} _ i$が潜在空間から正規分布$P(\boldsymbol{z} _ i;\boldsymbol{\theta})$に従ってサンプリングされるような確率分布を求めることが目標になる.

ここで,$P(\boldsymbol{x} _ i\mid \boldsymbol{z} _ i;\boldsymbol{\theta})$に近い確率分布$Q(\boldsymbol{z} _ i\mid \boldsymbol{x} _ i;\boldsymbol{\theta})$をニューラルネットワークで作る.2つの確率分布の近さの評価としてKullback Leibler divergenceを使う:

KL(Q(zixi;ϕ)P(zixi;θ))=EQ(zixi;ϕ)[logQ(zixi;ϕ)logP(zixi;θ)]=EQ(zixi;ϕ)[logQ(zixi;ϕ)logP(zi;θ)P(xi;θ)P(xizi;θ)]=EQ(zixi;ϕ)[logQ(zixi;ϕ)logP(zi;θ)logP(xizi;θ)+logP(xi;θ)]=EQ(zixi;ϕ)[logP(xizi;θ)]+logP(xi;θ)+KL(Q(zixi;ϕ)P(zi;θ)).(2)\begin{aligned} &\text{KL}(Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})\|P(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\theta}))\\ &= E _ {Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})}\left[\log Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})-\log P(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\theta})\right]\\ &= E _ {Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})}\Bigl[\log Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})-\log\frac{P(\boldsymbol{z} _ i;\boldsymbol{\theta})}{P(\boldsymbol{x} _ i;\boldsymbol{\theta})}P(\boldsymbol{x} _ i\mid\boldsymbol{z} _ i;\boldsymbol{\theta})\Bigr]\\ &= E _ {Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})}\bigl[\log Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})-\log P(\boldsymbol{z} _ i;\boldsymbol{\theta})-\log P(\boldsymbol{x} _ i\mid\boldsymbol{z} _ i;\boldsymbol{\theta})+\log P(\boldsymbol{x} _ i;\boldsymbol{\theta})\bigr]\\ &= -E _ {Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})}[\log P(\boldsymbol{x} _ i\mid\boldsymbol{z} _ i;\boldsymbol{\theta})]+\log P(\boldsymbol{x} _ i;\boldsymbol{\theta})+\text{KL}(Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})\| P(\boldsymbol{z} _ i;\boldsymbol{\theta})). \end{aligned}\tag{2}

式変形の途中でBayesの定理を用いた.$Q(\boldsymbol{z} _ i\mid \boldsymbol{x} _ i;\boldsymbol{\theta})$は$P(\boldsymbol{x} _ i\mid \boldsymbol{z} _ i;\boldsymbol{\theta})$の近似であるので,これらの量のKullback Leibler divergenceはほぼ$0$である.よって,

logP(xi;θ)EQ(zixi;ϕ)[logP(xizi;θ)]KL(Q(zixi;ϕ)P(zi;θ)).(3)\begin{aligned} \log P(\boldsymbol{x} _ i;\boldsymbol{\theta})\simeq E _ {Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})}[\log P(\boldsymbol{x} _ i\mid\boldsymbol{z} _ i;\boldsymbol{\theta})]-\text{KL}(Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})\| P(\boldsymbol{z} _ i;\boldsymbol{\theta})). \end{aligned}\tag{3}

(3)式第1項は解析的に解くことは困難なので,Monte Carlo法によるサンプリングによって決める.ただし,$Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})$に従って$\boldsymbol{z} _ i$を選び出す:

ziQ(zixi;ϕ)(4)\boldsymbol{z} _ i\sim Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})\tag{4}

のは不便であるので,再パラメータ化を考える.すなわち,

zi=gϕ(ϵ,xi),ϵp(ϵ)(5)\boldsymbol{z} _ i = g _ {\boldsymbol{\phi}}(\boldsymbol{\epsilon}, \boldsymbol{x} _ i),\quad \boldsymbol{\epsilon}\sim p(\boldsymbol{\epsilon})\tag{5}

となる関数$g _ {\boldsymbol{\phi}}$とノイズ$\boldsymbol{\epsilon}$を適当に選んでやる.こうすれば,

EQ(zixi;ϕ)[f(zi)]=Q(zixi;ϕ)f(zi)dzi=p(ϵ)f(zi)dϵ=Ep(ϵ)[f(zi)],zi=gϕ(ϵ,xi)(6)\begin{aligned} E _ {Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})}[f(\boldsymbol{z} _ i)] &= \int Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi}) f(\boldsymbol{z} _ i)d\boldsymbol{z} _ i\\ &= \int p(\boldsymbol{\epsilon}) f(\boldsymbol{z} _ i)d\boldsymbol{\epsilon}\\ &= E _ {p(\boldsymbol{\epsilon})}[f(\boldsymbol{z} _ i)],\quad \boldsymbol{z} _ i = g _ {\boldsymbol{\phi}}(\boldsymbol{\epsilon}, \boldsymbol{x} _ i) \end{aligned}\tag{6}

となる.よって,再パラメータ化を施せば(3)第1項は

EQ(zixi;ϕ)[logP(xizi;θ)]=1Ll=1LlogP(xizi(l);θ)(7)E _ {Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})}[\log P(\boldsymbol{x} _ i\mid\boldsymbol{z} _ i;\boldsymbol{\theta})] = \frac{1}{L}\sum _ {l=1} ^ L\log P(\boldsymbol{x} _ i\mid\boldsymbol{z} _ i ^ {(l)};\boldsymbol{\theta})\tag{7}

となる.ただし,

zi(l)=gϕ(ϵ(l),xi),ϵ(l)p(ϵ)(8)\boldsymbol{z} _ i ^ {(l)}=g _ {\boldsymbol{\phi}}(\boldsymbol{\epsilon} ^ {(l)}, \boldsymbol{x} _ i),\quad \boldsymbol{\epsilon} ^ {(l)}\sim p(\boldsymbol{\epsilon})\tag{8}

である.よって,(3)は次のようになる:

logP(xi;θ)=1Ll=1LlogP(xizi(l);θ)KL(Q(zixi;ϕ)P(zi;θ)).(9)\log P(\boldsymbol{x} _ i;\boldsymbol{\theta})=\frac{1}{L}\sum _ {l=1} ^ L\log P(\boldsymbol{x} _ i\mid\boldsymbol{z} _ i ^ {(l)};\boldsymbol{\theta})-\text{KL}(Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})\| P(\boldsymbol{z} _ i;\boldsymbol{\theta})).\tag{9}

ここで,$Q(\boldsymbol{z}\mid\boldsymbol{x};\boldsymbol{\theta})$が多変量正規分布であると仮定する:

Q(zx;θ)=1(2π)nΣexp[12t(zμ)Σ1(zμ)].(10)Q(\boldsymbol{z}\mid\boldsymbol{x};\boldsymbol{\theta})=\frac{1}{\sqrt{(2\pi) ^ n\mid\Sigma\mid}}\exp\left[-\frac{1}{2} ^ t(\boldsymbol{z}-\boldsymbol\mu)\Sigma ^ {-1}(\boldsymbol{z}-\boldsymbol\mu)\right].\tag{10}

ただし,$\Sigma,\boldsymbol\mu$は$\boldsymbol{x}$によって決まり,特に$\Sigma$は対角行列であるとする:

Σ=(σ12Oσ22Oσn2).(11)\begin{aligned} \Sigma = \begin{pmatrix} \sigma _ 1 ^ 2 &&&O\\ & \sigma _ 2 ^ 2 &&\\ &&\ddots&\\ O &&& \sigma _ n ^ 2 \end{pmatrix}. \end{aligned}\tag{11}

さらに,$P(\boldsymbol{z};\boldsymbol{\theta})$も正規分布であるとする:

P(z;θ)=1(2π)nexp(z22).(12)P(\boldsymbol{z};\boldsymbol{\theta}) = \frac{1}{\sqrt{(2\pi) ^ n}}\exp\left(-\frac{\mid\boldsymbol{z}\mid ^ 2}{2}\right).\tag{12}

まず,

Q(zx;θ)logP(z;θ)dz=1(2π)nΣexp[12t(zμ)Σ1(zμ)]×[n2log(2π)z22]dz=n2log(2π)12(2π)nΣ×z2exp[12t(zμ)Σ1(zμ)]dz(13)\begin{aligned} &\int Q(\boldsymbol{z}\mid\boldsymbol{x};\boldsymbol{\theta})\log P(\boldsymbol{z};\boldsymbol{\theta})\,d\boldsymbol{z} \\ &= \frac{1}{\sqrt{(2\pi) ^ n\mid\Sigma\mid}}\int\exp\left[-\frac{1}{2} ^ t(\boldsymbol{z}-\boldsymbol\mu)\Sigma ^ {-1}(\boldsymbol{z}-\boldsymbol\mu)\right]\times\left[-\frac{n}{2}\log(2\pi)-\frac{\mid\boldsymbol{z}\mid ^ 2}{2}\right]\,d\boldsymbol{z}\\ &= -\frac{n}{2}\log(2\pi)-\frac{1}{2\sqrt{(2\pi) ^ n\mid\Sigma\mid}}\times\int\mid\boldsymbol{z}\mid ^ 2\exp\left[-\frac{1}{2} ^ t(\boldsymbol{z}-\boldsymbol\mu)\Sigma ^ {-1}(\boldsymbol{z}-\boldsymbol\mu)\right]\,d\boldsymbol{z} \end{aligned}\tag{13}

第2項の積分は

i=1nzi2exp[12j=1n(zjμj)2σj2]dz=i=1nexp[(z1μ1)22σ12]dz1××zi2exp[(zjμj)22σj2]dzi××exp[(znμn)22σn2]dzn=i=1n(2π)n12jiσjzi2exp[(zjμj)22σj2]dzi=i=1n(2π)n2j=1nσj(σi2+μi2)(14)\begin{aligned} \sum _ {i=1} ^ n\int z _ i ^ 2\exp\left[-\frac{1}{2}\sum _ {j=1} ^ n\frac{(z _ j-\mu _ {j}) ^ 2}{\sigma _ j ^ 2}\right]\,d\boldsymbol{z}&=\sum _ {i=1} ^ n\int \exp\left[-\frac{(z _ 1-\mu _ {1}) ^ 2}{2\sigma _ 1 ^ 2}\right]\,dz _ 1\times\\ &\qquad\dots\times\int z _ i ^ 2\exp\left[-\frac{(z _ j-\mu _ {j}) ^ 2}{2\sigma _ j ^ 2}\right]\,dz _ i\times\\ &\qquad\dots\times\int \exp\left[-\frac{(z _ n-\mu _ {n}) ^ 2}{2\sigma _ n ^ 2}\right]\,dz _ n\\ &= \sum _ {i=1} ^ n(2\pi) ^ {\frac{n-1}{2}}\prod _ {j\neq i}\sigma _ j\int z _ i ^ 2\exp\left[-\frac{(z _ j-\mu _ {j}) ^ 2}{2\sigma _ j ^ 2}\right]\,dz _ i\\ &= \sum _ {i=1} ^ n(2\pi) ^ {\frac{n}{2}}\prod _ {j=1} ^ n\sigma _ j(\sigma _ i ^ 2+\mu _ {i} ^ 2) \end{aligned}\tag{14}

のように変形できるので,

Q(zx;θ)logP(z;θ)dz=n2log(2π)12i=1n(σi2+μi2).(15)\int Q(\boldsymbol{z}\mid\boldsymbol{x};\boldsymbol{\theta})\log P(\boldsymbol{z};\boldsymbol{\theta})\,d\boldsymbol{z}= -\frac{n}{2}\log(2\pi)-\frac{1}{2}\sum _ {i=1} ^ n(\sigma _ i ^ 2+\mu _ {i} ^ 2).\tag{15}

次に,

Q(zx;θ)logQ(zx;θ)dz=1(2π)nΣexp[12i=1n(ziμi)2σi2]×[n2log2π12logΣ12j=1n(zjμj)2σj2]dz=n2log2π12logΣ12(2π)nΣ×i=1n(ziμi)2σi2exp[j=1n(zjμj)22σj2]dz(16)\begin{aligned} &\int Q(\boldsymbol{z}\mid\boldsymbol{x};\boldsymbol{\theta})\log Q(\boldsymbol{z}\mid\boldsymbol{x};\boldsymbol{\theta})\,d\boldsymbol{z}\\ &=\frac{1}{\sqrt{(2\pi) ^ n\mid\Sigma\mid}}\int\exp\left[-\frac{1}{2}\sum _ {i=1} ^ n\frac{(z _ i-\mu _ {i}) ^ 2}{\sigma _ i ^ 2}\right]\times\left[-\frac{n}{2}\log 2\pi-\frac{1}{2}\log\mid\Sigma\mid-\frac{1}{2}\sum _ {j=1} ^ n\frac{(z _ j-\mu _ {j}) ^ 2}{\sigma _ j ^ 2}\right]\,d\boldsymbol{z}\\ &=-\frac{n}{2}\log 2\pi-\frac{1}{2}\log\mid\Sigma\mid-\frac{1}{2\sqrt{(2\pi) ^ n\mid\Sigma\mid}}\times\sum _ {i=1} ^ n\int\frac{(z _ i-\mu _ {i}) ^ 2}{\sigma _ i ^ 2}\exp\left[-\sum _ {j=1} ^ n\frac{(z _ j-\mu _ {j}) ^ 2}{2\sigma _ j ^ 2}\right]\,d\boldsymbol{z} \end{aligned}\tag{16}

第3項の積分は

i=1n1σi2zi2j=1nexp(zj22σj2)dz=i=1n1σi2exp(z122σ12) dz1××zi2exp(zi22σi2) dzi××exp(zn22σn2)dzn=i=1n1σi2(2π)n12jiσj2πσi3=i=1n(2π)n2j=1nσj(17)\begin{aligned} \sum _ {i=1} ^ n \frac{1}{\sigma _ i ^ 2} \int z _ i ^ 2 \prod _ {j=1} ^ n \exp \left(-\frac{z _ j ^ 2}{2 \sigma _ j ^ 2}\right) d\boldsymbol{z}&= \sum _ {i=1} ^ n\frac{1}{\sigma _ i ^ 2}\int \exp \left(-\frac{z _ 1 ^ 2}{2\sigma _ 1 ^ 2}\right)\ dz _ 1\times\\ \qquad&\dots\times\int z _ i ^ 2\exp\left(-\frac{z _ i ^ 2}{2\sigma _ i ^ 2}\right)\ dz _ i\times\\ \qquad&\dots\times\int\exp\left(-\frac{z _ n ^ 2}{2\sigma _ n ^ 2}\right)\,dz _ n\\ &=\sum _ {i=1} ^ n\frac{1}{\sigma _ i ^ 2}(2\pi) ^ {\frac{n-1}{2}}\prod _ {j\neq i}\sigma _ j\sqrt{2\pi}\sigma _ i ^ 3\\ &=\sum _ {i=1} ^ n (2\pi) ^ {\frac{n}{2}}\prod _ {j=1} ^ n\sigma _ j \end{aligned}\tag{17}

のように変形できるので,

Q(zx;θ)logQ(zx;θ)dz=n2log(2π)12j=1n(1+logσj2).(18)\int Q(\boldsymbol{z}\mid\boldsymbol{x};\boldsymbol{\theta})\log Q(\boldsymbol{z}\mid\boldsymbol{x};\boldsymbol{\theta})\,d\boldsymbol{z}=-\frac{n}{2}\log(2\pi)-\frac{1}{2}\sum _ {j=1} ^ n(1+\log\sigma _ j ^ 2).\tag{18}

(15),(18)から,(3) 第2項は,

KL(Q(zixi;ϕ)P(zi;θ))=Q(zixi;θ)logQ(zixi;θ)dzi+Q(zixi;θ)logP(zi;θ)dzi=12j=1n[(1+logσj2)μj2σj2].(19)\begin{aligned} &-\text{KL}(Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})\| P(\boldsymbol{z} _ i;\boldsymbol{\theta}))\\ &=-\int Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\theta})\log Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\theta})\,d\boldsymbol{z} _ i+\int Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\theta})\log P(\boldsymbol{z} _ i;\boldsymbol{\theta})\,d\boldsymbol{z} _ i\\ &= \frac{1}{2}\sum _ {j=1} ^ n[(1+\log{\sigma _ j} ^ 2)-\mu _ j ^ 2-\sigma _ j ^ 2]. \end{aligned}\tag{19}

また,今回は$Q(\boldsymbol{z} _ i\mid\boldsymbol{x} _ i;\boldsymbol{\phi})$が正規分布$\mathcal{N}(\boldsymbol{z} _ i,\boldsymbol{\mu},\boldsymbol{\sigma}\odot\boldsymbol{\sigma}E)$であるとしたので,再パラメータ化は

gϕ(ϵ,xi)=μ(xi)+σ(xi)ϵ,ϵN(ϵ,0,E)(20)g _ {\boldsymbol{\phi}}(\boldsymbol{\epsilon},\boldsymbol{x} _ i) = \boldsymbol{\mu}(\boldsymbol{x} _ i)+\boldsymbol{\sigma}(\boldsymbol{x} _ i)\odot\boldsymbol{\epsilon},\quad\boldsymbol{\epsilon}\sim\mathcal{N}(\boldsymbol{\epsilon},0,E)\tag{20}

とする.これは,

EN(ϵ,0,E)[μ(xi)+σ(xi)ϵ]=μ(xi)(21)E _ {\mathcal{N}(\boldsymbol{\epsilon},0,E)}[\boldsymbol{\mu}(\boldsymbol{x} _ i)+\boldsymbol{\sigma}(\boldsymbol{x} _ i)\odot\boldsymbol{\epsilon}]=\boldsymbol{\mu}(\boldsymbol{x} _ i)\tag{21}

および

EN(ϵ,0,E)[(μi+σiϵiμi)(μj+σjϵjμj)]=δijσiσj(22)E _ {\mathcal{N}(\boldsymbol{\epsilon},0,E)}[(\mu _ i+\sigma _ i\epsilon _ i-\mu _ i)(\mu _ j+\sigma _ j\epsilon _ j-\mu _ j)]= \delta _ {ij}\sigma _ i\sigma _ j \tag{22}

から分かる.

以上から,

logP(xi;θ)=12j=1n[(1+logσj2)μj2σj2]+1Ll=1LlogP(xizi(l);θ).(23)\begin{aligned} \log P(\boldsymbol{x} _ i;\boldsymbol{\theta}) &= \frac{1}{2}\sum _ {j=1} ^ n[(1+\log{\sigma _ j} ^ 2)-\mu _ j ^ 2-\sigma _ j ^ 2]+\frac{1}{L}\sum _ {l=1} ^ L\log P(\boldsymbol{x} _ i\mid\boldsymbol{z} _ i ^ {(l)};\boldsymbol{\theta}). \end{aligned}\tag{23}

ただし,

zi(l)=μ(l)(xi)+σ(l)(xi)ϵ(l),ϵ(l)N(ϵ,0,E).(24)\begin{aligned} \boldsymbol{z} _ i ^ {(l)}&= \boldsymbol{\mu} ^ {(l)}(\boldsymbol{x} _ i)+\boldsymbol{\sigma} ^ {(l)}(\boldsymbol{x} _ i)\odot\boldsymbol{\epsilon} ^ {(l)},\\ \boldsymbol{\epsilon} ^ {(l)}&\sim\mathcal{N}(\boldsymbol{\epsilon},0,E). \end{aligned}\tag{24}

これが最大になるように,NNを構成すればよい.

摂動$\delta$の予想

各状態にある細胞を抽出し,細胞の数によるバイアスを無くすために,調節する.最後に

δ=z1z0(25)\delta=\overline{z _ 1}-\overline{z _ 0}\tag{25}

を計算する.$\overline{z _ 0}$は各状態の潜在変数の平均,$\overline{z _ 1}$は摂動があったときの潜在変数の平均.