Softmax


Softmax の定義

逆温度 λ>0\lambda>0softmax:RdRd\text{softmax}:\mathbb R^d\to\mathbb R^d は次のように定義されます.

softmax(x)=1i=1dexp(λxi)[exp(λx1)exp(λxd)]\text{softmax}(\mathbf{x}) = \frac{1}{\sum\limits_{i=1}^d\exp(\lambda x_i)} \begin{bmatrix} \exp(\lambda x_1) \\ \dots \\ \exp(\lambda x_d) \end{bmatrix}

Softmax は log-sum-exp の勾配

log-sum-exp 関数 lse:RdR\text{lse}:\mathbb R^d\to\mathbb R は次のように定義されます.

lse(x)=1λlogi=1dexp(λxi)\text{lse}(\mathbf{x}) = \frac{1}{\lambda} \log\sum\limits_{i=1}^d\exp(\lambda x_i)

softmax(x)\text{softmax}(\mathbf{x}) は,log-sum-exp 関数の勾配です. つまり,次が成り立ちます.

Proposition 1.
lse(x)=softmax(x)\nabla\text{lse}(\mathbf{x}) = \text{softmax}(\mathbf{x})

証明

lse(x)=1λlogi=1dexp(λxi)=1λ1i=1dexp(λxi)i=1dexp(λxi)=1λλi=1dexp(λxi)[exp(λx1)exp(λxd)]=1i=1dexp(λxi)[exp(λx1)exp(λxd)]=softmax(x)\begin{align*} \nabla\text{lse}(\mathbf{x}) &= \frac{1}{\lambda} \nabla\log\sum\limits_{i=1}^d\exp(\lambda x_i) \\ &= \frac{1}{\lambda} \frac{1}{\sum\limits_{i=1}^d\exp(\lambda x_i)} \nabla\sum\limits_{i=1}^d\exp(\lambda x_i) \\ &= \frac{1}{\lambda} \frac{\lambda}{\sum\limits_{i=1}^d\exp(\lambda x_i)} \begin{bmatrix} \exp(\lambda x_1) \\ \dots \\ \exp(\lambda x_d) \end{bmatrix} \\ &= \frac{1}{\sum\limits_{i=1}^d\exp(\lambda x_i)} \begin{bmatrix} \exp(\lambda x_1) \\ \dots \\ \exp(\lambda x_d) \end{bmatrix} \\ &= \text{softmax}(\mathbf{x}) \end{align*}

Softmax の微分

softmax(x)\text{softmax}(\mathbf{x}) の微分は,次のように計算できます.

Proposition 2.
Jsoftmax(x)=λ(diag(softmax(x))softmax(x)softmax(x)T)J_{\text{softmax}}(\mathbf{x}) = \lambda(\text{diag}(\text{softmax}(\mathbf{x})) - \text{softmax}(\mathbf{x})\text{softmax}(\mathbf{x})^T)

証明

(softmax(x))ixj\displaystyle \frac{\partial (\text{softmax}(\mathbf{x}))_i}{\partial x_j} を計算します.

i=ji = j の場合

(softmax(x))ixi=xi(exp(λxi)k=1dexp(λxk))=λexp(λxi)k=1dexp(λxk)λexp(λxi)2(k=1dexp(λxk))2=λsoftmax(x)iλsoftmax(x)i2\begin{align*} \frac{\partial (\text{softmax}(\mathbf{x}))_i}{\partial x_i} &= \frac{\partial}{\partial x_i}\left(\frac{\exp(\lambda x_i)}{\sum\limits_{k=1}^d\exp(\lambda x_k)}\right) \\ &= \frac{\lambda \exp(\lambda x_i)\sum\limits_{k=1}^d\exp(\lambda x_k) - \lambda\exp(\lambda x_i)^2}{\left(\sum\limits_{k=1}^d\exp(\lambda x_k)\right)^2} \\ &= \lambda\text{softmax}(\mathbf x)_i - \lambda\text{softmax}(\mathbf x)_i^2 \\ \end{align*}

iji \neq j の場合

(softmax(x))ixj=xj(exp(λxi)k=1dexp(λxk))=λexp(λxi)exp(λxj)(k=1dexp(λxk))2=λsoftmax(x)isoftmax(x)j\begin{align*} \frac{\partial (\text{softmax}(\mathbf{x}))_i}{\partial x_j} &= \frac{\partial}{\partial x_j}\left(\frac{\exp(\lambda x_i)}{\sum\limits_{k=1}^d\exp(\lambda x_k)}\right) \\ &= -\frac{\lambda \exp(\lambda x_i)\exp(\lambda x_j)}{\left(\sum\limits_{k=1}^d\exp(\lambda x_k)\right)^2} \\ &= -\lambda \text{softmax}(\mathbf x)_i\text{softmax}(\mathbf x)_j \\ \end{align*}

よって,微分は次のように計算できます.

Jsoftmax(x)=λ(diag(softmax(x))softmax(x)softmax(x)T)J_{\text{softmax}}(\mathbf{x}) = \lambda(\text{diag}(\text{softmax}(\mathbf{x})) - \text{softmax}(\mathbf{x})\text{softmax}(\mathbf{x})^T)

References