読者です 読者をやめる 読者になる 読者になる

サイエンティストとマーケターのはざま

Pythonとか広告とかデータ分析とかとか


Python scikit-learnで決定木分析

データマイニングで定番の決定木分析をやってみたいと思います。

決定木の説明に関しては他に譲るとして、ここではpythonの機械学習ライブラリである、scikit-learnを利用して決定木分析を行い、graphvizという可視化ライブラリでグラフを描画します。

その時にハマったところを中心に話を進めたいと思います。

#scikit-learn付属データの読み込み

from sklearn.datasets import load_iris

iris = load_iris()

モデル作成には、tree.DecisionTreeClassifierクラスのfit関数を使います。

#決定木のクラスの読み込み

from sklearn import tree

#木の深さを指定し、モデルを生成

clf = tree.DecisionTreeClassifier(max_depth=3)
clf = clf.fit(iris.data, iris.target)

生成したモデルをgraphvizを利用して描画していきます。

pipでgraphvizとpydotをインストールしておく必要があります。

pip install graphviz

pip install pydot

インストールしたライブラリを読み込み、

from graphviz import Digraph
import pydot
from sklearn.externals.six import StringIO
from IPython.display import Images

tree.export_graphviz関数を使って、dot形式で分類結果を吐き出します。

そして、pdfとしてグラフを表示します。

#dot形式でエクスポート

dot_data = StringIO()

tree.export_graphviz(clf, out_file=dot_data, class_names=iris.target_names,feature_names=iris.feature_names)
graph = pydot.graph_from_dot_data(dot_data.getvalue())

#グラフ化
graph.write_pdf("graph.pdf")
Image(graph.create_png())

ここでエラーになりました。具体的なエラーはメモするの忘れたのですが、結論、graphvizを扱うにはpipでインストールするだけでなく、アプリケーション自体が必要みたいです。以下Homebrewでinstallしましょう。

brew install graphviz

インストールしたらdotで動作確認しておきましょう。

インストールできたら、先ほどのコードを再度実行します。

はい、またエラーになりました。

TypeError: export_graphviz() got an unexpected keyword argument 'class_names'

これはscikit-learnのバージョンが古いためです。

anacondaを使っている人は以下コマンドでアップデートできます。

conda update scikit-learn

これで再度コードを実行しなおすと下記のような出力ができるかと思います。

f:id:ukichang:20160620011555p:plain

決定木自体は難しくないですが、パッケージまわりでトラップ多いですね。

 

 参考

scikit-learn で決定木分析 (CART 法) – Python でデータサイエンス

Git のインストールメモ 【Homebrew】 | gworks web site

python 2.7 - "class_names" in export_graphviz unexpected keyword error - Stack Overflow

sklearn.tree.export_graphviz — scikit-learn 0.17.1 documentation