- ニューラルネットワークのライブラリのChainerですが、去年のうちに大分変更がありました.
- というかバージョンアップ早すぎてびびる
- この記事書く際にふとリファレンス見たらいつの間にか1.20.0のドキュメントができてた(GitHubのリリースノートの最新は現時点ではまだ1.19.0)
Chainer1.19.0版MNISTのコードを紹介します.
# -*- coding: utf-8 -*- from __future__ import print_function import chainer import chainer.functions as F import chainer.links as L from chainer import training from chainer.training import extensions #Network definition class MLP(chainer.Chain): def __init__(self, n_units, n_out): super(MLP, self).__init__( l1=L.Linear(None, n_units), l2=L.Linear(None, n_units), l3=L.Linear(None, n_out), ) def __call__(self, x): h1 = F.relu(self.l1(x)) h2 = F.relu(self.l2(h1)) return self.l3(h2) def main(): unit = 1000 batchsize = 100 epoch = 20 model = L.Classifier(MLP(unit, 10)) optimizer = chainer.optimizers.Adam() optimizer.setup(model) train, test = chainer.datasets.get_mnist() train_iter = chainer.iterators.SerialIterator(train, batchsize) test_iter = chainer.iterators.SerialIterator(test, batchsize, repeat=False, shuffle=False) updater = training.StandardUpdater(train_iter, optimizer) trainer = training.Trainer(updater, (epoch, 'epoch'), out='result') trainer.extend(extensions.Evaluator(test_iter, model)) trainer.extend(extensions.dump_graph('main/loss')) trainer.extend(extensions.snapshot(), trigger=(epoch, 'epoch')) trainer.extend(extensions.LogReport()) trainer.extend(extensions.PrintReport( ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) trainer.extend(extensions.ProgressBar()) trainer.run() if __name__ == "__main__": main()
ざっくりとした説明とかはこちら
ChainerでMNIST