Dropping Columns with Low Observations
After doing a lot of feature engineering it's a good idea to take a step back and look at what you've created. If you've used some automation techniques on your categorical features like exploding or OneHot Encoding you may find that you now have hundreds of new binary features. While the subject of feature selection is material for a whole other course but there are some quick steps you can take to reduce the dimensionality of your data set.
In this exercise, we are going to remove columns that have less than 30 observations. 30 is a common minimum number of observations for statistical significance. Any less than that and the relationships cause overfitting because of a sheer coincidence!
NOTE: The data is available in the dataframe, df.
This exercise is part of the course
Feature Engineering with PySpark
Exercise instructions
- Using the provided
forloop that iterates through the list of binary columns, calculate thesumof the values in the column using theaggfunction. Usecollect()to run the calculation immediately and save the results toobs_count. - Compare
obs_counttoobs_threshold, theifstatement should be true ifobs_countis less than or equal toobs_threshold. - Remove columns that have been appended to
cols_to_removelist by usingdrop(). Recall that the*allows the list to be unpacked. - Print the starting and ending shape of the PySpark dataframes by using
count()for number of records andlen()ondf.columnsornew_df.columnsto find the number of columns.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
obs_threshold = 30
cols_to_remove = list()
# Inspect first 10 binary columns in list
for col in binary_cols[0:10]:
# Count the number of 1 values in the binary column
obs_count = df.____({col: ____}).____()[0][0]
# If less than our observation threshold, remove
if ____ ____ ____:
cols_to_remove.append(col)
# Drop columns and print starting and ending dataframe shapes
new_df = df.____(*____)
print('Rows: ' + str(df.____()) + ' Columns: ' + str(____(df.____)))
print('Rows: ' + str(new_df.____()) + ' Columns: ' + str(____(new_df.____)))