{"id":1105,"date":"2025-04-23T04:23:33","date_gmt":"2025-04-22T19:23:33","guid":{"rendered":"https:\/\/www.dogrow.net\/nnet\/?p=1105"},"modified":"2025-04-24T14:29:01","modified_gmt":"2025-04-24T05:29:01","slug":"blog33%e3%80%90pytorch%e3%81%a7mnist-4%e3%80%91%e3%83%97%e3%83%ad%e3%82%b0%e3%83%a9%e3%83%a0%e3%81%ae%e4%bf%9d%e5%ae%88%e6%80%a7%e3%82%92%e5%90%91%e4%b8%8a%e3%81%95%e3%81%9b%e3%82%8b%e3%80%82","status":"publish","type":"post","link":"https:\/\/www.dogrow.net\/nnet\/blog33%e3%80%90pytorch%e3%81%a7mnist-4%e3%80%91%e3%83%97%e3%83%ad%e3%82%b0%e3%83%a9%e3%83%a0%e3%81%ae%e4%bf%9d%e5%ae%88%e6%80%a7%e3%82%92%e5%90%91%e4%b8%8a%e3%81%95%e3%81%9b%e3%82%8b%e3%80%82\/","title":{"rendered":"(33)\u3010PyTorch\u3067MNIST #4\u3011\u30d7\u30ed\u30b0\u30e9\u30e0\u306e\u4fdd\u5b88\u6027\u3092\u5411\u4e0a\u3055\u305b\u308b\u3002"},"content":{"rendered":"<h1 class=\"my_h\">\u30101\u3011\u3084\u308a\u305f\u3044\u3053\u3068<\/h1>\n<p>\u524d\u56de\u306e\u6295\u7a3f <a href=\"https:\/\/www.dogrow.net\/nnet\/blog32-pytorch%e3%81%a7mnist-%e7%90%86%e8%a7%a3%e3%81%ae%e5%8a%a9%e3%81%91%e3%81%ab%e3%81%aa%e3%82%8b%e3%83%9f%e3%83%8b%e3%83%9e%e3%83%a0%e5%ae%9f%e8%a3%85\/\" target=\"_blank\">(32)\u3010PyTorch\u3067MNIST #3\u3011\u7406\u89e3\u306e\u52a9\u3051\u306b\u306a\u308b\u30df\u30cb\u30de\u30e0\u5b9f\u88c5<\/a> \u3067\u5b9f\u88c5\u3057\u305f\u30d7\u30ed\u30b0\u30e9\u30e0\u3092\u6539\u826f\u3057\u305f\u3044\u3002<\/p>\n<p>\u76ee\u7684\u306f\u3001<br \/>\n\u30fb\u30d7\u30ed\u30b0\u30e9\u30e0\u3092\u90e8\u54c1\u5316\u3057\u3001\u518d\u5229\u7528\u6027\u3092\u4e0a\u3052\u308b\u3053\u3068\u3002<br \/>\n\u30fb\u30cd\u30c3\u30c8\u30ef\u30fc\u30af\u69cb\u6210\u306e\u5909\u66f4\u306b\u67d4\u8edf\u306b\u5bfe\u5fdc\u3067\u304d\u308b\u3053\u3068\u3002<br \/>\n\u30fb\u5b66\u7fd2\u6e08\u307f\u30d1\u30e9\u30e1\u30fc\u30bf\u3092\u30d5\u30a1\u30a4\u30eb\u4fdd\u5b58\u3057\u3001\u5f8c\u3067\u30c6\u30b9\u30c8\u3060\u3051\u3092\u5358\u4f53\u5b9f\u884c\u3067\u304d\u308b\u3053\u3068\u3002<\/p>\n<h1 class=\"my_h\">\u30102\u3011\u3084\u3063\u3066\u307f\u308b<\/h1>\n<p>\u524d\u6295\u7a3f\u306e\u4e00\u3064\u306e\u30d7\u30ed\u30b0\u30e9\u30e0\u3092\u3001\u4ee5\u4e0b\u306e 5\u30d5\u30a1\u30a4\u30eb\u69cb\u6210\u306b\u5206\u89e3\u3057\u305f\u3002<br \/>\n<table class=\"my_tbl_simple\">\n<tr><td>1<\/td><td>dataset_MNIST.py<\/td><td>MNIST\u30c7\u30fc\u30bf\u30bb\u30c3\u30c8\u306b\u4f9d\u5b58\u3059\u308b\u51e6\u7406\u3092\u96c6\u7d04\u30fb\u96a0\u853d<\/td><\/tr><tr><td>2<\/td><td>net_model.py<\/td><td>\u30cd\u30c3\u30c8\u30ef\u30fc\u30af\u69cb\u6210\u3092\u5b9a\u7fa9<\/td><\/tr><tr><td>3<\/td><td>trainer.py<\/td><td>\u5b66\u7fd2\u30fb\u30c6\u30b9\u30c8\u51e6\u7406\u3092\u5b9f\u88c5\uff08\u5206\u985e\u554f\u984c\u7528\uff09<\/td><\/tr><tr><td>4<\/td><td>control_train.py<\/td><td>\u5b66\u7fd2\u5b9f\u884c\u3092\u5236\u5fa1<\/td><\/tr><tr><td>5<\/td><td>control_test.py<\/td><td>\u30c6\u30b9\u30c8\u5b9f\u884c\u3092\u5236\u5fa1<\/td><\/tr>\n<\/table> <\/p>\n<h3 class=\"my_h\">Case 1: \u30cd\u30c3\u30c8\u30ef\u30fc\u30af\u69cb\u6210\u3092\u5909\u66f4\u3057\u305f\u3044\u3002<\/h3>\n<p>net_model.py \u3060\u3051\u3092\u5909\u66f4\u3059\u308c\u3070\u3088\u3044\u3002<\/p>\n<h3 class=\"my_h\">Case 2: \u30c7\u30fc\u30bf\u30bb\u30c3\u30c8\u3092\u5909\u66f4\u3057\u305f\u3044\u3002\uff08\uff1d\u5206\u985e\u554f\u984c\u3092\u5909\u66f4\u3057\u305f\u3044\uff09<\/h3>\n<p>\u5165\u529b\u753b\u50cf\u30b5\u30a4\u30ba\u3084\u30e9\u30d9\u30eb\u6570\u304c\u5909\u308f\u308b\u3053\u3068\u306b\u306a\u308b\u3002<br \/>\ndataset_MNIST.py, net_model.py \u3092\u65b0\u3057\u3044\u30c7\u30fc\u30bf\u30bb\u30c3\u30c8\u7528\u306b\u5165\u308c\u66ff\u3048\u308b\u3002<br \/>\ncontrol_train.py, control_test.py\u306e import\u884c\u3092\u3001\u65b0\u3057\u3044\u30c7\u30fc\u30bf\u30bb\u30c3\u30c8\u7528\u306b\u5909\u66f4\u3059\u308b\u3002<\/p>\n<h2 class=\"my_h\">1) \u30d7\u30ed\u30b0\u30e9\u30df\u30f3\u30b0<\/h2>\n<h3 class=\"my_h\">(1) dataset_MNIST.py<\/h3>\n<p>MNIST\u30c7\u30fc\u30bf\u30bb\u30c3\u30c8\u306b\u4f9d\u5b58\u3059\u308b\u51e6\u7406\u3092\u96c6\u7d04\u30fb\u96a0\u853d\u3059\u308b\u3002<\/p>\n<pre class=\"brush: python; title: ; notranslate\" title=\"\">\r\nfrom torchvision import datasets, transforms\r\n\r\nclass MyDataset:\r\n    def __init__(self, data_dir='..\/data'):\r\n        self.transform = transforms.Compose(&#x5B;\r\n            transforms.ToTensor(),\r\n            transforms.Normalize((0.1307,), (0.3081,))\r\n        ])\r\n        self.data_dir = data_dir\r\n\r\n    def get_train_dataset(self):\r\n        return datasets.MNIST(self.data_dir, train=True, download=True, transform=self.transform)\r\n\r\n    def get_test_dataset(self):\r\n        return datasets.MNIST(self.data_dir, train=False, download=True, transform=self.transform)\r\n<\/pre>\n<h3 class=\"my_h\">(2) net_model.py<\/h3>\n<p>\u30cd\u30c3\u30c8\u30ef\u30fc\u30af\u69cb\u6210\u3092\u5b9a\u7fa9\u3059\u308b\u3002<\/p>\n<pre class=\"brush: python; title: ; notranslate\" title=\"\">\r\nimport torch.nn as nn\r\nimport torch\r\n\r\nclass Net(nn.Module):\r\n    def __init__(self):\r\n        super(Net, self).__init__()\r\n        self.fc1 = nn.Linear(28 * 28, 128)\r\n        self.fc2 = nn.Linear(128, 64)\r\n        self.fc3 = nn.Linear(64, 10)\r\n\r\n    def forward(self, x):\r\n        x = x.view(-1, 28 * 28)\r\n        x = torch.relu(self.fc1(x))\r\n        x = torch.relu(self.fc2(x))\r\n        x = self.fc3(x)\r\n        return x\r\n<\/pre>\n<h3 class=\"my_h\">(3) trainer.py<\/h3>\n<p>\u5b66\u7fd2\u3001\u30c6\u30b9\u30c8\u51e6\u7406\u3092\u5b9f\u88c5\u3059\u308b\u3002\uff08\u5206\u985e\u30bf\u30b9\u30af\u7528\uff09<\/p>\n<pre class=\"brush: python; title: ; notranslate\" title=\"\">\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.optim as optim\r\n\r\nclass Trainer:\r\n    def __init__(self, model, device, lr=0.001):\r\n        self.model = model.to(device)\r\n        self.device = device\r\n        self.criterion = nn.CrossEntropyLoss()\r\n        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)\r\n\r\n    def train(self, train_loader, epochs=3):\r\n        self.model.train()\r\n        for epoch in range(epochs):\r\n            for data, target in train_loader:\r\n                data, target = data.to(self.device), target.to(self.device)\r\n                self.optimizer.zero_grad()\r\n                output = self.model(data)\r\n                loss = self.criterion(output, target)\r\n                loss.backward()\r\n                self.optimizer.step()\r\n            print(f&quot;&#x5B;{epoch+1}\/{epochs}] Epoch complete&quot;)\r\n\r\n    def test(self, test_loader):\r\n        self.model.eval()\r\n        correct = 0\r\n        total = 0\r\n        with torch.no_grad():\r\n            for data, target in test_loader:\r\n                data, target = data.to(self.device), target.to(self.device)\r\n                outputs = self.model(data)\r\n                _, predicted = torch.max(outputs.data, 1)\r\n                total += target.size(0)\r\n                correct += (predicted == target).sum().item()\r\n        acc = 100 * correct \/ total\r\n        print(f&quot;Test Accuracy: {acc:.2f}%&quot;)\r\n        return acc\r\n\r\n    def save(self, path):\r\n        torch.save(self.model.state_dict(), path)\r\n        print(f&quot;Model saved to {path}&quot;)\r\n\r\n    def load(self, path):\r\n        self.model.load_state_dict(torch.load(path, map_location=self.device))\r\n        print(f&quot;Model loaded from {path}&quot;)\r\n<\/pre>\n<h3 class=\"my_h\">(4) control_train.py<\/h3>\n<p>\u5b66\u7fd2\u5b9f\u884c\u3092\u5236\u5fa1\u3059\u308b\u3002<\/p>\n<pre class=\"brush: python; title: ; notranslate\" title=\"\">\r\nfrom torch.utils.data import DataLoader\r\nfrom net_model import Net\r\nfrom trainer import Trainer\r\nfrom dataset_MNIST import MyDataset\r\nimport torch\r\n\r\n# \u30c7\u30fc\u30bf\u3068\u74b0\u5883\r\ndataset = MyDataset()\r\ntrain_loader = DataLoader(dataset.get_train_dataset(), batch_size=64,   shuffle=True)\r\ntest_loader  = DataLoader(dataset.get_test_dataset(),  batch_size=1000, shuffle=False)\r\n\r\ndevice = torch.device(&quot;cuda&quot; if torch.cuda.is_available() else &quot;cpu&quot;)\r\nmodel = Net()\r\ntrainer = Trainer(model, device)\r\n\r\n# \u5b9f\u884c\r\ntrainer.train(train_loader, epochs=3)\r\ntrainer.test(test_loader)\r\ntrainer.save(&quot;learned_model.pth&quot;)\r\n<\/pre>\n<h3 class=\"my_h\">(5) control_test.py<\/h3>\n<p>\u30c6\u30b9\u30c8\u5b9f\u884c\u3092\u5236\u5fa1\u3059\u308b\u3002<\/p>\n<pre class=\"brush: python; title: ; notranslate\" title=\"\">\r\nfrom torch.utils.data import DataLoader\r\nfrom net_model import Net\r\nfrom trainer import Trainer\r\nfrom dataset_MNIST import MyDataset\r\nimport torch\r\n\r\ndataset = MyDataset()\r\ntest_loader = DataLoader(dataset.get_test_dataset(), batch_size=1000, shuffle=False)\r\n\r\ndevice = torch.device(&quot;cuda&quot; if torch.cuda.is_available() else &quot;cpu&quot;)\r\nmodel = Net()\r\ntrainer = Trainer(model, device)\r\ntrainer.load(&quot;learned_model.pth&quot;)\r\ntrainer.test(test_loader)\r\n<\/pre>\n<h2 class=\"my_h\">2) \u5b9f\u884c<\/h2>\n<pre class=\"my_pre_bgBlack\">\r\n$ python control_train.py \r\n[1\/3] Epoch complete\r\n[2\/3] Epoch complete\r\n[3\/3] Epoch complete\r\nTest Accuracy: 97.20%\r\nModel saved to learned_model.pth\r\n<\/pre>\n<pre class=\"my_pre_bgBlack\">\r\n$ python control_test.py \r\nModel loaded from learned_model.pth\r\nTest Accuracy: 97.20%\r\n<\/pre>\n<p>\u4eca\u56de\u306f\u5b66\u7fd2\u6e08\u307f\u30d1\u30e9\u30e1\u30fc\u30bf\u30d5\u30a1\u30a4\u30eb\u540d\u3092\u56fa\u5b9a\u306b\u3057\u305f\u304c\u3001\u5f8c\u65e5\u3053\u308c\u3082\u6539\u826f\u3057\u305f\u3044\u3002<br \/>\n\u30fb\u30b3\u30de\u30f3\u30c9\u30e9\u30a4\u30f3\u5f15\u6570\u3067\u6e21\u3059\u3002<br \/>\n\u30fb\u74b0\u5883\u5909\u6570\u3067\u6e21\u3059\u3002<br \/>\n\u30fb\u8af8\u3005\u306e\u30d1\u30e9\u30e1\u30fc\u30bf\u3092\u307e\u3068\u3081\u3066JSON\u30c7\u30fc\u30bf\u3067\u6e21\u3059\u3002<br \/>\n\u7b49\u3005\u3092\u5f8c\u65e5\u691c\u8a0e\u3059\u308b\u3002<\/p>\n<hr class=\"my_hr_bottom\">\n","protected":false},"excerpt":{"rendered":"<p>\u30101\u3011\u3084\u308a\u305f\u3044\u3053\u3068 \u524d\u56de\u306e\u6295\u7a3f (32)\u3010PyTorch\u3067MNIST #3\u3011\u7406\u89e3\u306e\u52a9\u3051\u306b\u306a\u308b\u30df\u30cb\u30de\u30e0\u5b9f\u88c5 \u3067\u5b9f\u88c5\u3057\u305f\u30d7\u30ed\u30b0\u30e9\u30e0\u3092\u6539\u826f\u3057\u305f\u3044\u3002 \u76ee\u7684\u306f\u3001 \u30fb\u30d7\u30ed\u30b0\u30e9\u30e0\u3092\u90e8\u54c1\u5316\u3057\u3001\u518d\u5229\u7528\u6027\u3092\u4e0a\u3052\u308b\u3053\u3068\u3002 \u30fb\u30cd\u30c3\u30c8\u30ef\u30fc\u30af\u69cb\u6210\u306e\u2026 <span class=\"read-more\"><a href=\"https:\/\/www.dogrow.net\/nnet\/blog33%e3%80%90pytorch%e3%81%a7mnist-4%e3%80%91%e3%83%97%e3%83%ad%e3%82%b0%e3%83%a9%e3%83%a0%e3%81%ae%e4%bf%9d%e5%ae%88%e6%80%a7%e3%82%92%e5%90%91%e4%b8%8a%e3%81%95%e3%81%9b%e3%82%8b%e3%80%82\/\">\u7d9a\u304d\u3092\u8aad\u3080 &raquo;<\/a><\/span><\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[18,6],"tags":[],"class_list":["post-1105","post","type-post","status-publish","format-standard","hentry","category-mnist","category-6"],"views":621,"amp_enabled":true,"_links":{"self":[{"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/posts\/1105","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/comments?post=1105"}],"version-history":[{"count":16,"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/posts\/1105\/revisions"}],"predecessor-version":[{"id":1217,"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/posts\/1105\/revisions\/1217"}],"wp:attachment":[{"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/media?parent=1105"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/categories?post=1105"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.dogrow.net\/nnet\/wp-json\/wp\/v2\/tags?post=1105"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}