Focal Loss は本当に不均衡データに効果があるのか検証した [①実装編]

LightGBM, 機械学習

こんな人にオススメ!

・LightGBMにcustom(自作した) objectiveを実装したい人
・focal loss の実装について知りたい人

はじめに

今回は、競艇というより、かなり機械学習よりな記事になります。

LightGBMの損失関数を独自に作成したFocal Lossに設定し、学習してみました。いろいろ苦戦した箇所もあるので、そういった内容も含めて記載したいと思います。

Focal Loss

(多分これだと思うのですが)元になった論文はこちらです。

もともとは画像の機械学習分野で、物体検知における不均衡なクラス分類の改善案として提案されたもののようです。

数式はこんな感じです。以下、FLの出力をloss と呼ぶことにします。

(1)   \begin{equation*}FL(p_t)=-(1-p_t)^\gamma\log(p_t)\end{equation*}

p_tは確率を想定してるため0〜1に規格化された数値です。これはCross Entropy(以下、CE)CE(p_t)=-\log(p_t)を拡張した関数になります。

-\log(p_t)の箇所ですが、この箇所はCEと共通ですね。p_tと出力の関係はこうなります。

p_t-\log(p_t)
高い(正解に近い)0(に近づく)
低い(不正解に近い)∞(に大きくなる)

次に(1-p_t)^\gammaです。\gammaはハイパーパラメータで任意に設定する必要があります。

p_t(1-p_t)^\gamma
高い(正解に近い)0(に近づく)
低い(不正解に近い)1(に近づく)

この項を加えることで、正解に近いラベルに対してはよりlossを下げる効果があります。

つまり、正解に近い予測値の場合はそれ以上学習する事を抑えることになり、不正解なデータに対しての学習を進めやすくなります。

Focal Loss と Cross Entropy の比較

論文内の FL の図を載せておきます。\gamma=0 で CEと一致します。 “well-classifed examples" と書かれた領域はloss がCEに比べて低くなっています。loss が低いということは、そのデータについては学習を進めなくてもいいよ、という事になります。つまり、もっと不正解ラベルにフォーカスしましょうよ、というのがFLの本質ですね。

ポイント!

・Focal Loss は Cross Entropy の拡張である
・Focal Loss は不正解ラベルにフォーカスして不均衡データの学習改善に期待できる

LightGBM custom objective

ここでは LightGBMでの実装に触れます。2020年5月時点でのAPIの理解についても記載します。また、私がscikit-learn APIしか使っていないため、LGBMClassifier を基準に説明する事にします。

損失関数は objective parameter に実装します。ごっちゃになりますが、fit 関数の eval_metric ではありません! eval_metric は early_stopping_rounds などの判断に使われるだけです。ただ、結果的に同じ評価で early_stopping_rounds を判断することが多いので、eval_metric に独自に作成した損失関数を設定することは多いと思います。

ではどういう形式の関数を objective に設定すれば良いのか。公式HPでは次のように記載されています。

第1引数y_true正解ラベル
第2引数y_predLightGBMの生出力
出力1grad損失関数の1階微分
出力2hess損失関数の2階微分

正解ラベルがmulti classの場合

公式では、y_predarray-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task) と記載されています。

これ、分かりにくかったのですが、Nデータnラベルではこういう並びです。

[0ラベル出力1, 0ラベル出力2, …, 0ラベル出力N, 1ラベル出力1, …, nラベル出力N]

入力について

この件は今でも本当かと疑っていますが、lgb.Dataset を使わないと y_true と y_pred が逆になります! もしかしたらどこか未来のversion で改修されるかもしれませんが、現時点ではこういう仕様でした。。なので、下記のような形で変換すればいいです。

func に勾配を返却する関数を入力します。is_lgbdataset は lgb.Dataset を使っていれば true とします。

出力(grad, hess)について

微分した関数での出力は[n_sample]必要になります。multi class の場合は、y_pred と同じ形式で返却する必要があります。

簡単な微分は計算すれば良いのですが、面倒なので scipy の derivative を使います。loss_func に損失関数の計算が書かれた関数を input すれば良いです。

この実装でハマった点について補足しておきます。。

x は規格化されていない!

普段 predict や predict_proba で受け取っている値は規格化されていますが、この x は規格化されていません! 弱学習機の寄せ集めの生の出力が入力されてきます。冷静に考えれば回帰もあるので当然その通りなのですが、最初0〜1に規格化されていると思って少しハマりました。。

規格化は loss_func 内でしろ!

規格化の方法ではよく sigmoid 関数が使われますが、はじめその規格化を calc_grad_hess 内で実装していました。これはダメです!

上にも書いた通り、x は弱学習機等からの生の出力です。その値に対しての勾配を計算する必要があります。当然、規格化も考慮した上での勾配です。規格化をするのも損失関数内のお仕事です。これに気づくまで結構時間を費やしました。。

multi class で softmax の使用は注意しろ!

これも結構ハマりました。。softmax というのは multi class で良い感じに出力の合計を足して 1 になるように規格化してくれる関数です。

ただこれを損失関数内で定義した事で勾配計算がうまくいかなくなる場合があります。私の場合はこうでした。

コンピュータでの微分計算というのは(loss(x+\delta x) - loss(x)) / \delta x\delta x を有限な微小な値として入力し計算させます。

y_pred の初回値(多分、弱学習機がゼロの状態)は全て0の値になります。y_pred = [0, 0, 0, 0, 0, …] です。これを loss 内に入力して、multi class を見やすいように [[0, 0, 0], [0, 0, 0], … ] ※3クラスの場合 と分解し、softmax に入力して [[0.333.., 0.333.., 0.333..], [0.333.., 0.333.., 0.333..], … ] を得ます。その後focal loss などの計算をして、loss_y を得ます。

微分計算のため、\delta x=0.01 で計算したとしましょう。同じように multi class は [[0.01, 0.01, 0.01], [0.01, 0.01, 0.01], … ] となり、softmax後は[[0.333.., 0.333.., 0.333..], [0.333.., 0.333.., 0.333..], … ] を得ます。そしてloss_yを得ます。

あれ?値が同じ?? とすぐに気づく人は素晴らしい。そうなんです。初回値に限り、softmax のせいで勾配が消失しますloss(x+\delta x) - loss(x)=0です。初回に重みが更新されないということは、その先ずっと更新されず、学習が1ミリも進みません。

なので、損失関数内の規格化はsigmoid でしておく方が無難です。勿論、softmaxが活躍するケースはあると思いますが。

そして、結局、multi class の FL ではどうあがいても derivative を使った微分計算ができなかったので、真面目に微分して実装しました。

FLの微分とLGBM実装

この微分計算とFLの実装に関しては、合っている保証がないので使用については自己責任でお願いします!

1階微分と2階微分の計算過程は恥を承知で公開します。※間違いを発見された場合は是非コメントください!

そして、FLの実装は次のようにしました。

focal_loss 関数は通常の FL 計算、focal_loss_grad 関数は FLの1階微分と2階微分の計算をしています。

これらの関数を LGBM で使用するサンプルは次です。

まとめ

LightGBMの custom objective (独自損失関数) について、Focal loss を例に実装方法とハマりポイントについてまとめました。

はじめは Focal Loss の検証について書く予定でブログを書き始めたのですが、この実装までにかなりハマってしまい、Focal Loss が動いた!というだけで燃え尽きてしまいました笑

検証を簡単にはしてみたのですが、思いの外大した効果がない印象でした。とはいえ最終的に検証結果をまとめたいとは思っていますので、「②検証編」がいつになるかは分かりませんが、いつか書きます。いつか・・・。

では今回のまとめです。

ポイント!

・Focal Loss は不均衡データの学習改善に期待できる
・Focal Loss を2階微分まで計算した(合っている保証はない)
・LightGBMの custom objective の実装方法をまとめた
・softmax を使った損失関数のscipy / derivativeを使った微分計算には注意する

LightGBM, 機械学習