(38)【PyTorchでCIFAR-10 #1】自動認識してみる。

投稿者: | 2025年4月23日

875 views

【1】やりたいこと

PyTorchを使い、CIFAR-10の自動認識プログラムを作って動かしたい。
データセットの構造は、前回 (32) CIFAR-10画像セットを見てみる(Python版)で解析済みだ。

https://www.cs.utoronto.ca/~kriz/cifar.html

【2】やってみる

1) プログラム・ソースコード

MNISTのときと同様に、PyTorchでは torchvision.datasets.CIFAR10 なる便利クラスも用意されている。

でも・・・
データ入力部分をブラックボックスにしていると
自作のデータセットを使いたいときに困るはず。

そこで・・・
今回は CIFAR-10用の Datasetクラスを自作することにした。
クラス名は CustomCIFAR10 とする。

(1) CIFAR10Dataset.py

import pickle
import numpy as np
import torch
from torch.utils.data import Dataset
import os

#//////////////////////////////////////////////////////////////////////////////
# CIFAR-10 Dataset クラス定義
class CustomCIFAR10(Dataset):
    #//////////////////////////////////////////////////////////////////////////
    def __init__(self, data_dir, train=True, transform=None):
        self.transform = transform
        self.data = []
        self.labels = []
        if train:
            files = [f"data_batch_{i}" for i in range(1, 6)]
        else:
            files = ["test_batch"]
        for file in files:
            filepath = os.path.join(data_dir, file)
            images, labels = self.load_cifar10_batch(filepath)
            self.data.append(images)
            self.labels += labels
        self.data = np.concatenate(self.data)
    #//////////////////////////////////////////////////////////////////////////
    def __len__(self):
        return len(self.labels)
    #//////////////////////////////////////////////////////////////////////////
    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx]
        image = torch.tensor(image, dtype=torch.float32)
        if self.transform:
            image = self.transform(image)
        return image, label
    #//////////////////////////////////////////////////////////////////////////
    def load_cifar10_batch(self, file):
        with open(file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
        data = dict[b'data']
        labels = dict[b'labels']
        data = data.reshape(len(data), 3, 32, 32).astype("float32") / 255.0
        return data, labels

(2) main.py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from CIFAR10Dataset import CustomCIFAR10

#//////////////////////////////////////////////////////////////////////////////
# CNN モデル定義
class SimpleCNN(nn.Module):
    #//////////////////////////////////////////////////////////////////////////
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 10)
    #//////////////////////////////////////////////////////////////////////////
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

#//////////////////////////////////////////////////////////////////////////////
# メイン関数
def main():
    # 設定
    data_dir = '../data/cifar-10-batches-py'
    batch_size = 64
    num_epochs = 3
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # データ変換
    transform = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

    # データセットとローダー
    trainset = CustomCIFAR10(data_dir, train=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

    testset = CustomCIFAR10(data_dir, train=False, transform=transform)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

    # モデル、損失関数、最適化
    model = SimpleCNN().to(device)
    criterion = nn.CrossEntropyLoss()      # クラス分類向けの損失関数
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # 学習
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.detach().item()
            if i % 100 == 99:
                print(f"[{epoch+1}, {i+1:5d}] loss: {running_loss / 100:.3f}")
                running_loss = 0.0

    # テスト
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Accuracy on test images: {100 * correct / total:.2f}%")

#//////////////////////////////////////////////////////////////////////////////
# 実行
if __name__ == "__main__":
    main()

2) レイヤー構成

【入力層】
 ↓ [3][32][32]
【CNN】        ※上下左右端にパディング1 → 出力サイズは変わらず。
 ↓ [32[32][32]
【Max Pooling】    ※「窓2×2」で「ストライド2」→ 半分のサイズに縮小される。
 ↓ [32][16][16]
【CNN】        ※上下左右端にパディング1 → 出力サイズは変わらず。
 ↓ [64[16][16]
【Max Pooling】    ※「窓2×2」で「ストライド2」→ 半分のサイズに縮小される。
 ↓ [64][8][8]
【Full Connection】
 ↓ [256]
【Full Connection】
 ↓ [10]
【出力層】

3) 実行結果

このプログラムの正解率は、3 epochsで 70%前後だった。

$ python main.py
[1,   100] loss: 1.806
[1,   200] loss: 1.465
[1,   300] loss: 1.331
[1,   400] loss: 1.291
[1,   500] loss: 1.175
[1,   600] loss: 1.153
[1,   700] loss: 1.083
[2,   100] loss: 0.982
[2,   200] loss: 0.937
[2,   300] loss: 0.939
[2,   400] loss: 0.937
[2,   500] loss: 0.897
[2,   600] loss: 0.901
[2,   700] loss: 0.850
[3,   100] loss: 0.759
[3,   200] loss: 0.767
[3,   300] loss: 0.756
[3,   400] loss: 0.754
[3,   500] loss: 0.759
[3,   600] loss: 0.728
[3,   700] loss: 0.759
Finished Training
Accuracy on test images: 71.40%

アクセス数(直近7日): ※試験運用中、BOT除外簡易実装済
  • 2026-05-26: 0回
  • 2026-05-25: 0回
  • 2026-05-24: 0回
  • 2026-05-23: 0回
  • 2026-05-22: 0回
  • 2026-05-21: 0回
  • 2026-05-20: 0回
  • コメントを残す

    メールアドレスが公開されることはありません。 が付いている欄は必須項目です