A function that makes quick,but useful plots for exploring binary flags as features for a binary classification problem

Motivation

I was recently working on a classification project at work, where most of my features were binary flags. I was unable to find a good pre-built visualization tool to explore the flags and how they related with the response variable, so I wrote my own.

# import libraries
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt


# formatting for plots
#plotting format details
MEDIUM_SIZE = 20
BIGGER_SIZE = 22
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)     # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels

The function I wrote is defined below

def check_flag_hits(data, flags, target='target', rotation = 0,output_file=None):
    """
    explore relationship of binary features with 
    
    ~~Inputs~~
    data: pandas dataframe with the data
    flags: list of col names of binary flags
    target: name of target response variable (should also be binary)
    output_file: name of file to save the plot to (should be '.png')
    
    ~~Outputs~~
    None, just generates the plots
    """
    flag_hit_counts = []
    target_list = []
    target_perc_list = []
    for flag in flags:
        flag_sum = data[flag].sum()
        flag_hit_counts.append(flag_sum)
        target_sum = data[data[flag] == 1][target].sum()
        target_list.append(target_sum)
        target_perc = data[data[flag] == 1][target].mean() * 100
        target_perc_list.append(target_perc)
        
    plt.figure(figsize=(25,8))
    nrows = 1
    ncols = 3
    
    plt.subplot(nrows, ncols, 1)
    plt.barh(flags, flag_hit_counts)
    plt.title("Times Flag is Triggered")
    plt.xticks(rotation=rotation)

    plt.subplot(nrows, ncols, 2)
    plt.barh(flags, target_list)
    plt.title("Count of "+ target + " for leases with that Flag")
    plt.xticks(rotation=rotation)

    plt.subplot(nrows, ncols, 3)
    plt.barh(flags, target_perc_list)
    plt.title("Percent of "+ target + " for leases with that Flag")
    plt.tight_layout()
    
    if output_file:
        plt.savefig(output_file)
    else:
        plt.show()
    plt.close()
    
    return
# read in some sample data
df = pd.read_csv("sample.csv")
df
id flag1 flag2 flag3 response
0 1 1 1 0 0
1 2 1 0 0 0
2 3 0 1 0 0
3 4 0 0 0 0
4 5 1 1 0 0
5 6 1 0 1 0
6 7 0 1 1 0
7 8 0 0 1 0
8 9 1 1 1 0
9 10 1 0 1 0
10 11 0 1 1 1
11 12 0 0 1 1
12 13 1 1 1 1
13 14 1 0 1 1
14 15 0 0 1 1
15 16 0 0 1 1
16 17 1 0 1 1
17 18 1 0 1 1
18 19 0 0 1 1
19 20 0 0 1 1
check_flag_hits(data=df, flags=df.columns[1:4], target='response')

png

That’s all there is to it. For larger datasets, sometime the numbers on the first and second charts may need to be rotated, but I included a parameter for that in the function. I hope you find this helpful! If you do, or if you have feedback, please connect with and message me on LinkedIn.