【PyTorch】モデルの可視化・保存方法について学ぶ

2020-06-06
2020-06-06

本記事では、PyTorch でよく使うモデルの可視化や保存方法を紹介します。

また、たまに使うけどよくわからないregister_buffertorch.lerpについても調べてみました。

本記事では、前回使用した MLP モデルを使っていきます。

torchsummay でモデルを可視化

torchsummaryというモジュールを利用することで、モデルを可視化することができます。

複雑なモデルを定義していると入力や出力の shape がわからなくなったり、「これメモリに乗るのかな」ということがあります。

そういう時にこのtorchsummaryを利用します。

インストールはpipでできます。

pip install torchsummary

使い方はこんな感じです。前回のコードを流用します。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets, transforms
# 追加============================
import os
from torchsummary import summary
# ===============================
from datetime import datetime
print(torch.__version__) # 1.5.0
# colabでgoogle driveをマウントしてない場合のパス
root="content/"
# dataの変換方法を定義
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# dataをダウンロード
train_set = datasets.MNIST(root=root, train=True, transform=trans, download=True)
test_set = datasets.MNIST(root=root, train=False, transform=trans, download=True)
# cpuかgpuか
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# dataloaderを定義
train_loader = DataLoader(train_set, batch_size=100, shuffle=True)
test_loader = DataLoader(test_set, batch_size=100, shuffle=False)
# Networkを定義
class MLPNet (nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1 * 28 * 28, 512)
        self.fc2 =nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 10)
        self.dropout1=nn.Dropout2d(0.2)
        self.dropout2=nn.Dropout2d(0.2)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        return F.relu(self.fc3(x))
net = MLPNet().to(device)
# torchsummaryを使った可視化
summary(net, input_size=(1,1 * 28 * 28))

出力は以下のようになります。

----------------------------------------------------------------
title: 【PyTorch】モデルの可視化・保存方法について学ぶ
createdAt: '2020-06-06'
updatedAt: '2020-06-06'
tags: ['PyTorch', 'Python', '機械学習']
draft: false
description:  'PyTorchを使った少々実践的な内容をまとめました。モデルの可視化や保存方法について説明します。また、たまに見かけるtorch.lerpやregister_bufferについてもコード付きで紹介します。'
thumbnail: '/img/twitter-card.png'
---
# 【PyTorch】モデルの可視化・保存方法について学ぶ
-------------------------------------------------------------
title: 【PyTorch】モデルの可視化・保存方法について学ぶ
createdAt: '2020-06-06'
updatedAt: '2020-06-06'
tags: ['PyTorch', 'Python', '機械学習']
draft: false
description:  'PyTorchを使った少々実践的な内容をまとめました。モデルの可視化や保存方法について説明します。また、たまに見かけるtorch.lerpやregister_bufferについてもコード付きで紹介します。'
thumbnail: '/img/twitter-card.png'
---
# 【PyTorch】モデルの可視化・保存方法について学ぶ
-------------------------------------------------------------

非常にわかりやすいです。

特に他人にモデルの説明をするときにあると重宝します。

注意としては、今回のモデルのように入力が 1 次元の場合はそのまま入力サイズにinput_size=(1*28*28)とするとエラーになります。

なので、チャネルの次元を加えてinput_size=(1, 1*28*28)とします。

学習済みモデルの保存

公式に詳しく書いてありますが念のため。

まずモデルの保存を行う目的は2つあります。

  • 学習済みモデルを使って推論を行う
  • 保存済みモデルの学習を再開する

目的によって、保存しておくべき内容が違います。

次に PyTorch でモデルを保存する方法について確認してきます。

PyTorch ではモデルを保存する方法が 2 通りあります。

  • モデル全体を保存する
  • モデルのパラメータを保存する

さらに保存する際には GPU か CPU なのかを注意する必要があります。

ややこしいですが、認識しておく必要があります。

まずは普通にモデルを保存してみます。

net.apply(init_weights) # 追加
# loss関数
criterion = nn.CrossEntropyLoss()
# 最適化方法
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
# log用フォルダを毎回生成
# tensorboardの可視化用
now = datetime.now()
log_path = "./runs/" + now.strftime("%Y%m%d-%H%M%S") + "/"
print(log_path)
# tensorboard用のwriter
writer = SummaryWriter(log_path)
epochs = 30
for epoch in range(epochs):
    train_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    # train dataで訓練
    net.train()
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.view(-1, 28*28*1).to(device), labels.to(device)
        # 勾配を0にリセット
        optimizer.zero_grad()
        # 順伝搬
        out = net(images)
        # loss計算
        loss = criterion(out, labels)
        # 計算したlossとaccの値を入れる
        train_loss += loss.item()
        train_acc += (out.max(1)[1] == labels).sum().item()
        # 誤差逆伝搬
        loss.backward()
        # 重みの更新
        optimizer.step()
        # 平均のlossとacc計算
        avg_train_loss = train_loss / len(train_loader.dataset)
        avg_train_acc = train_acc / len(train_loader.dataset)
    # validation dataで評価
    net.eval()
    with torch.no_grad():
        for (images, labels) in test_loader:
            images, labels = images.view(-1, 28*28*1).to(device), labels.to(device)
            out = net(images)
            loss = criterion(out, labels)
            val_loss += loss.item()
            acc = (out.max(1)[1] == labels).sum()
            val_acc += acc.item()
    avg_val_loss = val_loss / len(test_loader.dataset)
    avg_val_acc = val_acc / len(test_loader.dataset)
    # print log
    print ('Epoch [{}/{}], Loss: {loss:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}'
                   .format(epoch+1, epochs, loss=avg_train_loss, val_loss=avg_val_loss, val_acc=avg_val_acc))
    # tensorboard用
    writer.add_scalars('loss', {'train_loss':avg_train_loss, 'val_loss':avg_val_loss},epoch+1)
    writer.add_scalars('accuracy', {'train_acc':avg_train_acc, 'val_acc':avg_val_acc}, epoch+1)
writer.close()
# 追加部分
dir_name = 'output'
if not os.path.exists(dir_name):
    os.mkdir(dir_name)
model_save_path = os.path.join(dir_name, "model_full.pt")
# モデル保存
torch.save(net, model_save_path)
# モデルロード
model_full = torch.load(model_save_path)

以下コードの部分が保存用のコードです。

# モデル保存
torch.save(net, model_save_path)
# モデルロード
model_full = torch.load(model_save_path)

これが一番単純な方法です。

しかし、この保存方法は公式で推奨されてません。

非推奨の理由をいろいろ調べてみると、どうもこの方法でやると保存時の GPU にロード時も読み込まれてしまうらしい。

つまり GPU がない場合は詰んでしまう可能性がある。

あとはもう一つの方法に比べてサイズが大きい。

なので保存時は公式推奨のstate_dict()の方法で行う。

# モデル保存
torch.save(net.state_dict(), model_save_path)
# モデルロード
model.load_state_dict(torch.load(model_save_path))

一応、GPU で保存してしまっても CPU で読み出す方法はあるらしいが、失敗するのが怖いので CPU で保存しておくのが無難。

やり方は以下のように保存時にto('cpu)をつける。

torch.save(net.to('cpu').state_dict(), model_save_path)
model_cpu.load_state_dict(torch.load(model_save_path))

学習を再開するための checkpoint を作りたい場合は以下のようにします。

if epoch % 3 == 0:
        file_name = 'epoch_{}.pt'.format(epoch)
        path = os.path.join(checkPoint_dir, file_name)
        torch.save({
            'epoch' : epoch,
            'model_state_dict' : net.state_dict(),
            'optimaizer_state_dict': optimizer.state_dict(),
            'loss': avg_train_loss
        }, path)

保存するタイミングは適当に決めます。

こちらが素直に学習した場合の出力です。

Epoch [0/10], Loss: 0.0060, val_loss: 0.0019, val_acc: 0.9452
Epoch [1/10], Loss: 0.0020, val_loss: 0.0014, val_acc: 0.9600
Epoch [2/10], Loss: 0.0015, val_loss: 0.0012, val_acc: 0.9645
Epoch [3/10], Loss: 0.0012, val_loss: 0.0009, val_acc: 0.9715
Epoch [4/10], Loss: 0.0010, val_loss: 0.0009, val_acc: 0.9738
Epoch [5/10], Loss: 0.0009, val_loss: 0.0008, val_acc: 0.9743
Epoch [6/10], Loss: 0.0008, val_loss: 0.0008, val_acc: 0.9754
Epoch [7/10], Loss: 0.0007, val_loss: 0.0008, val_acc: 0.9762
Epoch [8/10], Loss: 0.0007, val_loss: 0.0007, val_acc: 0.9773
Epoch [9/10], Loss: 0.0006, val_loss: 0.0007, val_acc: 0.9775

ロードは以下のようにします。

tmp_path = 'checkPoint/epoch_3.pt'
if os.path.exists(tmp_path):
    checkpoint = torch.load(tmp_path)
    net.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimaizer_state_dict'])
    epoch_num = checkpoint['epoch']
    loss = checkpoint['loss']

そしてこちらがepoch=3の時の checkpoint をロードした時の結果です。

Epoch [3/10], Loss: 0.0010, val_loss: 0.0010, val_acc: 0.9699
Epoch [4/10], Loss: 0.0009, val_loss: 0.0009, val_acc: 0.9728
Epoch [5/10], Loss: 0.0008, val_loss: 0.0008, val_acc: 0.9750
Epoch [6/10], Loss: 0.0007, val_loss: 0.0007, val_acc: 0.9771
Epoch [7/10], Loss: 0.0006, val_loss: 0.0007, val_acc: 0.9783
Epoch [8/10], Loss: 0.0006, val_loss: 0.0007, val_acc: 0.9791
Epoch [9/10], Loss: 0.0005, val_loss: 0.0006, val_acc: 0.9798

乱数固定しなかったので微妙にずれてますが、学習が再開されていることがわかります。

register_buffer とは

次はregister_bufferです。

たまに論文実装のコードをみるとモデルに書いてあります。

公式の説明によると

This is typically used to register a buffer that should not to be considered a model parameter.

とあります。

model のパラメーター ではないけどモデルに持っておきたい値を保存する際に使うようです。

利用シーンとしては Batchnormalization の計算のためのバッチごとの計算結果を保持するのに使われます。

ネットワークに少し追加をして実験しました。

# Networkを定義
class MLPNet (nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1 * 28 * 28, 512)
        self.fc2 =nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 10)
        self.dropout1=nn.Dropout2d(0.2)
        self.dropout2=nn.Dropout2d(0.2)
        # 追加部分
        self.mean_val = 0 # 比較用
        self.register_buffer('count', torch.ones(2,2))
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        return F.relu(self.fc3(x))

普通にクラス変数を定義した場合とregister_bufferの場合を書いてみました。

学習中にこの 2 つをインクリメントして、保存後に値をみるという検証です。

register_bufferを使うとパラメータ同様に保存されることを確かめます。

学習後にモデルの中身をみると

print(net.mean_val)
print(net.count)
# >>>
# 6000
# tensor([[6001., 6001.],
#       [6001., 6001.]], device='cuda:0')

ちゃんとインクリメントされた値が入ってます。

これをいったんstate_dict()で保存し、保存したモデルを再度呼び出します。

dir_name = 'output'
if not os.path.exists(dir_name):
    os.mkdir(dir_name)
model_save_path = os.path.join(dir_name, "model.pt")
torch.save(net.state_dict(), model_save_path)
model = MLPNet()
model.load_state_dict(torch.load(model_save_path))
print(model.mean_val)
print(model.count)
# >>>
# 0
# tensor([[6001., 6001.],
#        [6001., 6001.]])

結果をみると、普通にモデル内に定義した変数の値は保持されていません。一方で、register_bufferの方は保存した時の値がちゃんと残ってます。

したがって、state_dictなどで後からモデルを呼び出す際に、パラメータじゃないけど必要な値をモデルに入れておきたいを値を使う際に役立ちます。

torch.lerp とは

torch.lerpは線形補完を行う関数です。

線形補完は式で表すと以下のようになります。

$$ out_i = v_1 + w (v_2 - v_1) $$

torch.lerp(torch.tensor([1,1],dtype=float), torch.tensor([4,4],dtype=float), 0.5)
# >>> tensor([2.5000, 2.5000], dtype=torch.float64)

上の式に代入すると同じ結果になります。 2つのベクトルの間を$w$で動かす感じです。

まとめ

今回は少し発展的な PyTorch の内容を説明しました。

  • torchsummary の使い方
  • モデルの保存方法
  • register_buffer の使い方
  • torch.lerp について

深層学習は数学的な難しさもありますし、層が増えるとその分コードでモデルを構築するのも難しいです。

フレームワークやライブラリをうまく使って本質的な問題解決に時間を当てたいですね。