Weights&Biasesで機械学習の実験管理を行う方法

2022-08-13
2022-08-13

久々に機械学習記事です。

今回はローカルの実験管理ツールについてです。

ずばり調査したのは Weights & Biases です。

なぜ weights&biases かというと元々は Mlflow の利用を考えていたのですが、たまたま Twitter で weights&biases が便利だという話を聞いたので調査してみました。

結論から言うと超便利だったので今後のローカル実験管理は weights&biases で行くことに決めました。

間違ったとこもあるかもですが、とりあえずわかったことを中心にまとめてみました。

weights&biases の使い方

Quickstart の内容をまずは試してみるのをおすすめします。

weights&biases は利用するのに GitHub か Google で登録が必要がなります。

ブラウザで登録し、login すると API キーが参照できるのでそれをコピーしておきます。

続いてはローカルで weights&biases のパッケージである wandb の準備をします。

コマンドは以下の通り(※ poetry 想定です)

poetry add wandb

install 後、wandb loginをコマンドラインで入力すると、対話形式で API キーを聞かれるので先ほどコピーした値を貼り付けます。

これでコードで wandb を使用する準備ができました。

各種設定、パラメータの管理

まずは機械学習に欠かせないパラメータ管理です。

正直これが使いたくて今回実験管理ツールを探していました。

実際のコードです。

一番単純にやる場合は以下のコードのみです

import wandb
configs={"epoch":100, "learning_rate":0.001}
wandb.init(project="test-project", config=configs)

wandb をインポートし、wandb.initでプロジェクトに接続(なければ作成)し、新しい実験(実行)の記録が開始されます。 この init の段階で記録したいパラメータを渡しておきます。

wandb ではプロジェクトの下にいくつもの実験が記録される仕組みです。 これにより、あるタスクの訓練をいろいろモデルやパラメータを変えて試行錯誤した結果を追跡できます。

私の場合は、後のコード内でconfigsの値を使いやすくするためにpydanticを使用しています。

import wandb
from pydantic import BaseSettings
class Settings(BaseSettings):
    learning_rate: float = 0.001
    epoch: int = 100
settings = Settings()
wandb.init(project="test-project", config=settings.dict())

分類系の情報をまとめて作成

wandb には機械学習ライブラリやフレームワークごとに便利な plot 機能が用意されています。

今回は scikit-learn で分類問題を学習する際に便利なメソッドを紹介します。

使い方は超簡単です。モデル学習後に以下の情報をメソッドに入力するだけです。

wandb.sklearn.plot_classifier(
        clf, # 学習済み分類器
        X_train, # 訓練用データ
        x_val, # 検証用データ
        y_train, # 訓練正解データ
        y_val, # 検証正解データ
        y_pred, # 検証用データの予測ラベル結果
        y_proba, # 検証用データの予測確率
        labels, # 正解ラベルの種類
        model_name="MLP",
        feature_names=None,
    )

データの準備と学習さえ終わっていればどれも準備は簡単です。

これを実行すると混同行列、ROC 曲線、Precision Recall、ターゲットクラスの割合、accuracy/f1/precision/recall の値が全て可視化・保存されます。

コードもスッキリしてみやすくなります。

その他

まだ調査途中で実際に使ってないものもありますが、メモがてら書いておきます

model 保存方法

artifact としてファイルなどを保存できるので、そのやり方で学習済みモデルを保存します。

実際のコードは以下の通りです。

art = wandb.Artifact(f"{model_name}-clf-{wandb.run.id}", type="model")
art.add_file("path/to/model/")
run.log_artifact(art)

ドキュメントをざっと読んだ感じですと、wandb の Artifact は単純にファイルがプロジェクトに紐づけられて保存されているっぽいです。 type は何でもいいらしいです。

使用するときは download メソッドでローカルに落として使います。

使ってみた感想

めちゃめちゃ使いやすくてびっくりです。

大してドキュメントを読んだわけではないですが、直感的に使えます。

一番いいのは既存の学習コードをあまり変えなくていいところと、グラフを作成するコートが省けることです。

今後も引き続き使っていきます。