{"id":1065,"date":"2025-04-23T02:10:28","date_gmt":"2025-04-22T17:10:28","guid":{"rendered":"https:\/\/www.dogrow.net\/nnet\/?p=1065"},"modified":"2025-04-24T14:29:42","modified_gmt":"2025-04-24T05:29:42","slug":"blog32-pytorch%e3%81%a7mnist-%e7%90%86%e8%a7%a3%e3%81%ae%e5%8a%a9%e3%81%91%e3%81%ab%e3%81%aa%e3%82%8b%e3%83%9f%e3%83%8b%e3%83%9e%e3%83%a0%e5%ae%9f%e8%a3%85","status":"publish","type":"post","link":"https:\/\/www.dogrow.net\/nnet\/blog32-pytorch%e3%81%a7mnist-%e7%90%86%e8%a7%a3%e3%81%ae%e5%8a%a9%e3%81%91%e3%81%ab%e3%81%aa%e3%82%8b%e3%83%9f%e3%83%8b%e3%83%9e%e3%83%a0%e5%ae%9f%e8%a3%85\/","title":{"rendered":"(32)\u3010PyTorch\u3067MNIST #3\u3011\u7406\u89e3\u306e\u52a9\u3051\u306b\u306a\u308b\u30df\u30cb\u30de\u30e0\u5b9f\u88c5\uff08CPU\u7528\u3001GPU\u7528\uff09"},"content":{"rendered":"<h1 class=\"my_h\">\u30101\u3011\u3084\u308a\u305f\u3044\u3053\u3068<\/h1>\n<p>\u904e\u53bb\u8a18\u4e8b\u3067\u4f7f\u7528\u3057\u305f PyTorch\u516c\u5f0f\u306e MNIST\u30b5\u30f3\u30d7\u30eb\u30d7\u30ed\u30b0\u30e9\u30e0\u306f\u57fa\u790e\u5b66\u7fd2\u306b\u306f\u5c11\u3005\u5927\u304d\u3081\u3060\u3002<br \/>\n\u5fc5\u8981\u6a5f\u80fd\u3060\u3051\u3092\u5207\u308a\u51fa\u3057\u3066\u3001\u30df\u30cb\u30de\u30e0\u5b9f\u88c5\u3057\u3066\u307f\u305f\u3044\u3002<\/p>\n<p><span class=\"my_fc_deeppinkBBig\">\u30d7\u30ed\u30b0\u30e9\u30e0\u306e\u6700\u5c0f\u5316\u306b\u3088\u308a\u3001<br \/>\n\u51e6\u7406\u306e\u30a8\u30c3\u30bb\u30f3\u30b9\u304c\u660e\u78ba\u306b\u308f\u304b\u308b\u3088\u3046\u306b\u306a\u308b\u3060\u308d\u3046\u3002<\/span><\/p>\n<h1 class=\"my_h\">\u30102\u3011\u3084\u3063\u3066\u307f\u305f<\/h1>\n<h3 class=\"my_h\">(1) CPU\u5b9f\u884c\u7528\u306e\u5b9f\u88c5<\/h3>\n<pre class=\"brush: python; title: ; notranslate\" title=\"\">\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.optim as optim\r\nfrom torchvision import datasets, transforms\r\nfrom torch.utils.data import DataLoader\r\n\r\n#\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\r\n# 1. \u30c7\u30fc\u30bf\u306e\u6e96\u5099\r\ntransform = transforms.Compose(&#x5B;\r\n    transforms.ToTensor(),\r\n    transforms.Normalize((0.1307,), (0.3081,))\r\n])\r\n\r\ndataDir = '..\/data'\r\ntrain_dataset = datasets.MNIST(dataDir, train=True, download=True, transform=transform)\r\ntest_dataset  = datasets.MNIST(dataDir, train=False,               transform=transform)\r\ntrain_loader = DataLoader(train_dataset, batch_size=64,   shuffle=True)\r\ntest_loader  = DataLoader(test_dataset,  batch_size=1000, shuffle=False)\r\n\r\n#\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\r\n# 2. \u30cb\u30e5\u30fc\u30e9\u30eb\u30cd\u30c3\u30c8\u30ef\u30fc\u30af\u306e\u5b9a\u7fa9\r\nclass Net(nn.Module):\r\n    def __init__(self):\r\n        super(Net, self).__init__()\r\n        self.fc1 = nn.Linear(28 * 28, 128)\r\n        self.fc2 = nn.Linear(128, 64)\r\n        self.fc3 = nn.Linear(64, 10)\r\n\r\n    def forward(self, x):\r\n        x = torch.flatten(x, 1)      # &#x5B;BatchSize]&#x5B;28]&#x5B;28] \u2192 &#x5B;BatchSize]&#x5B;784] 1\u6b21\u5143\u30d9\u30af\u30c8\u30eb\u306b\u5909\u63db\r\n        x = torch.relu(self.fc1(x))  #  \u2192 &#x5B;BatchSize]&#x5B;128] \u2192 ReLU\r\n        x = torch.relu(self.fc2(x))  #  \u2192 &#x5B;BatchSize]&#x5B;64] \u2192 ReLU\r\n        x = self.fc3(x)              #  \u2192 &#x5B;BatchSize]&#x5B;10] \u3053\u306e\u5f8c\u3001criterion\u3067 softmax\u51e6\u7406\u3055\u308c\u308b\u3002\r\n        return x\r\n\r\nmodel = Net()\r\n\r\n#\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\r\n# 3. \u640d\u5931\u95a2\u6570\u30fb\u6700\u9069\u5316\u624b\u6cd5\u306e\u5b9a\u7fa9\r\ncriterion = nn.CrossEntropyLoss()    # \u5206\u985e\u30bf\u30b9\u30af\u5411\u3051\u306e\u640d\u5931\u95a2\u6570\r\noptimizer = optim.Adam(model.parameters())\r\n\r\n#\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\r\n# 4. \u5b66\u7fd2\u30eb\u30fc\u30d7\r\nfor epoch in range(5):\r\n    model.train()                          # \u5b66\u7fd2\u30e2\u30fc\u30c9\u306b\u8a2d\u5b9a (dropout\u7b49\u3092\u8a31\u53ef)\r\n    for data, target in train_loader:      # \u30a4\u30c6\u30ec\u30fc\u30bf\u304b\u3089\u30c7\u30fc\u30bf\u3092\u53d6\u308a\u51fa\u3059\u3002\r\n        optimizer.zero_grad()              # \u52fe\u914d\u3092\u30bc\u30ed\u30af\u30ea\u30a2\r\n        output = model(data)               # \u9806\u4f1d\u64ad\u3092\u5b9f\u884c\r\n        loss = criterion(output, target)   # \u640d\u5931\u95a2\u6570\u3092\u5b9f\u884c (\u8aa4\u5dee\u3092\u7b97\u51fa)\r\n        loss.backward()                    # \u9006\u4f1d\u64ad\r\n        optimizer.step()                   # \u5404\u30d1\u30e9\u30e1\u30fc\u30bf\u306e\u66f4\u65b0\r\n    print(f&quot;Epoch {epoch+1} complete&quot;)\r\n\r\n#\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\r\n# 5. \u30c6\u30b9\u30c8\r\nmodel.eval()                               # \u30c6\u30b9\u30c8\u30e2\u30fc\u30c9\u306b\u8a2d\u5b9a (dropout\u7b49\u3092\u7981\u6b62)\r\ncorrect = 0                                # \u6b63\u89e3\u6570 = 0\r\ntotal = 0                                  # \u30c6\u30b9\u30c8\u6570 = 0\r\nwith torch.no_grad():                      # \u30c6\u30b9\u30c8\u4e2d\u306f\u52fe\u914d\u3092\u8a08\u7b97\u3057\u306a\u3044\u3088\u3046\u306b\u3059\u308b\u3002\uff08\u30e1\u30e2\u30ea\u7bc0\u7d04\uff09\r\n    for data, target in test_loader:       # \u30a4\u30c6\u30ec\u30fc\u30bf\u304b\u3089\u30c7\u30fc\u30bf\u3092\u53d6\u308a\u51fa\u3059\u3002\r\n        outputs = model(data)              # \u9806\u4f1d\u64ad\r\n        _, predicted = torch.max(outputs.data, 1)      # \u51fa\u529b\u5c64\u306e\u6700\u5927\u5024\u306e\u30a4\u30f3\u30c7\u30c3\u30af\u30b9\uff08=\u30e9\u30d9\u30eb\u756a\u53f7\uff09\u3092\u53d6\u5f97\r\n        total += target.size(0)                        # \u30c6\u30b9\u30c8\u6570\u3092\u66f4\u65b0\r\n        correct += (predicted == target).sum().item()  # \u6b63\u89e3\u6570\u3092\u66f4\u65b0\uff08\u6bd4\u8f03\u7d50\u679c\u304cTrue\u306e\u6570\u3092\u52a0\u7b97\uff09\r\nprint(f&quot;Test Accuracy: {100 * correct \/ total:.2f}%&quot;)\r\n<\/pre>\n<p>\u51e6\u7406\u306e\u30dd\u30a4\u30f3\u30c8\u306b\u3064\u3044\u3066\u3001\u30e1\u30e2\u66f8\u304d\u3057\u3066\u304a\u304f\u3002<\/p>\n<p>9-12\u884c\u76ee\uff1a<br \/>\n\u5165\u529b\u30c7\u30fc\u30bf\u306e\u524d\u51e6\u7406\u3092\u5b9a\u7fa9\u3059\u308b\u3002<br \/>\n\u753b\u50cf\u3092\u30c6\u30f3\u30bd\u30eb\u306b\u5909\u63db\u3057\u3001\u5e73\u5747 0.1307\u3001\u6a19\u6e96\u504f\u5dee 0.3081 \u3067\u6b63\u898f\u5316\u3057\u3066\u3044\u308b\u3002<br \/>\n\u3053\u308c\u306b\u3088\u308a\u5b66\u7fd2\u304c\u53ce\u675f\u3057\u3084\u3059\u304f\u306a\u308a\u3001\u5b89\u5b9a\u5316\u30fb\u9ad8\u901f\u5316\u304c\u671f\u5f85\u3067\u304d\u308b\u3002<\/p>\n<p>15-18\u884c\u76ee\uff1a<br \/>\nMNIST\u30c7\u30fc\u30bf\u30bb\u30c3\u30c8\u3092\u30d5\u30a1\u30a4\u30eb\u304b\u3089\u30ed\u30fc\u30c9\u3057\u3001\u30a4\u30c6\u30ec\u30fc\u30bf\u578b\u30aa\u30d6\u30b8\u30a7\u30af\u30c8\u306b\u683c\u7d0d\u3057\u3066\u3044\u308b\u3002<br \/>\n\u6307\u5b9a\u30c7\u30a3\u30ec\u30af\u30c8\u30ea\u306b\u30c7\u30fc\u30bf\u30bb\u30c3\u30c8\u30d5\u30a1\u30a4\u30eb\u304c\u5b58\u5728\u3057\u306a\u3051\u308c\u3070\u3001\u81ea\u52d5\u7684\u306b\u30c0\u30a6\u30f3\u30ed\u30fc\u30c9\u3057\u3066\u304f\u308b\u3002<br \/>\n\u30a4\u30c6\u30ec\u30fc\u30bf\u304b\u3089\u30c7\u30fc\u30bf\u3092\u53d6\u5f97\u6642\u306b\u3001\u4e0a\u8a18\u306e\u6b63\u898f\u5316\u6e08\u307f\u306e\u30c7\u30fc\u30bf\u304c\u53d6\u5f97\u3067\u304d\u308b\u3002<br \/>\n\u30a4\u30c6\u30ec\u30fc\u30bf\u304b\u3089\u30c7\u30fc\u30bf\u3092\u53d6\u5f97\u6642\u306b\u3001\u5b66\u7fd2\u30c7\u30fc\u30bf\u306f\u30b7\u30e3\u30c3\u30d5\u30eb\u3057\u305f\u5f8c\u3001\u30c6\u30b9\u30c8\u30c7\u30fc\u30bf\u306f\u30b7\u30e3\u30c3\u30d5\u30eb\u305b\u305a\u306b\u53d6\u5f97\u3059\u308b\u3002<\/p>\n<p>\u5b9f\u884c\u7d50\u679c\u306f\u4ee5\u4e0b\u306e\u901a\u308a\u3002<\/p>\n<pre class=\"my_pre_bgBlack\">\r\n$ python mnist_samnple_CPU.py\r\nEpoch 1 complete\r\nEpoch 2 complete\r\nEpoch 3 complete\r\nEpoch 4 complete\r\nEpoch 5 complete\r\nTest Accuracy: 97.48%\r\n<\/pre>\n<h3 class=\"my_h\">(2) GPU\u5b9f\u884c\u7528\u306e\u5b9f\u88c5<\/h3>\n<p>\u5909\u66f4\u7b87\u6240\u3092\u30cf\u30a4\u30e9\u30a4\u30c8\u8868\u793a\u3059\u308b\u3002<\/p>\n<pre class=\"brush: python; highlight: [36,37,38,50,65]; title: ; notranslate\" title=\"\">\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.optim as optim\r\nfrom torchvision import datasets, transforms\r\nfrom torch.utils.data import DataLoader\r\n\r\n#\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\r\n# 1. \u30c7\u30fc\u30bf\u306e\u6e96\u5099\r\ntransform = transforms.Compose(&#x5B;\r\n    transforms.ToTensor(),\r\n    transforms.Normalize((0.1307,), (0.3081,))\r\n])\r\n\r\ndataDir = '..\/data'\r\ntrain_dataset = datasets.MNIST(dataDir, train=True, download=True, transform=transform)\r\ntest_dataset  = datasets.MNIST(dataDir, train=False,               transform=transform)\r\ntrain_loader = DataLoader(train_dataset, batch_size=64,   shuffle=True)\r\ntest_loader  = DataLoader(test_dataset,  batch_size=1000, shuffle=False)\r\n\r\n#\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\r\n# 2. \u30cb\u30e5\u30fc\u30e9\u30eb\u30cd\u30c3\u30c8\u30ef\u30fc\u30af\u306e\u5b9a\u7fa9\r\nclass Net(nn.Module):\r\n    def __init__(self):\r\n        super(Net, self).__init__()\r\n        self.fc1 = nn.Linear(28 * 28, 128)\r\n        self.fc2 = nn.Linear(128, 64)\r\n        self.fc3 = nn.Linear(64, 10)\r\n\r\n    def forward(self, x):\r\n        x = x.view(-1, 28 * 28)\r\n        x = torch.relu(self.fc1(x))\r\n        x = torch.relu(self.fc2(x))\r\n        x = self.fc3(x)\r\n        return x\r\n\r\ndevice = torch.device(&quot;cuda&quot; if torch.cuda.is_available() else &quot;cpu&quot;)\r\nprint(f&quot;Device: {device}&quot;)\r\nmodel = Net().to(device)\r\n\r\n#\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\r\n# 3. \u640d\u5931\u95a2\u6570\u30fb\u6700\u9069\u5316\u624b\u6cd5\u306e\u5b9a\u7fa9\r\ncriterion = nn.CrossEntropyLoss()\r\noptimizer = optim.Adam(model.parameters())\r\n\r\n#\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\r\n# 4. \u5b66\u7fd2\u30eb\u30fc\u30d7\r\nfor epoch in range(5):  # \u30a8\u30dd\u30c3\u30af\u6570\u306f\u304a\u597d\u307f\u3067\u8abf\u6574\r\n    model.train()\r\n    for data, target in train_loader:\r\n        data, target = data.to(device), target.to(device)\r\n        optimizer.zero_grad()\r\n        output = model(data)\r\n        loss = criterion(output, target)\r\n        loss.backward()\r\n        optimizer.step()\r\n    print(f&quot;Epoch {epoch+1} complete&quot;)\r\n\r\n#\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\r\n# 5. \u30c6\u30b9\u30c8\r\nmodel.eval()\r\ncorrect = 0\r\ntotal = 0\r\nwith torch.no_grad():\r\n    for data, target in test_loader:\r\n        data, target = data.to(device), target.to(device)\r\n        outputs = model(data)\r\n        _, predicted = torch.max(outputs.data, 1)\r\n        total += target.size(0)\r\n        correct += (predicted == target).sum().item()\r\nprint(f&quot;Test Accuracy: {100 * correct \/ total:.2f}%&quot;)\r\n<\/pre>\n<p>\u5b9f\u884c\u7d50\u679c\u306f\u4ee5\u4e0b\u306e\u901a\u308a\u3002<br \/>\nCPU\u7248\u3068\u6bd4\u8f03\u3059\u308b\u3068\u3060\u3044\u3076\u9ad8\u901f\u3060\u3002<\/p>\n<pre class=\"my_pre_bgBlack\">\r\n$ python mnist_samnple_GPU.py\r\nDevice: cuda\r\nEpoch 1 complete\r\nEpoch 2 complete\r\nEpoch 3 complete\r\nEpoch 4 complete\r\nEpoch 5 complete\r\nTest Accuracy: 96.98%\r\n<\/pre>\n<hr class=\"my_hr_bottom\">\n","protected":false},"excerpt":{"rendered":"<p>\u30101\u3011\u3084\u308a\u305f\u3044\u3053\u3068 \u904e\u53bb\u8a18\u4e8b\u3067\u4f7f\u7528\u3057\u305f PyTorch\u516c\u5f0f\u306e MNIST\u30b5\u30f3\u30d7\u30eb\u30d7\u30ed\u30b0\u30e9\u30e0\u306f\u57fa\u790e\u5b66\u7fd2\u306b\u306f\u5c11\u3005\u5927\u304d\u3081\u3060\u3002 \u5fc5\u8981\u6a5f\u80fd\u3060\u3051\u3092\u5207\u308a\u51fa\u3057\u3066\u3001\u30df\u30cb\u30de\u30e0\u5b9f\u88c5\u3057\u3066\u307f\u305f\u3044\u3002 \u30d7\u30ed\u30b0\u30e9\u30e0\u306e\u6700\u5c0f\u5316\u306b\u3088\u308a\u3001 \u51e6\u7406\u306e\u30a8\u30c3\u30bb\u30f3\u30b9\u304c\u660e\u2026 <span class=\"read-more\"><a href=\"https:\/\/www.dogrow.net\/nnet\/blog32-pytorch%e3%81%a7mnist-%e7%90%86%e8%a7%a3%e3%81%ae%e5%8a%a9%e3%81%91%e3%81%ab%e3%81%aa%e3%82%8b%e3%83%9f%e3%83%8b%e3%83%9e%e3%83%a0%e5%ae%9f%e8%a3%85\/\">\u7d9a\u304d\u3092\u8aad\u3080 &raquo;<\/a><\/span><\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[18,16],"tags":[],"class_list":["post-1065","post","type-post","status-publish","format-standard","hentry","category-mnist","category-pytorch"],"views":950,"amp_enabled":true,"_links":{"self":[{"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/posts\/1065","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/comments?post=1065"}],"version-history":[{"count":22,"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/posts\/1065\/revisions"}],"predecessor-version":[{"id":1218,"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/posts\/1065\/revisions\/1218"}],"wp:attachment":[{"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/media?parent=1065"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/categories?post=1065"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/tags?post=1065"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}