2019年4月7日 星期日

PyTorch - 搭建神經網絡 - Building Model

在 PyTorch 中搭建神經網絡有一個固定格式,參考PyTorch 文本
格式固定為 class Model,其中包含__init__(self) & forward(self, x) 如下:
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
來試試上一篇回歸的問題,使用這種方式搭建模型。

1. 建立 DATA

import torch import matplotlib.pyplot as plt X = torch.unsqueeze(torch.linspace(-1, 1, 200), dim=1) # x data (tensor), shape=(100, 1) Y = 2*X.pow(2) + 0.3*torch.rand(X.size()) # noisy y data (tensor), shape=(100, 1) plt.scatter(X.data.numpy(), Y.data.numpy()) plt.show()
將X,Y 丟進 Variable 做梯度計算用。
X=Variable(torch.Tensor(X.reshape(200,1))) Y=Variable(torch.Tensor(Y.reshape(200,1)))

2. 建立神經網絡模型

先定義所有的層屬性(__init __()),然後再一層層搭建(forward(x))層於層的關係鏈接。
這邊先在 __init__() 中寫好要使用的 nn,然後在 forward()中將串接過程使用激活函數。
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.nn1 = nn.Linear(1, 15) #第一層 Linear NN self.nn2 = nn.Linear(15, 1) #第二層 Linear NN def forward(self, x): x = F.relu(self.nn1(x)) #對第一層 NN 使用Relu激活 x = self.nn2(x) #第二層直接輸出 return x model = Model() print(model) #將模型print出來看看
Model(
  (nn1): Linear(in_features=1, out_features=15, bias=True)
  (nn2): Linear(in_features=15, out_features=1, bias=True)
)
搭建好了之後,一樣選擇優化器&損失函數:
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.005) loss_function = torch.nn.MSELoss()

3. 直接將訓練過程可視化

epochs = 500 for epoch in range(epochs): prediction = model(X) loss = loss_function(prediction, Y) optimizer.zero_grad() loss.backward() optimizer.step() if epoch % 20 == 0: # plot and show learning process plt.cla() plt.scatter(X.data.numpy(), Y.data.numpy()) plt.plot(X.data.numpy(), prediction.data.numpy(), 'r-', lw=5) plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'}) plt.savefig('D:\\img'+'%s'%epoch+'.jpg') plt.pause(0.1)
製作成GIF如下:

沒有留言:

張貼留言