Qstairs

現役AIベンチャーCTOの知見、画像認識(人工知能、Deep Learning)を中心とした技術ノウハウをアップしていきます

広告

【画像認識】Deep Learningフレームワーク「Chainer」で自前データで識別する方法

f:id:qstairs:20160601221351j:plain

はじめに

前回は自前で用意したデータを使用した「Chainer」での学習方法について紹介しました。
qstairs.hatenablog.com

今回は、前回学習した結果(モデル)を使用した識別処理について紹介します。

識別の流れ

識別の流れは大きく分けて以下のようになります。
2,3,4は変数名が変わった程度で学習時と同じです。

  1. 学習したモデルを読み込む
  2. 識別データリスト読み込み
  3. 識別データ格納
  4. 識別データ形式変換
  5. 識別
  6. 識別結果集計



ソースコード

識別処理のソースコードのみ載せます。
importしている「CNN_model」は前回の記事を参考にしてください。

# -*- coding:utf-8 -*-
import numpy as np
import chainer
from chainer import cuda, Function, FunctionSet, gradient_check, Variable, optimizers, serializers
from chainer.training import extensions
import argparse
import random
import chainer.functions as F
import cv2
import CNN_model as cnn
import time

cuda.get_device(0).use()    # GPUを使用することを前提とする

# 引数指定
parser = argparse.ArgumentParser(description='Train Sample')
parser.add_argument('--test_list', '-test', default='test.txt', type=str, help='Test data list')
parser.add_argument('--model', '-m', default='model_100', type=str, help='Learning model')
parser.add_argument('--epoch', '-e', type=int, default=100, help='Number of epochs to train')
args = parser.parse_args()

# 学習モデル読み込み
model = cnn.CNN()
serializers.load_npz(args.model, model.model)

# テストデータリストファイルから一行ずつ読み込む(学習時と同じ)
test_list = []
for line in open(args.test_list):
    pair = line.strip().split()
    test_list.append((pair[0], np.float32(pair[1])))

# 画像データとラベルデータを取得する(学習時と同じ)
x_test = []    # 画像データ格納
y_test = []    # ラベルデータ格納
for filepath, label in test_list:
    img = cv2.imread(filepath, 0)   # グレースケールで読み込む
    x_test.append(img)
    y_test.append(label)

# 学習で使用するsoftmax_cross_entropyは
# 学習データはfloat32,ラベルはint32にする必要がある。
x_test = np.array(x_test).astype(np.float32)
y_test = np.array(y_test).astype(np.int32)
# 画像を(学習枚数、チャンネル数、高さ、幅)の4次元に変換する
x_test = x_test.reshape(len(x_test), 1, 100, 100) / 255

N = len(y_test)

batchsize = 1
datasize = len(x_test)

# 学習開始
test_start = time.time()
false_count = 0
perm = np.random.permutation(N) # データセットの順番をシャッフル
for i in range(0,datasize, batchsize):
    x_batch = cuda.to_gpu(x_test[[i]]) # バッチサイズ分のデータを取り出す
    y_batch = cuda.to_gpu(y_test[[i]])

    result = model.forward(x_batch, y_batch, train=False) # 学習ではないのでtrainをFalse
    if np.argmax(result.data) != y_batch[0]:
        false_count += 1
        print "No.{} Wrong! correct:{}, result:{}".format(i, y_batch[0], np.argmax(result.data))
    else:
        print "No.{} Got it! correct:{}, result:{}".format(i, y_batch[0], np.argmax(result.data))

print "data num:{} false num:{} accuracy={}".format(datasize, false_count, 1 - (false_count / datasize))
print "test time:{}".format(time.time()-test_start)

実行結果

実行した結果が以下になります。
なんと一つも間違えていない!!
(識別対象が単純でデータの生成方法も同じですからね^_^; )

~\chainer>python test.py -test .\chainer_data\test.txt -m model_100
No.0 Got it! correct:0, result:0
No.1 Got it! correct:0, result:0
No.2 Got it! correct:0, result:0
No.3 Got it! correct:0, result:0
No.4 Got it! correct:0, result:0
No.5 Got it! correct:0, result:0
No.6 Got it! correct:0, result:0
No.7 Got it! correct:0, result:0
No.8 Got it! correct:0, result:0
No.9 Got it! correct:0, result:0
No.10 Got it! correct:0, result:0
~略~
No.7487 Got it! correct:2, result:2
No.7488 Got it! correct:2, result:2
No.7489 Got it! correct:2, result:2
No.7490 Got it! correct:2, result:2
No.7491 Got it! correct:2, result:2
No.7492 Got it! correct:2, result:2
No.7493 Got it! correct:2, result:2
No.7494 Got it! correct:2, result:2
No.7495 Got it! correct:2, result:2
No.7496 Got it! correct:2, result:2
No.7497 Got it! correct:2, result:2
No.7498 Got it! correct:2, result:2
No.7499 Got it! correct:2, result:2
data num:7500 false num:0 accuracy=1
test time:29.1349999905

最後に

これまで4回にわたって「Chainer」による学習、識別方法を紹介してきました。
予想以上にコード量が少なくて驚きます!

こんなに簡単にDeep Learningが使えるならもっと面白いことしたいんですが、
やっぱりデータがないと何もできませんね(-_-;

回帰モデルの作成もしてみたいのですが、
テストデータをどう取得するかが悩みどころです。

面白そうな学習データが見つかった際に、
紹介できればと思います。



【関連記事】

  • 「Chainer」の構築

qstairs.hatenablog.com

  • 「Chainer」での学習の下準備

qstairs.hatenablog.com

  • 「Chainer」での自前データによる学習

qstairs.hatenablog.com

広告