【Deep Learning】fetch_mldataでmnistのデータを取得できない場合の対処(Chainerの場合)
はじめに
mnistを使った実験を行いたい場合、
これまでは「fetch_mldata」でmnistのデータを取得していましたが、
どうやら取得できなくなっている模様です。
#Irisは取得できています。
そこで、fetch_mldataではなく
Chainerにあるmnistを取得する関数を使おうとしました。
ところが、データ形式がfetch_mldataとは異なっていて困った。
なんとかfetch_mldataの形式に合わせる処理を作ったので紹介します。
fetch_mldataの場合
mnist = fetch_mldata('MNIST original', data_home=".") # mnist.data : 70,000件の28x28=784次元ベクトルデータ mnist.data = mnist.data.astype(np.float32) mnist.data /= 255 # 正規化 # mnist.target : 正解データ mnist.target = mnist.target.astype(np.int32) # 学習用データN個,検証用データを残りの個数に設定 N = 60000 xtrain, xtest = np.split(mnist.data, [N]) ytrain, yans = np.split(mnist.target, [N])
Chainerでmnistを取得しfetch_mldataの形式に変換
各変数がfetch_mldataと対応しています。
train, test = chainer.datasets.get_mnist() train = np.array(train) test = np.array(test) xtrain = [] ytrain = [] for data, label in train: xtrain.append(data) label_list = np.zeros(10) label_list[label] = 1.0 ytrain.append(label_list) xtrain = np.array(xtrain).astype(np.float32) ytrain = np.array(ytrain).astype(np.float32) xtest = [] yans = [] for data, label in test: xtest.append(data) yans.append(label) xtest = np.array(xtest).astype(np.float32) yans = np.array(yans)
以上!