(33)【PyTorchでMNIST #4】プログラムの保守性を向上させる。

投稿者: | 2025年4月23日

622 views

この記事は最終更新から 397日 が経過しています。

【1】やりたいこと

前回の投稿 (32)【PyTorchでMNIST #3】理解の助けになるミニマム実装 で実装したプログラムを改良したい。

目的は、
・プログラムを部品化し、再利用性を上げること。
・ネットワーク構成の変更に柔軟に対応できること。
・学習済みパラメータをファイル保存し、後でテストだけを単体実行できること。

【2】やってみる

前投稿の一つのプログラムを、以下の 5ファイル構成に分解した。

1dataset_MNIST.pyMNISTデータセットに依存する処理を集約・隠蔽
2net_model.pyネットワーク構成を定義
3trainer.py学習・テスト処理を実装(分類問題用)
4control_train.py学習実行を制御
5control_test.pyテスト実行を制御

Case 1: ネットワーク構成を変更したい。

net_model.py だけを変更すればよい。

Case 2: データセットを変更したい。(=分類問題を変更したい)

入力画像サイズやラベル数が変わることになる。
dataset_MNIST.py, net_model.py を新しいデータセット用に入れ替える。
control_train.py, control_test.pyの import行を、新しいデータセット用に変更する。

1) プログラミング

(1) dataset_MNIST.py

MNISTデータセットに依存する処理を集約・隠蔽する。

from torchvision import datasets, transforms

class MyDataset:
    def __init__(self, data_dir='../data'):
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        self.data_dir = data_dir

    def get_train_dataset(self):
        return datasets.MNIST(self.data_dir, train=True, download=True, transform=self.transform)

    def get_test_dataset(self):
        return datasets.MNIST(self.data_dir, train=False, download=True, transform=self.transform)

(2) net_model.py

ネットワーク構成を定義する。

import torch.nn as nn
import torch

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

(3) trainer.py

学習、テスト処理を実装する。(分類タスク用)

import torch
import torch.nn as nn
import torch.optim as optim

class Trainer:
    def __init__(self, model, device, lr=0.001):
        self.model = model.to(device)
        self.device = device
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

    def train(self, train_loader, epochs=3):
        self.model.train()
        for epoch in range(epochs):
            for data, target in train_loader:
                data, target = data.to(self.device), target.to(self.device)
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()
            print(f"[{epoch+1}/{epochs}] Epoch complete")

    def test(self, test_loader):
        self.model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(self.device), target.to(self.device)
                outputs = self.model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
        acc = 100 * correct / total
        print(f"Test Accuracy: {acc:.2f}%")
        return acc

    def save(self, path):
        torch.save(self.model.state_dict(), path)
        print(f"Model saved to {path}")

    def load(self, path):
        self.model.load_state_dict(torch.load(path, map_location=self.device))
        print(f"Model loaded from {path}")

(4) control_train.py

学習実行を制御する。

from torch.utils.data import DataLoader
from net_model import Net
from trainer import Trainer
from dataset_MNIST import MyDataset
import torch

# データと環境
dataset = MyDataset()
train_loader = DataLoader(dataset.get_train_dataset(), batch_size=64,   shuffle=True)
test_loader  = DataLoader(dataset.get_test_dataset(),  batch_size=1000, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net()
trainer = Trainer(model, device)

# 実行
trainer.train(train_loader, epochs=3)
trainer.test(test_loader)
trainer.save("learned_model.pth")

(5) control_test.py

テスト実行を制御する。

from torch.utils.data import DataLoader
from net_model import Net
from trainer import Trainer
from dataset_MNIST import MyDataset
import torch

dataset = MyDataset()
test_loader = DataLoader(dataset.get_test_dataset(), batch_size=1000, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net()
trainer = Trainer(model, device)
trainer.load("learned_model.pth")
trainer.test(test_loader)

2) 実行

$ python control_train.py 
[1/3] Epoch complete
[2/3] Epoch complete
[3/3] Epoch complete
Test Accuracy: 97.20%
Model saved to learned_model.pth
$ python control_test.py 
Model loaded from learned_model.pth
Test Accuracy: 97.20%

今回は学習済みパラメータファイル名を固定にしたが、後日これも改良したい。
・コマンドライン引数で渡す。
・環境変数で渡す。
・諸々のパラメータをまとめてJSONデータで渡す。
等々を後日検討する。


アクセス数(直近7日): ※試験運用中、BOT除外簡易実装済
  • 2026-05-26: 2回
  • 2026-05-25: 0回
  • 2026-05-24: 0回
  • 2026-05-23: 0回
  • 2026-05-22: 0回
  • 2026-05-21: 0回
  • 2026-05-20: 0回
  • コメントを残す

    メールアドレスが公開されることはありません。 が付いている欄は必須項目です