Decoding predictions of decision tree algorithm in pyspark

Dipawesh Pawar
4 min readSep 17, 2020

--

example decision tree and nodes involved in prediction 21.0 for sample feature values shown on right corner

If we want to analyse reasons behind why particular instance is predicted to belong to certain class by decision tree model, we need to parse the decision tree produced during training.

There are several posts that explain how same can be achieved with scikit learn decision tree model, however there are very few for pyspark decision tree model. Those that have tried are hard to understand. So here, I will try to elucidate it.

Note: I have tested code only for numerical features. If you also have some categorical ones, code should work but that needs to be tested.

Many of the times later in post, I have included screenshots of code snippets. It would be troublesome if you just want to use the code. So providing a link to one of my answer on stackoverflow for similar question. You can just copy the code from there.

let’s define a sample dataframe as below.

import pandas as pd
from pyspark.sql import SparkSession
spark_session = SparkSession.builder.getOrCreate()
data = pd.DataFrame({
'ball': [0, 1, 1, 3, 1, 0, 1, 3],
'keep': [4, 5, 6, 7, 7, 4, 6, 7],
'hall': [8, 9, 10, 11, 2, 6, 10, 11],
'fall': [12, 13, 14, 15, 15, 12, 14, 15],
'mall': [16, 17, 18, 10, 10, 16, 18, 10],
'label': [21, 31, 41, 51, 51, 51, 21, 31]
})
df = spark.createDataFrame(data)
df.show()
df.show()
sample dataframe

let us train a pyspark decision tree model on this sample dataframe. As all columns are numeric we just need to assemble them in one column using vector assembler and use that as a feature column for training decision tree.

from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml import Pipeline
# feature columns
f_list = ['ball','keep','mall','hall','fall']
assemble_numerical_features = VectorAssembler(inputCols=f_list, outputCol='features',handleInvalid='skip')dt = DecisionTreeClassifier(featuresCol='features',labelCol='label')pipeline = Pipeline(stages=[assemble_numerical_features, dt])
model = pipeline.fit(df)
df = model.transform(df)
dt_m = model.stages[-1]

Trained Decision tree model rules in string format is as below

print(dt_m.toDebugString)
decision tree model rules

Now, task reduces to parsing these rules. First we will convert them from string to dictionary of nodes and their connection as below.

function that converts model rules in string format to dictionary of nodes and their connection

last line will result in following output. Hand drawn figure at top shows below output in tree form.

decision tree rules as dictionary of nodes and their connection

As we can see in above dictionary, rules are in format,

feature 3 > 7.0
where,
3 denotes feature index in feature vector present in 'features' column in dataframe (output of vector assembler)

Now, wwe can map feature index to feature name using meta data that vector assembler stores in its output column. Meta data stored is as follows:

df.schema['features'].metadata["ml_attr"]["attrs"]
meta data that vector assembler stores in its output column

below lines creates a dictionary that maps feature index to feature names

f_type_to_flist_dict = df.schema['features'].metadata["ml_attr"]["attrs"]f_index_to_name_dict = {}for f_type, f_list in f_type_to_flist_dict.items():

for f in f_list:
f_index = f['idx']
f_name = f['name']
f_index_to_name_dict[f_index] = f_name
dictionary mapping feature index to feature names

Now lets define a dictionary that maps a operator string to actual operator as below. This will help us while trying to check whether current instance satisfies a rule in the decision tree node.

import operatoroperators = {
">=": operator.ge,
"<=": operator.le,
">": operator.gt,
"<": operator.lt,
"==": operator.eq,
'and': operator.and_,
'or': operator.or_
}

Now to get the rule that lead to a prediction for each instance, we can just go through nodes in dictionary of rules which features of current instance satisfy. Code for same is as shown below. generate_rules() function adds ‘rule’ column in input dataframe that contains rules that leads to a prediction for that particular instance. As function is too big to fit in a single screenshot so uploading multiple. generate_rules() contains two inner functions namely, parse_validate_cond() and extract_rule(). Former parses and validates rule/condition in a node and later recursively goes through nodes for each instance.

generate_rules() part 1
generate_rules() part 2
generate_rules() part 3

Finally, we can just collect dataframe rows in a list and check out rules that explains the prediction.

df = generate_rules(tree_as_dict,df,f_index_to_name_dict,operators)
result_rows = df.select('ball','keep','hall','fall','mall','prediction','rule').collect()

Output is as follows:

rules that explains the prediction for each instance

How we reach to prediction 21.0 for 1st row is visually presented in hand drawn figure at top.

References:

1. https://github.com/tristaneljed/Decision-Tree-Visualization-Spark/blob/master/DT.py

--

--

Responses (2)