Decoding predictions of decision tree algorithm in pyspark
If we want to analyse reasons behind why particular instance is predicted to belong to certain class by decision tree model, we need to parse the decision tree produced during training.
There are several posts that explain how same can be achieved with scikit learn decision tree model, however there are very few for pyspark decision tree model. Those that have tried are hard to understand. So here, I will try to elucidate it.
Note: I have tested code only for numerical features. If you also have some categorical ones, code should work but that needs to be tested.
Many of the times later in post, I have included screenshots of code snippets. It would be troublesome if you just want to use the code. So providing a link to one of my answer on stackoverflow for similar question. You can just copy the code from there.
let’s define a sample dataframe as below.
import pandas as pd
from pyspark.sql import SparkSession
spark_session = SparkSession.builder.getOrCreate()data = pd.DataFrame({
'ball': [0, 1, 1, 3, 1, 0, 1, 3],
'keep': [4, 5, 6, 7, 7, 4, 6, 7],
'hall': [8, 9, 10, 11, 2, 6, 10, 11],
'fall': [12, 13, 14, 15, 15, 12, 14, 15],
'mall': [16, 17, 18, 10, 10, 16, 18, 10],
'label': [21, 31, 41, 51, 51, 51, 21, 31]
})df = spark.createDataFrame(data)
df.show()df.show()
let us train a pyspark decision tree model on this sample dataframe. As all columns are numeric we just need to assemble them in one column using vector assembler and use that as a feature column for training decision tree.
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml import Pipeline# feature columns
f_list = ['ball','keep','mall','hall','fall']assemble_numerical_features = VectorAssembler(inputCols=f_list, outputCol='features',handleInvalid='skip')dt = DecisionTreeClassifier(featuresCol='features',labelCol='label')pipeline = Pipeline(stages=[assemble_numerical_features, dt])
model = pipeline.fit(df)
df = model.transform(df)dt_m = model.stages[-1]
Trained Decision tree model rules in string format is as below
print(dt_m.toDebugString)
Now, task reduces to parsing these rules. First we will convert them from string to dictionary of nodes and their connection as below.
last line will result in following output. Hand drawn figure at top shows below output in tree form.
As we can see in above dictionary, rules are in format,
feature 3 > 7.0
where,
3 denotes feature index in feature vector present in 'features' column in dataframe (output of vector assembler)
Now, wwe can map feature index to feature name using meta data that vector assembler stores in its output column. Meta data stored is as follows:
df.schema['features'].metadata["ml_attr"]["attrs"]
below lines creates a dictionary that maps feature index to feature names
f_type_to_flist_dict = df.schema['features'].metadata["ml_attr"]["attrs"]f_index_to_name_dict = {}for f_type, f_list in f_type_to_flist_dict.items():
for f in f_list:
f_index = f['idx']
f_name = f['name']
f_index_to_name_dict[f_index] = f_name
Now lets define a dictionary that maps a operator string to actual operator as below. This will help us while trying to check whether current instance satisfies a rule in the decision tree node.
import operatoroperators = {
">=": operator.ge,
"<=": operator.le,
">": operator.gt,
"<": operator.lt,
"==": operator.eq,
'and': operator.and_,
'or': operator.or_
}
Now to get the rule that lead to a prediction for each instance, we can just go through nodes in dictionary of rules which features of current instance satisfy. Code for same is as shown below. generate_rules() function adds ‘rule’ column in input dataframe that contains rules that leads to a prediction for that particular instance. As function is too big to fit in a single screenshot so uploading multiple. generate_rules() contains two inner functions namely, parse_validate_cond() and extract_rule(). Former parses and validates rule/condition in a node and later recursively goes through nodes for each instance.
Finally, we can just collect dataframe rows in a list and check out rules that explains the prediction.
df = generate_rules(tree_as_dict,df,f_index_to_name_dict,operators)
result_rows = df.select('ball','keep','hall','fall','mall','prediction','rule').collect()
Output is as follows:
How we reach to prediction 21.0 for 1st row is visually presented in hand drawn figure at top.
References:
1. https://github.com/tristaneljed/Decision-Tree-Visualization-Spark/blob/master/DT.py