621 views
この記事は最終更新から 397日 が経過しています。
【1】やりたいこと
前回の投稿 (32)【PyTorchでMNIST #3】理解の助けになるミニマム実装 で実装したプログラムを改良したい。
目的は、
・プログラムを部品化し、再利用性を上げること。
・ネットワーク構成の変更に柔軟に対応できること。
・学習済みパラメータをファイル保存し、後でテストだけを単体実行できること。
【2】やってみる
前投稿の一つのプログラムを、以下の 5ファイル構成に分解した。
| 1 | dataset_MNIST.py | MNISTデータセットに依存する処理を集約・隠蔽 |
| 2 | net_model.py | ネットワーク構成を定義 |
| 3 | trainer.py | 学習・テスト処理を実装(分類問題用) |
| 4 | control_train.py | 学習実行を制御 |
| 5 | control_test.py | テスト実行を制御 |
Case 1: ネットワーク構成を変更したい。
net_model.py だけを変更すればよい。
Case 2: データセットを変更したい。(=分類問題を変更したい)
入力画像サイズやラベル数が変わることになる。
dataset_MNIST.py, net_model.py を新しいデータセット用に入れ替える。
control_train.py, control_test.pyの import行を、新しいデータセット用に変更する。
1) プログラミング
(1) dataset_MNIST.py
MNISTデータセットに依存する処理を集約・隠蔽する。
from torchvision import datasets, transforms
class MyDataset:
def __init__(self, data_dir='../data'):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
self.data_dir = data_dir
def get_train_dataset(self):
return datasets.MNIST(self.data_dir, train=True, download=True, transform=self.transform)
def get_test_dataset(self):
return datasets.MNIST(self.data_dir, train=False, download=True, transform=self.transform)
(2) net_model.py
ネットワーク構成を定義する。
import torch.nn as nn
import torch
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
(3) trainer.py
学習、テスト処理を実装する。(分類タスク用)
import torch
import torch.nn as nn
import torch.optim as optim
class Trainer:
def __init__(self, model, device, lr=0.001):
self.model = model.to(device)
self.device = device
self.criterion = nn.CrossEntropyLoss()
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
def train(self, train_loader, epochs=3):
self.model.train()
for epoch in range(epochs):
for data, target in train_loader:
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
output = self.model(data)
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
print(f"[{epoch+1}/{epochs}] Epoch complete")
def test(self, test_loader):
self.model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(self.device), target.to(self.device)
outputs = self.model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
acc = 100 * correct / total
print(f"Test Accuracy: {acc:.2f}%")
return acc
def save(self, path):
torch.save(self.model.state_dict(), path)
print(f"Model saved to {path}")
def load(self, path):
self.model.load_state_dict(torch.load(path, map_location=self.device))
print(f"Model loaded from {path}")
(4) control_train.py
学習実行を制御する。
from torch.utils.data import DataLoader
from net_model import Net
from trainer import Trainer
from dataset_MNIST import MyDataset
import torch
# データと環境
dataset = MyDataset()
train_loader = DataLoader(dataset.get_train_dataset(), batch_size=64, shuffle=True)
test_loader = DataLoader(dataset.get_test_dataset(), batch_size=1000, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net()
trainer = Trainer(model, device)
# 実行
trainer.train(train_loader, epochs=3)
trainer.test(test_loader)
trainer.save("learned_model.pth")
(5) control_test.py
テスト実行を制御する。
from torch.utils.data import DataLoader
from net_model import Net
from trainer import Trainer
from dataset_MNIST import MyDataset
import torch
dataset = MyDataset()
test_loader = DataLoader(dataset.get_test_dataset(), batch_size=1000, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net()
trainer = Trainer(model, device)
trainer.load("learned_model.pth")
trainer.test(test_loader)
2) 実行
$ python control_train.py [1/3] Epoch complete [2/3] Epoch complete [3/3] Epoch complete Test Accuracy: 97.20% Model saved to learned_model.pth
$ python control_test.py Model loaded from learned_model.pth Test Accuracy: 97.20%
今回は学習済みパラメータファイル名を固定にしたが、後日これも改良したい。
・コマンドライン引数で渡す。
・環境変数で渡す。
・諸々のパラメータをまとめてJSONデータで渡す。
等々を後日検討する。
アクセス数(直近7日): ※試験運用中、BOT除外簡易実装済2026-05-26: 2回 2026-05-25: 0回 2026-05-24: 0回 2026-05-23: 0回 2026-05-22: 0回 2026-05-21: 0回 2026-05-20: 0回