せっかくHoloLensがあるので機械学習の途中経過を覗いて遊んでみました。かたりぃなです。
学習中のデータの可視化する意義
機械学習についてですが、Chainerなどのフレームワークがどうやって学習を進めているかよくわかっていません。 具体的にはフレームワークによってネットワーク内パラメータがどう変動していくのかが分からないのです。 optimizerがいて順伝播と逆伝播の誤差を~というのは理屈ではそうなのでしょうけれども、実際に中を見ていないのでしっくりきていません。
Chainer自体はよくできたフレームワークなのでexampleの中身動かして「あ、こんな簡単に動くんだね」と感動するのですが、そこから先どうしようかという展望を持ちにくい印象です。 学習が進むことによって「ニューラルネットワークのパラメータにどういう"変化"が起きているのか」を見れたらもっと欲望というか展望が見える気がします。
可視化といえばchainerが出しているログをグラフにしたり、各層を可視化するなどといった手法はありますが、それって静的なものなので、「どういう過程で生み出されたものなのか」が見えにくいです。 少しずつ変化していっているとしても、やっぱりそこを見てみたいですよね。 可視化した画像をブラウザで開いてCtrl+F5連打という力業を試しましたが毎回連打するのは疲れます。
というわけで、機械学習の進行状況をリアルタイムで見れるようにアプリを作ってみます。 せっかくHoloLensあるのでこいつを使います。
実運用ではオーバーヘッドが大きすぎ(GPUからCPUにNNを取り出してからnpz形式でファイル出力)で役に立たないとは思いますが、まずは動いているものを見て興味を持つという点が重要かなと思っています。
結果
こんなの出ました。一定周期で更新されるワイヤーフレームはまるで昭和時代のCGです。 だんだんと山と谷の差が大きくなっていく様子が観察できました。
実世界のテーブルの上とかに表示しておいて山の裏側を見たいってときは回り込めば見えます。 (実機でのキャプチャしようと思ったのですがDirectXのレンダリング結果をオーバーレイしたキャプチャがうまくできないので保留です)
やったこと、やらなかったこと
今回やってみたことと、保留にしたものを列挙します。 大まかに考えると、「可視化すること」が目標で、そこから洗い出された課題などはいったん保留です。
やったこと
- chainerのネットワークの状態を一定周期でファイルに出力する
- ネットワークの種類はautoencoder(可視化して結果がわかりやすそう)
- HoloLens側からファイルをポーリングして一定周期で読み出す
- HoloLens上でDirectXを使ってレンダリング
やらなかったこと
- 機械学習の分野の詳細に立ち入ること(学習が収束するとか、効率の良いネットワークだとか)
- 可視化した結果の正確性
- 過度な高速化(HoloLensでとりあえず表示できればいい)
以降、実装の詳細です。
chainerのサンプルコードを編集
chainerはPC上で実行します。 まずchainerのmnistサンプルコードでスナップショットを出力する部分がありますが、ここを変更します。
trainer.extend(extensions.snapshot(filename='work.npz'), trigger=(1, 'epoch'))
デフォルトでは連番ファイルを出力するようになっているところを変更します。同一ファイルをwrite/readし続けたほうが実験としてはやりやすいので。 また、トリガは毎epoch終了ごととします。 ネットワークが大きいくなると毎epochやっていると重そうですが、今は気にしないことにします。
次にネットワークをautoencoderにします。 autoencoderを選んだ理由は実装が簡単かつ「なんかそれっぽい気がする」ものが見えそうだからです。 データセットにmnistを使うので入力/出力ともに784次元のデータです。
class MLP(chainer.Chain): def __init__(self, n_units): super(MLP, self).__init__( # the size of the inputs to each layer will be inferred encoder = L.Linear(None, n_units), # n_in -> n_units decoder = L.Linear(None, 784) ) def __call__(self, x): h = F.relu(self.encoder(x)) return self.decoder(h)
ここでencoder/decoderと名前を付けましたが、この名前がそのままスナップショット内のネットワークのレイヤ名に使われるので、わかりやすい名前をつけておいたほうが幸せになれます。
最後にデータセットを入出力ともに同じものを使うようにします。 withlabelをfalseにすれば教師ラベルなしの入力データがそのまま取得できます。 ただし、学習のイテレーションでは教師データをTupleDatasetとして与える必要があるので、 入力=出力となるDatasetを生成します。testも同様です。
train, test = chainer.datasets.get_mnist(withlabel=False)
train = tuple_dataset.TupleDataset(train, train)
test = tuple_dataset.TupleDataset(test, test)
ちょっと寄り道して動作確認。 それっぽく動いているかどうかpythonから画像を出力して確認してみます。 plot関係のライブラリを使って画像に出します。
なんか見れた。
寄り道の詳細
ClassifierしたmodelからだとNNを参照するのが大変そう(簡単にできるのだろうか?)なので、networkをとっておいてそこから重み係数を参照することにしました。 また画像として出力するにはユニット数が少ないほうが出力画像に収めやすいので10x10枚の可視化=100ユニットとしています。
import matplotlib.pyplot as plt import numpy as np #~~~ network = MLP(unit) model = L.Classifier(network) #~~~ save_images(network.encoder.data, "plot_test.png") #~~~ # 画像で保存(隠れ層=100という前提) def save_images(x, filename): fig, ax = plt.subplots(10, 10, figsize=(10, 10), dpi=100) for ai, xi in zip(ax.flatten(), x): ai.imshow(xi.reshape(28, 28)) fig.savefig(filename)
本当にこれでいいのか不安ですが、これでchainer側で可視化の準備は整いました。 スナップショットはresultフォルダに出力されます。 ついでにdotファイル吐かせてgraphvizで見れることも確認しました。
chainerのスナップショットをhttpで取得する
まずHTTPサーバをpythonで簡単に立ち上げます。3.5系ではこうなります。
python -m http.server
これで8000番ポートが開くので、httpアクセスすると上記スクリプトを実行したディレクトリに置かれたファイルを参照できます。
次はHoloLensアプリ側の実装です。 と、その前にHoloLensの設定とリモートデバッガの設定をします。
HoloLensの設定
HoloLensのデバイスポータルに接続してデバイスの情報を拾えるようにしておきます。
これをやっておけばPCのブラウザからHoloLensにアクセスできるようになります。 証明書のインストールとかあるので、手順は公式を参照。
HoloLens 用 Device Portal | Microsoft Docs
リモートデバッガを構成する
VisualStudioのデバッグ構成を変更してリモートデバッグの設定をします。 とはいえ毎回IP調べるのは面倒なので固定IPにしたほうが便利です。
HTTPサーバからファイルを取得する
ここからHoloLens側の実装です。 VisualStudio2015のテンプレートからHoloLens DirectX11を使って作業していきます。
httpアクセスしてchainerの学習中スナップショットを取得します。 UWPのhttpclientクラスはキャッシュが効いてしまって実験中は不便なのでHttpCacheReadBehavior::MostRecentを設定します。 あとは普通にresponseを見てdatareadすれば済みます。 エラーチェックしていませんが、ストアに提出するようなアプリのときはもうちょっと真面目にやりましょう。
auto uri = ref new Windows::Foundation::Uri(L"http://192.168.0.14:8000/result/work.npz";); auto filter = ref new Windows::Web::Http::Filters::HttpBaseProtocolFilter(); filter->CacheControl->ReadBehavior = Windows::Web::Http::Filters::HttpCacheReadBehavior::MostRecent; auto client = ref new Windows::Web::Http::HttpClient(filter); auto headers = client->DefaultRequestHeaders; create_task(client->GetAsync(uri)).then([cb](Windows::Web::Http::HttpResponseMessage ^ response) { response->EnsureSuccessStatusCode(); return response->Content->ReadAsBufferAsync(); }).then([cb](Windows::Storage::Streams::IBuffer ^ input) { auto reader = Windows::Storage::Streams::DataReader::FromBuffer(input); auto loaded_buffer = reader->ReadBuffer(input->Length); parse_npz(loaded_buffer, cb); });
この処理を一定周期(1秒間隔とか)で呼び出してあげたうえでzip形式を解釈してzlibで伸長してあげれば完成です。
Direct3Dでレンダリングする
取得したNNのencoderレイヤの各ユニットの重み係数は28x28のシングルチャンネル画像として解釈できます。 今回は単純にW係数を高さ(Y)にとることにします。(つまり、X,Zは0~27の間のいずれかの値で、X,Zが与えられるとYが一意に決まる,Yは重み係数) ポリゴンレンダリングすると色々と大変なのでワイヤーフレームでいきます。ワイヤーフレームのほうが味がありますし。
ラスタライズステージでワイヤーフレームにします。
D3D11_RASTERIZER_DESC rasterizerDesc = { D3D11_FILL_WIREFRAME, // ワイヤフレーム D3D11_CULL_NONE, // カリングなし FALSE, 0, 0.0f, FALSE, FALSE, FALSE, FALSE, FALSE }; ID3D11RasterizerState* rasterizerState = NULL; HRESULT hr = m_deviceResources->GetD3DDevice()->CreateRasterizerState(&rasterizerDesc, &rasterizerState); if (FAILED(hr)) { // TODO : エラー処理 } context->RSSetState(rasterizerState);
しかしこれではやりたいことに対して少し不足しています。 VisualStudioのDirectXテンプレートではTriangleListをシェーダーに入力していますが単純にLineListを入力したいところです。
GeometryShaderを書く
LineListをシェーダに入力するには、まずCPU側のコードは単純に
context->IASetPrimitiveTopology(D3D11_PRIMITIVE_TOPOLOGY_LINELIST);
すれば良いです。
これに合わせてGeometryShadeerのコードを変更します。 3頂点を受け取ってに三角形を吐き出すのではなく、2頂点を受け取って1つのLine(始点,終点)を出力するよう書き換えます。
-[maxvertexcount(3)] -void main(triangle GeometryShaderInput input[3], inout TriangleStream<GeometryShaderOutput> outStream) +[maxvertexcount(2)] +void main(line GeometryShaderInput input[2], inout LineStream<GeometryShaderOutput> outStream)
シェーダー内でやっているジオメトリ生成も同様にして2頂点ずつ生成するようにします。
[unroll(2)] for (int i = 0; i < 2; ++i) { // 略 }
あとはLineListの規則に従って頂点バッファとインデックスバッファを定義して格子状のメッシュを生成すれば完成です。
頂点バッファを動的に更新する
表示するワイヤーフレームができたので、Y軸を動的に更新します。 乱暴ですが新しい頂点情報を受け取るたびにVertexBufferを作り直すことにしました。 UpdateSubResourceのほうが適任な気がしますが、HoloLensのサンプルコードのspatialmappingもこういう実装になっていたりするのでいったん良しとします。 コードは自分自身納得いっていないので省略です。
参考にさせていただいたサイト
chainerのモデルをC++から読み込むのはこちらのサイトが参考になりました。そのままnpzを読み込めました。 ChainerのモデルをC++で読み込む - TadaoYamaokaの日記
感想
chainerのモデルをC++から読みだすことができるようになりました。
「なんとなくdeepLearningやってる」感じが見えたので良しとします。 できれば1つのレイヤ内のすべてのユニットを可視化したかったのですが、視界からはみ出して残念な見栄えになるので諦めました。
とはいえChainerに慣れていない段階だと便利な気がします。HoloLens装着したまま別の作業進めつつ、進捗状況をチラ見できます。 PCの画面を占有しない+進捗確認の作業が単純(頭の向きを変えるだけ)っていうのは正義ですね。
しかしながらワイヤーフレーム表示では学習が進んでデータのレンジが大きくなってきたときに辛いです。実世界に表示したモデルをぐるっと回って見るのであれば光源処理を入れたポリゴンのほうがいい気がします。
あと今回の簡易的な実装ではHoloLens側の処理がやたら重いです。 どれくらい重いかというとHoloLensのジェスチャ操作が効きにくくなるくらい重いです。
原因は詳しく調べてはいませんが、Chainerのsnapshotのnpzを展開するとき何も考えずに全部展開してしまっているのが原因かもしれないと思っています。というのも、snapshotで保存されるnpzはすべての情報が含まれていて、今回使わないdecoderやoptimizerなんかも展開してしまっているので。。。 気が向いたときに調べてみます。
今後の展望
読みだした学習済みモデルを食べさせるC++実装があれば色々できそうです。 あとはdotファイルをうまく解析すれば「c++からchainerのモデルなんでも読めるよ」という夢も叶うかもしれませんね。 書いてて気づきましたがboost::proptreeとcv::dnn使えばできそうな気がします。
アプリづくりの練習としてChainerの出力したdotファイルを読んで、どの層を見るかを選択するUIつけてあげれば面白そうです。
それでは今回はこれにて。