Causal modeling has been recognized as a potential solution to many challenging problems in machine learning (ML). Here, we propose a counterfactual approach to remove/reduce the influence of confounders from the predictions generated a deep neural network (DNN). Rather than attempting to prevent DNNs from directly learning the confounding signal, we propose a counterfactual approach to remove confounding from the feature representations learned by DNNs in anticausal prediction tasks. By training an accurate DNN using softmax activation at the classification layer, and then adopting the representation learned by the last layer prior to the output layer as our features, we have that, by construction, the learned features will fit well a logistic regression model, and will be linearly associated with the labels. Then, in order to generate classifiers that are free from the influence of the observed confounders we: (i) use linear models to regress each learned feature on the labels and on the confounders and estimate the respective regression coefficients and model residuals; (ii) generate new counterfactual features by adding back to the estimated residuals to a linear predictor which no longer includes the confounder variables; and (iii) train and evaluate a logistic classifier using the counterfactual features as inputs. We validate the proposed methodology using colored versions of the MNIST and fashion-MNIST datasets, and show how the approach can effectively combat confounding and improve generalization in the context of dataset shift. Comparison against a variation of the SMOTE \cite{chawla2002} approach showed that the causality-aware approach compared favorably against SMOTE balancing in our experiments. Finally, we also describe how to use conditional independence tests to evaluate if the counterfactual approach has effectively removed the confounder signals from the predictions.
View on arXiv