superi

主にキャリアと金融を嗜むメディア

scikit-learnの決定木をexport_graphvizで可視化する


スポンサーリンク

データマイニングで定番の決定木分析をやってみたいと思います。

決定木の説明に関しては他に譲るとして、ここではpythonの機械学習ライブラリである、scikit-learnを利用して決定木分析を行い、graphvizという可視化ライブラリでグラフを描画します。

その時にハマったところを中心に話を進めたいと思います。

#scikit-learn付属データの読み込み

from sklearn.datasets import load_iris

iris = load_iris()

モデル作成には、tree.DecisionTreeClassifierクラスのfit関数を使います。

#決定木のクラスの読み込み

from sklearn import tree

#木の深さを指定し、モデルを生成

clf = tree.DecisionTreeClassifier(max_depth=3)
clf = clf.fit(iris.data, iris.target)

生成したモデルをgraphvizを利用して描画していきます。

pipでgraphvizpydotをインストールしておく必要があります。

pip install graphviz

pip install pydot

インストールしたライブラリを読み込み、

from graphviz import Digraph
import pydot
from sklearn.externals.six import StringIO
from IPython.display import Images

tree.export_graphviz関数を使って、dot形式で分類結果を吐き出します。

そして、pdfとしてグラフを表示します。

#dot形式でエクスポート

dot_data = StringIO()

tree.export_graphviz(clf, out_file=dot_data, class_names=iris.target_names,feature_names=iris.feature_names)
graph = pydot.graph_from_dot_data(dot_data.getvalue())

#グラフ化
graph.write_pdf("graph.pdf")
Image(graph.create_png())

ここでエラーになりました。graphvizを扱うにはpipでインストールするだけでなく、アプリケーション自体が必要です。以下Homebrewでinstallしましょう。

brew install graphviz

インストールできたら、先ほどのコードを再度実行します。

が、またエラーになりました。

TypeError: export_graphviz() got an unexpected keyword argument 'class_names'

これはscikit-learnのバージョンが古いためです。

anacondaを使っている人は以下コマンドでアップデートできます。

conda update scikit-learn

これで再度コードを実行しなおすと次の出力がされるかと思います。

f:id:ukichang:20160620011555p:plain

 

>>こちらも合わせてどうぞ 

www.superi.jp

www.superi.jp

 参考

scikit-learn で決定木分析 (CART 法) – Python でデータサイエンス

Git のインストールメモ 【Homebrew】 | gworks web site

python 2.7 - "class_names" in export_graphviz unexpected keyword error - Stack Overflow

sklearn.tree.export_graphviz — scikit-learn 0.17.1 documentation