[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
60
43

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

KerasモデルをCloud ML Engineで学習してOnline Predictionしてみた

Last updated at Posted at 2017-07-05

Cloud ML Engine のruntime versionが1.2になったので、Kerasが小細工なしで使えるようになりました。TensorFlowの高レベルAPIもいい感じになって来ていますが、やはりKerasのpretrained modelの多さは魅力的です。とりあえずやり方だけ把握しておこうと、せっかくなので学習だけでなくOnline PredictionもKerasモデルでserveしてみました。

Cloud ML Engineとは
TensorFlowのフルマネージドな実行環境です。分散環境で学習、オートスケールしAPIで推論リクエスト可能なOnline Prediction等、TensorFlowの運用には最高の環境です。

KerasをCloud ML Engine(training)で使う

注意するのは、

  • Kerasのimportをtf.contribからする
  • jobのsubmit時にruntimeVersionを1.2に設定する

位で、後は普通に動きます。こんな感じで書けばOK (こちらを参考にしました)

import tensorflow as tf
from tensorflow.contrib.keras.python import keras
from tensorflow.contrib.keras.python.keras.models import Sequential
from tensorflow.contrib.keras.python.keras.layers.core import Dense, Activation
from sklearn.cross_validation import train_test_split

iris = tf.contrib.learn.datasets.base.load_iris()
train_x, test_x, train_y, test_y = train_test_split(
    iris.data, iris.target, test_size=0.2)

num_classes = 3
train_y = keras.utils.to_categorical(train_y, num_classes)
test_y = keras.utils.to_categorical(test_y, num_classes)

model = Sequential()
model.add(Dense(10, activation='relu', input_shape=(4,)))
model.add(Dense(20, activation='relu'))
model.add(Dense(10, activation='relu'))
model.add(Dense(3, activation='softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer='sgd',
              metrics=['accuracy'])

cb = keras.callbacks.TensorBoard(
    log_dir="gs://BUCKET/keras-mlengine", histogram_freq=1)
# train
model.fit(train_x, train_y,
          batch_size=100,
          epochs=20,
          verbose=2,
          callbacks=[cb],
          validation_data=(test_x, test_y))

# eval
score = model.evaluate(test_x, test_y, verbose=0)
pred = model.predict(test_x)

TensorBoardの出力先はGCSにしておけば後からローカルでもCloud Shellからでも参照できます。
Datalab、Jupyter NotebookからJobを投げたいとき
cloudml-magicをインストールしてください。runtime version 1.2にも対応しておきました。Kerasの利用例はこちら

ログはこんな感じに出力されます。
Untitled.png

SavedModelを作る

Online Predictionはv1からSavedModel形式しか対応していないので、KerasモデルからSavedModelに変換する必要があります。Sessionを取り出してSignature追加すればいいだけですが、ちょっと嵌りました。

from tensorflow.contrib.keras import backend
sess = backend.get_session()
x = sess.graph.get_tensor_by_name('dense_1_input:0')
y = sess.graph.get_tensor_by_name('ArgMax_1:0')
inputs = {"dense_1_input": tf.saved_model.utils.build_tensor_info(x)}
outputs = {"ArgMax_1": tf.saved_model.utils.build_tensor_info(y)}
signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs=inputs,
    outputs=outputs,
    method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
)

# save as SavedModel
b = tf.saved_model.builder.SavedModelBuilder('gs://BUCKET/keras-mlengine/savedmodel')
b.add_meta_graph_and_variables(sess,
                               [tf.saved_model.tag_constants.SERVING],
                               signature_def_map={'serving_default': signature})
b.save()

SavedModelの保存先をGCSに指定すれば、Online Predictionでもそのまま使えます。Consoleから作成したSavedModelのGCSパスを指定すればOnline Prediction環境の出来上がり。

ちなみに入力と出力のTensorを取得するのに、Keras側のGraph生成ルールがわからず、とりあえずTensorBoardで確認してget_tensor_by_nameでTensorを引っ張ってきています。Sequentialplaceholderを突っ込むのが正しいやり方でしょうが、そうするとそこだけローレベルになってしまうので
追記
model.inputsmodel.outputsで入出力Tensorが取得できるようです。
しかしそれでも面倒なのでEstimatorexport_savedmodelみたいなのがKerasにも欲しいですね。

Online Prediction で推論する

ここは従来通りです。Discovery APIを使う場合はこんな感じ。

from oauth2client.client import GoogleCredentials
from googleapiclient import discovery
from googleapiclient import errors

PROJECTID = 'PROJECTID'
projectID = 'projects/{}'.format(PROJECTID)
modelName = 'keras-iris'
modelID = '{}/models/{}'.format(projectID, modelName)

credentials = GoogleCredentials.get_application_default()
ml = discovery.build('ml', 'v1', credentials=credentials)

request_body = {'instances': [{'dense_1_input': [5.4,  3.9,  1.3,  0.4]},
                              {'dense_1_input': [4.4,  3.2,  1.3,  0.2]},
                              {'dense_1_input': [4.3,  3.,  1.1,  0.1]},
                              {'dense_1_input': [5.,  3.5,  1.6,  0.6]},
                              {'dense_1_input': [5.9,  3.,  4.2,  1.5]},
                              {'dense_1_input': [7.7,  3.,  6.1,  2.3]},
                              ]}

request = ml.projects().predict(name=modelID, body=request_body)
try:
    response = request.execute()
except errors.HttpError as err:
    # Something went wrong with the HTTP transaction.
    # To use logging, you need to 'import logging'.
    print('There was an HTTP error during the request:')
    print(err._get_reason())
print(response)

するとこんな感じで帰ってきます。

{u'predictions': [{u'ArgMax_1': 0},
  {u'ArgMax_1': 0},
  {u'ArgMax_1': 0},
  {u'ArgMax_1': 0},
  {u'ArgMax_1': 1},
  {u'ArgMax_1': 2}]}

終わりに

GCPすごい、Cloud ML Engineすごい!・・・のですが、あまり情報がないのが寂しいところ。慣れちゃえばとっても簡単なので、皆さんどんどん使って色々事例を出してくださいね!

60
43
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
60
43

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?