PyTorchの勾配更新方法の解説
今回も PyTorch に関する記事です。
この記事では、requires_grad
、zero_grad
などについて説明します。
私自身も勉強中ということもあり間違い等あるかもしれません。その際は Twitter などで教えてください。
requires_grad とは
【学び直し】Pytorch の基本と MLP で MNIST の分類・可視化の実装までで紹介したように、requires_grad
は自動微分を行うためのフラグです。
単純に tensor を定義した場合はデフォルトで False になっています。
x = torch.ones([3, 32, 32])
x.requires_grad
# >>> False
一方で、ネットワークを定義した場合のパラメータはデフォルトでrequires_grad=True
です。
意外とこれを知らずにわざわざ学習時にrequires_grad=True
を設定していることがあります。
実際に定義したネットワークで確かめてみました。
# Networkを定義
class MLPNet (nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(1 * 28 * 28, 512)
self.fc2 =nn.Linear(512, 512)
self.fc3 = nn.Linear(512, 10)
self.dropout1=nn.Dropout2d(0.2)
self.dropout2=nn.Dropout2d(0.2)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.dropout1(x)
x = F.relu(self.fc2(x))
x = self.dropout2(x)
return F.relu(self.fc3(x))
net = MLPNet().to(device)
for name, p in net.named_parameters():
print(name, p.requires_grad)
# >>> fc1.weight True
# >>> fc1.bias True
# >>> fc2.weight True
# >>> fc2.bias True
# >>> fc3.weight True
# >>> fc3.bias True
このようにrequires_grad=True
になってます。
ちなみにnamed_parameters()
とすることでレイヤーの名前とパラメータを参照できます。
ドキュメントによるとネットワークで各層を定義した際に用いられるnn.Prameter
がデフォルトでrequires_grad=True
なので上記の結果になります。
requires_grad=True
が求められるのは、backward
で勾配を計算したいところです。
逆に、勾配の更新を行わないところは明示的にrequires_grad=False
とする必要があります。
optim について
optim は pytorch で学習を行う際に用いる最適化関数です。
今回も簡単な式で挙動を確認します。
import torch
import torch.optim as optim
x = torch.tensor(3.0, requires_grad=True)
w = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(1.0, requires_grad=True)
yy = torch.tensor(5.0, requires_grad=True)
y = w * x + b
opt = optim.SGD([w,b], lr=0.01, momentum=0.9)
optim の引数には学習したいパラメータと学習率などを渡します。
ここで、学習したいパラメータは iteration の形である必要があるためカッコでくくってます。
実際に optim で最適化してみます。
criterion = nn.MSELoss()
for epoch in range(3):
opt.zero_grad()
y = w * x + b
loss = criterion(y, yy)
loss.backward()
opt.step()
まず、opt.zero_grad()
で勾配をゼロにリセットします。
これをしないと勾配が蓄積されたままになってしまうので、正しい方向に更新されなくなります。
数式でかくと、パラメータ$w$を更新する際は以下のようになります(SGD の場合)。
$$ w ← w - lr \frac{\partial L}{\partial w} $$
opt.zero_grad()
を使わないと、偏微分部分が前の勾配との和になってしまうので、正しい方向に更新されません。
簡単に手計算で確かめられるので、一度計算してみると何が行われているかイメージできます。
パラメータの更新について
PyTorch を使い始めてまず混乱するのが、なぜ optim の操作で model の値が更新されるのかです。
結論からいうと、model パラメータのアドレスを optim の初期化のときに渡しているからです。
opt = optim.SGD([w,b], lr=0.01, momentum=0.9)
ここまでは大丈夫だと思います。
気をつけないといけないのが、学習の途中でパラメータを optim 以外の方法で更新する場合です。
例えば、GAN で重みの移動平均を計算する場合など。
そういう時は.data
で値を更新します。
model は先ほどの MLP だとして、まずは普通に値を更新します。
net = MLPNet().to(device)
# 最適化方法
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
net.fc1.weight = nn.Parameter(torch.ones([512, 784]))
print(net.fc1.weight)
print(optimizer.param_groups[0]['params'][0])
この時の出力は以下のようになります。
Parameter containing:
tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]], requires_grad=True)
Parameter containing:
tensor([[ 0.0323, -0.0328, -0.0338, ..., -0.0061, -0.0316, 0.0097],
[-0.0309, -0.0281, 0.0229, ..., 0.0164, -0.0236, 0.0347],
[-0.0184, -0.0026, 0.0323, ..., -0.0353, -0.0310, 0.0208],
...,
[ 0.0015, 0.0148, 0.0143, ..., -0.0112, -0.0215, 0.0123],
[ 0.0332, -0.0148, 0.0235, ..., -0.0001, -0.0148, -0.0104],
[ 0.0044, -0.0282, -0.0292, ..., -0.0311, -0.0068, 0.0349]],
device='cuda:0', requires_grad=True)
このように optimizer の方の値は変わっていないです。
なので、optimizer 以外で重みを更新する場合は以下のようにします。
# .dataで更新する
net.fc1.weight.data = nn.Parameter(torch.ones([512, 784]))
print(net.fc1.weight)
print(optimizer.param_groups[0]['params'][0])
こうすると optimizer の方の値も書き換わります。
Parameter containing:
tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]], requires_grad=True)
Parameter containing:
tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]], requires_grad=True)
まとめ
PyTorch は便利で、なんとなく書いてもできてしまいます。
しかし、ちょっと発展的な内容を実装しようと思うと、勾配がどのように更新されるかなどを知らないと実装できないです。
今回実際に、細かく処理の経過を出力したりすることでだいぶ理解が進みました。