Qstairs

現役AIベンチャーCTOの知見、画像認識(人工知能、Deep Learning)を中心とした技術ノウハウをアップしていきます

広告

【Deep Learning】fetch_mldataでmnistのデータを取得できない場合の対処(Chainerの場合)

f:id:qstairs:20160601221351j:plain

はじめに

mnistを使った実験を行いたい場合、
これまでは「fetch_mldata」でmnistのデータを取得していましたが、
どうやら取得できなくなっている模様です。
#Irisは取得できています。

そこで、fetch_mldataではなく
Chainerにあるmnistを取得する関数を使おうとしました。

ところが、データ形式がfetch_mldataとは異なっていて困った。
なんとかfetch_mldataの形式に合わせる処理を作ったので紹介します。

fetch_mldataの場合

mnist = fetch_mldata('MNIST original', data_home=".")
# mnist.data : 70,000件の28x28=784次元ベクトルデータ
mnist.data = mnist.data.astype(np.float32)
mnist.data /= 255  # 正規化

# mnist.target : 正解データ
mnist.target = mnist.target.astype(np.int32)

# 学習用データN個,検証用データを残りの個数に設定
N = 60000
xtrain, xtest = np.split(mnist.data,      [N])
ytrain, yans = np.split(mnist.target,    [N])

Chainerでmnistを取得しfetch_mldataの形式に変換

各変数がfetch_mldataと対応しています。

train, test = chainer.datasets.get_mnist()
train = np.array(train)
test = np.array(test)
xtrain = []
ytrain = []
for data, label in train:
    xtrain.append(data)
    label_list = np.zeros(10)
    label_list[label] = 1.0
    ytrain.append(label_list)
xtrain = np.array(xtrain).astype(np.float32)
ytrain = np.array(ytrain).astype(np.float32)

xtest = []
yans = []
for data, label in test:
    xtest.append(data)
    yans.append(label)
xtest = np.array(xtest).astype(np.float32)
yans = np.array(yans)


以上!

広告