** Sorry, this note is Japanese only, but please take a look at some code snippets. Hope it helps you a bit!
先日、TF Learn (元 skflow )というものが、TensorFlow を scikit-learn のように Deep Learning できて便利です!という記事を書いたのですが、
あれ・・・
TensorFlow って TensorBoard で学習状況とかを見れたような・・・ これって TF Learn を通して TensorFlow を使う場合どうなるんだろう・・・
・・・
ということで調べてみました!
TensorBoard ってなんだ
TensorFlow で機械学習をしていると「ちゃんうちのコは学習して賢くなってくれてるのかしら・・・」と不安になることがあります。
そんな時のために、TensorFlow は TensorBoard という便利なツールを用意してくれていて、自分が TensorFlow で開発している学習モデルが賢くなっていく様子をグラフとかで見せてくれます。
TensorBoard の詳しい使い方は公式ドキュメントを読んでいただくとして・・・
今回は TF Learn( skflow )で学習した場合に、TensorBoard で visualize するログをどうやって残すのか?という方法を紹介したいと思います。
TensorBoard の公式ドキュメント
TensorBoard: Visualizing Learning
skflow / TF Learn のチュートリアルを調べてみたら・・・
skflow / TF Learn の開発者 Yuan Tang 氏自身のチュートリアルによると、fit メソッドに引数で logdir を渡すことができる、と書いてます。
Introduction to Scikit Flow - A Simplified Interface to TensorFlowSummaries/TensorBoard(英語)
つまりこんな感じ。
from tensorflow.contrib import learn
classifier = learn.DNNClassifier(hidden_units=[10,20,10], n_classes=3)
classifier.fit(x=X_train, y=y_train, steps=200, logdir='./logs')
・・・
・・・あれ??
こんなん(エラー)出ましたけど・・・
TypeError: fit() got an unexpected keyword argument 'logdir'
確かに、DNNClassifier.fit の引数に logdir はないですね・・・
Help on function fit in module tensorflow.contrib.learn.python.learn.estimators.estimator:
fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None)
Trains a model given training datax
predictions andy
targets.
(以下略)
skflow / TF Learn で TensorBoard のためにログを残す・解決方法
TensorFlow のソースを見たり、ググったりしたけど解決方法が見つからなかったので、ダメ元で Yuan 氏に直接聞いてみると・・・
お返事くれました!やさしい!!
Yuan (Terry) Tang @terrytangyuan Jul 17 00:26
Use model_dir in the constructor instead
examples are outdated
モデル生成時のコンストラクタの引数 model_dir を代わりに使ってね!パラメータTensorFlow のリポジトリの中にあるサンプルは古い Yo!(超訳)
って感じでしょうか。というわけでこんなコードで試すと・・・
from tensorflow.contrib import learn
classifier = learn.DNNClassifier(hidden_units=[10,20,10], n_classes=3, model_dir='./logs')
classifier.fit(x=X_train, y=y_train, steps=200)
・・・
・・・
わー!!データが出てる!!ちゃんと TensorBoard でもグラフが出ました!!
最終的にコード全体はこんな感じです(iris のデータを DNNClassifier で分類)。
TensorFlow 関連の他の記事
datalove.hatenadiary.jp datalove.hatenadiary.jp datalove.hatenadiary.jp datalove.hatenadiary.jp