softmax の数値安定性について
Softmax の定義
softmax:Rd→Rd は次のように定義される。
softmax(x)=i=1∑dexp(xi)1exp(x1)…exp(xd)
この写像は、入力値を正のベクトルに変換し、全ての成分の合計が 1 になるように正規化する。
つまり、任意のベクトルを確率ベクトルへ移す写像。
数値的な問題
Softmax を計算する際、exp(x) を計算する必要があるため、x の値が大きいとオーバーフローが発生しやすくなる。
例えば、以下のように softmax 関数を定義すると、x の値が大きいときにオーバーフローが発生する。
import numpy as np
def softmax(x):
return np.exp(x) / np.sum(np.exp(x))
print(f"{softmax(np.array([0.5, 0.1, 0.2]))}")
>>> [0.41474187 0.27800979 0.30724834]
# オーバーフローが発生
print(f"{softmax(np.array([1000.0, 0.1, 0.2]))}")
>>> [nan 0. 0.]
数値安定性のための工夫
Softmax の数値安定性のために、任意の C∈R に対して以下が成り立つことを利用する。
softmax(x)j=i=1∑dexp(xi)exp(xj)=exp(−C)i=1∑dexp(xi)exp(−C)exp(xj)=i=1∑dexp(xi−C)exp(xj−C)
ここで、C=i=1maxxi とすると任意の j に対して exp(xj−C)≤1 となり、オーバーフローを防ぐことができる。
import numpy as np
def softmax_stab(x):
C = np.max(x)
return np.exp(x - C) / np.sum(np.exp(x - C))
print(f"{softmax_stab(np.array([0.5, 0.1, 0.2]))}")
>>> [0.41474187 0.27800979 0.30724834]
# オーバーフローが発生しない
print(f"{softmax_stab(np.array([1000.0, 0.1, 0.2]))}")
>>> [1. 0. 0.]