softmax 関数のリプシッツ性


定義

eiRde_{i}\in\mathbb R^dii 番目の基本ベクトルとする。

Δd\Delta_{d}dd 次元確率単体とする。つまり

Δd={sRd|i=1dsi=1,任意の i でsi0}.\Delta_{d} = \left\{ s \in\mathbb{R}^d \middle| \sum_{i=1}^d s_{i} = 1, \: \text{任意の $i$ で} s_{i}\ge 0 \right\}.

逆温度 λ>0\lambda>0 の softmax σλ:RdΔd\sigma_{\lambda}:\mathbb{R}^d\to\Delta_d を以下のように定義する

σλ(x)=1i=1dexp(λxi)[exp(λx1)exp(λxd)].\sigma_{\lambda}(\mathbf x) = \frac{1}{\sum\limits_{i=1}^d\exp(\lambda x_i)} \begin{bmatrix} \exp(\lambda x_1) \\ \dots \\ \exp(\lambda x_d) \end{bmatrix}.

命題

任意の 1p,q1\le p,q\le\inftyσλ\sigma_{\lambda} について以下が成り立つ

σλ(x)σλ(y)pLp,qxyq\|\sigma_{\lambda}(x)-\sigma_{\lambda}(y)\|_{p} \le L_{p,q}\|x-y\|_{q}

ここで

Lp,q=λ21+1p1qL_{p,q} = \lambda 2^{-1 + \frac{1}{p} - \frac{1}{q}}

証明

s=σλ(z)s=\sigma_{\lambda}(z) とするとヤコビ行列は以下のように書ける

J(z)=σλ(z)=λ(diag(s)ssT).J(z) = \nabla \sigma_{\lambda}(z) = \lambda(\mathrm{diag}(s)- ss^T).

平均値不等式より

σλ(x)σλ(y)p(supzsupuq=1J(z)up)xyq.\|\sigma_{\lambda}(x)-\sigma_{\lambda}(y)\|_{p} \le \left( \sup_{z} \sup_{\|u\|_{q}= 1} \|J(z) u\|_{p} \right) \|x-y\|_{q}.

したがって supzJ(z)qp\sup_{z}\|J(z)\|_{q\to p} を評価すれば十分である。

まず,標準基底 {ei}i=1d\{e_i\}_{i=1}^d に対して恒等式 diag(s)ssT=i<jsisj(eiej)(eiej)T\mathrm{diag}(s)-ss^T = \sum_{i<j}s_{i}s_{j}(e_{i}-e_{j})(e_{i}-e_{j})^T が成り立つから任意の uRdu\in\mathbb{R}^d

J(z)u=λi<jsisj(uiuj)(eiej).J(z)u = \lambda \sum_{i<j} s_{i}s_{j}(u_{i}-u_{j})(e_{i}-e_{j}).

三角不等式と eiejp=21/p\|e_i-e_j\|_{p}=2^{1/p}iji\neq j)より

J(z)upλi<jsisjuiujeiejp=λ21/pi<jsisjuiuj\begin{align*} \|J(z)u\|_{p} &\le \lambda \sum_{i<j} s_{i}s_{j} |u_{i}-u_{j}| \|e_{i}-e_{j}\|_{p} \\ &= \lambda 2^{1/p} \sum_{i<j} s_{i}s_{j} |u_{i}-u_{j}| \end{align*}

ここで s=σλ(z)s=\sigma_\lambda(z) は単体の内点を走る。関数 si<jsisjuiujs\mapsto\sum_{i<j}s_is_j|u_i-u_j| は連続なので, 境界を含む閉単体 Δd\Delta_d 上の上限と一致する。 したがって

supzsupuq=1J(z)up=supuq=1supzJ(z)up=λ21/psupuq=1supsΔdi<jsisjuiuj.\begin{align*} \sup_{z} \sup_{\|u\|_{q}= 1} \|J(z) u\|_{p} &= \sup_{\|u\|_{q}=1} \sup_{z} \|J(z)u\|_{p} \\ &= \lambda 2^{1/p} \sup_{\|u\|_{q}=1} \sup_{s \in\Delta_{d}} \sum_{i<j} s_{i}s_{j}|u_{i}-u_{j}|. \end{align*}

imin=arg miniui,imax=arg maxiuii_{min}=\argmin_{i}u_{i}, \: i_{max}=\argmax_{i}u_{i} とする。右辺について ssimin,imaxi_{min},i_{max} に集中させる場合が最も大きくなるから

supsΔdi<jsisjuiuj=uimaxuiminmaxS[0,1]S(1S)=uimaxuimin/4.\sup_{s \in \Delta_{d}} \sum_{i<j} s_{i}s_{j}|u_{i}-u_{j}| = |u_{i_{max}}-u_{i_{min}} | \max_{S \in[0,1]} S(1-S) = |u_{i_{max}}-u_{i_{min}} | /4.

等号成立条件は simin=simax=12s_{i_{min}}=s_{i_{max}}=\frac{1}{2}。 次に、任意の実数 a,ba,b について abq2q1(aq+bq)|a-b|^q \le 2^{q-1} (|a|^q + |b|^q) が成り立つから a=uimax,b=ujmina=u_{i_{max}},b=u_{j_{min}} に適用することで以下が成り立つ:

uimaxuimin211/q(uimaxq+uiminq)1/q211/q|u_{i_{max}}-u_{i_{min}}| \le 2^{1-1/q} (|u_{i_{max}}|^q + |u_{i_{min}}|^q)^{1/q} \le 2^{1-1/q}

等号成立条件は uimax=21/q,uimin=21/qu_{i_{max}}=2^{-1/q}, u_{i_{min}}=-2^{-1/q} 。 以上をまとめると

supzsupuq=1J(z)upλ21/p14211/q=λ21+1/p1/q=:Lp,q.\sup_{z} \sup_{\|u\|_{q}= 1} \|J(z) u\|_{p} \le \lambda 2^{1/p} \cdot \frac{1}{4} \cdot 2^{1-1/q} = \lambda 2^{-1+1/p-1/q} =: L_{p,q}.

\square

具体例

  • (2,2)(2,2) のとき L2,2=λ/2L_{2,2} = \lambda/2
  • (1,1)(1,1) のとき L1,1=λ/2L_{1,1} = \lambda/2
  • (1,)(1,\infty) のとき L1,=λL_{1,\infty} = \lambda

References