softmax の数値安定性について
Softmax の定義
は次のように定義されます。
この写像は、入力値を正のベクトルに変換し、全ての成分の合計が 1 になるように正規化します。 つまり、任意のベクトルを確率ベクトルへ移す写像です。
数値的な問題
Softmax を計算する際、 を計算する必要があるため、 の値が大きいとオーバーフローが発生しやすくなります。 例えば、以下のように softmax 関数を定義すると、 の値が大きいときにオーバーフローが発生します。
import numpy as np
def softmax(x):
return np.exp(x) / np.sum(np.exp(x))
print(f"{softmax(np.array([0.5, 0.1, 0.2]))}")
>>> [0.41474187 0.27800979 0.30724834]
# オーバーフローが発生する
print(f"{softmax(np.array([1000.0, 0.1, 0.2]))}")
>>> [nan 0. 0.]
数値安定性のための工夫
Softmax の数値安定性のために、以下が成り立つことを利用します。
ここで、 とすると任意の に対して となり、オーバーフローを防ぐことができます。
import numpy as np
def softmax_stab(x):
C = np.max(x)
return np.exp(x - C) / np.sum(np.exp(x - C))
print(f"{softmax_stab(np.array([0.5, 0.1, 0.2]))}")
>>> [0.41474187 0.27800979 0.30724834]
# オーバーフローが発生しない
print(f"{softmax_stab(np.array([1000.0, 0.1, 0.2]))}")
>>> [1. 0. 0.]