scikit-learnの決定木をexport_graphvizで可視化する
データマイニングで定番の決定木分析をやってみたいと思います。
決定木の説明に関しては他に譲るとして、ここではpythonの機械学習ライブラリである、scikit-learnを利用して決定木分析を行い、graphvizという可視化ライブラリでグラフを描画します。
その時にハマったところを中心に話を進めたいと思います。
#scikit-learn付属データの読み込み
from sklearn.datasets import load_iris
iris = load_iris()
モデル作成には、tree.DecisionTreeClassifierクラスのfit関数を使います。
#決定木のクラスの読み込み
from sklearn import tree
#木の深さを指定し、モデルを生成
clf = tree.DecisionTreeClassifier(max_depth=3)
clf = clf.fit(iris.data, iris.target)
生成したモデルをgraphvizを利用して描画していきます。
pipでgraphvizとpydotをインストールしておく必要があります。
pip install graphviz
pip install pydot
インストールしたライブラリを読み込み、
from graphviz import Digraph
import pydot
from sklearn.externals.six import StringIO
from IPython.display import Images
tree.export_graphviz関数を使って、dot形式で分類結果を吐き出します。
そして、pdfとしてグラフを表示します。
#dot形式でエクスポート
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data, class_names=iris.target_names,feature_names=iris.feature_names)
graph = pydot.graph_from_dot_data(dot_data.getvalue())#グラフ化
graph.write_pdf("graph.pdf")
Image(graph.create_png())
ここでエラーになりました。graphvizを扱うにはpipでインストールするだけでなく、アプリケーション自体が必要です。以下Homebrewでinstallしましょう。
brew install graphviz
インストールできたら、先ほどのコードを再度実行します。
が、またエラーになりました。
TypeError: export_graphviz() got an unexpected keyword argument 'class_names'
これはscikit-learnのバージョンが古いためです。
anacondaを使っている人は以下コマンドでアップデートできます。
conda update scikit-learn
これで再度コードを実行しなおすと次の出力がされるかと思います。
>>こちらも合わせてどうぞ
参考
scikit-learn で決定木分析 (CART 法) – Python でデータサイエンス
Git のインストールメモ 【Homebrew】 | gworks web site
python 2.7 - "class_names" in export_graphviz unexpected keyword error - Stack Overflow
sklearn.tree.export_graphviz — scikit-learn 0.17.1 documentation