Dropping Columns with Low Observations
After doing a lot of feature engineering it's a good idea to take a step back and look at what you've created. If you've used some automation techniques on your categorical features like exploding or OneHot Encoding you may find that you now have hundreds of new binary features. While the subject of feature selection is material for a whole other course but there are some quick steps you can take to reduce the dimensionality of your data set.
In this exercise, we are going to remove columns that have less than 30 observations. 30 is a common minimum number of observations for statistical significance. Any less than that and the relationships cause overfitting because of a sheer coincidence!
NOTE: The data is available in the dataframe, df
.
This exercise is part of the course
Feature Engineering with PySpark
Exercise instructions
- Using the provided
for
loop that iterates through the list of binary columns, calculate thesum
of the values in the column using theagg
function. Usecollect()
to run the calculation immediately and save the results toobs_count
. - Compare
obs_count
toobs_threshold
, theif
statement should be true ifobs_count
is less than or equal toobs_threshold
. - Remove columns that have been appended to
cols_to_remove
list by usingdrop()
. Recall that the*
allows the list to be unpacked. - Print the starting and ending shape of the PySpark dataframes by using
count()
for number of records andlen()
ondf.columns
ornew_df.columns
to find the number of columns.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
obs_threshold = 30
cols_to_remove = list()
# Inspect first 10 binary columns in list
for col in binary_cols[0:10]:
# Count the number of 1 values in the binary column
obs_count = df.____({col: ____}).____()[0][0]
# If less than our observation threshold, remove
if ____ ____ ____:
cols_to_remove.append(col)
# Drop columns and print starting and ending dataframe shapes
new_df = df.____(*____)
print('Rows: ' + str(df.____()) + ' Columns: ' + str(____(df.____)))
print('Rows: ' + str(new_df.____()) + ' Columns: ' + str(____(new_df.____)))