Conformal Inference for Synthetic Controls
Contents
Click to show
import numpy as np
import pandas as pd
from toolz import curry
import seaborn as sns
from matplotlib import pyplot as plt
import statsmodels.formula.api as smf
import cvxpy as cp
import toolz as f
from sklearn.linear_model import Lasso
import warnings
warnings.filterwarnings('ignore')
from matplotlib import style
style.use("ggplot")
Conformal Inference for Synthetic Controls#
Synthetic Control Refresher#
Synthetic Control (SC) is a particularly useful causal inference technique for when you have a single treatment unit and very few control units, but you have repeated observation of each unit through time (although there are plenty of SC extensions in the Big Data world). The canonical use case is when you want to know the impact of the treatment in one geography (like a state) and you use the other untreated states as controls. In our Synthetic Control chapter, we’ve motivated the technique by trying to estimate the effect of Proposition 99 (a bill passed in 1988 that increased cigarette tax in California) in cigarette sales.
In order to do that, we have to estimate what would have happened to California, had it not passed Proposition 99. This boils down to estimating the counterfactual
There are many methods to do that, among which, we have Synthetic Controls. Synthetic Controls tries to model
data = pd.read_csv("data/smoking.csv")
data = data.pivot("year", "state", "cigsale")
data = data.rename(columns={c: f"state_{c}" for c in data.columns}).rename(columns={"state_3": "california"})
data.shape
(31, 39)
Click to show
plt.figure(figsize=(10,5))
plt.plot(data.drop(columns=["california"]), color="C1", alpha=0.5)
plt.plot(data["california"], color="C0", label="California")
plt.vlines(x=1988, ymin=40, ymax=300, linestyle=":", lw=2, label="Proposition 99", color="black")
plt.legend()
plt.ylabel("Cigarette Sales")
Text(0, 0.5, 'Cigarette Sales')

That is why we combine multiple treated units. The goal is, if we don’t have a good enough control, we can craft a synthetic one that resembles the treated unit the way we want.
In order to find the combination of states that better approximate the pretreatment trend of California, the Synthetic Control method runs a horizontal regression, where the rows are the time periods and the columns are the states. It tries to find the weights that, when multiplied by the control states, better approximate the treated state
Since we have more states (39, some were discarded from the analysis) than time periods, an unconstrained regression would simply overfit, which is why Synthetic Control imposes two restrictions:
Weights must sum to 1;
Weights must be non-negative;
Or, in mathematical terms, let
Combined, these constraints means we are defining the synthetic control as a convex combination of the control units. It also means we are not doing any dangerous extrapolation and that our synthetic control will use only a small subset of control units.
Here is what this looks like in code, as an Sklearn estimator:
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
import cvxpy as cp
class SyntheticControl(BaseEstimator, RegressorMixin):
def __init__(self,):
pass
def fit(self, X, y):
X, y = check_X_y(X, y)
w = cp.Variable(X.shape[1])
objective = cp.Minimize(cp.sum_squares(X@w - y))
constraints = [cp.sum(w) == 1, w >= 0]
problem = cp.Problem(objective, constraints)
problem.solve(verbose=False)
self.X_ = X
self.y_ = y
self.w_ = w.value
self.is_fitted_ = True
return self
def predict(self, X):
check_is_fitted(self)
X = check_array(X)
return X @ self.w_
Let’s apply this method to our data, fitting it in the pre-intervention period (prior to 1988).
model = SyntheticControl()
train = data[data.index < 1988]
model.fit(train.drop(columns=["california"]), train["california"]);
We can now plot, side by side the trend for California and for the synthetic control we’ve just created. The difference between these two lines is the estimated effect of Proposition 99 in California.
Click to show
plt.plot(data["california"], label="California")
plt.plot(data["california"].index, model.predict(data.drop(columns=["california"])), label="SC")
plt.vlines(x=1988, ymin=40, ymax=120, linestyle=":", lw=2, label="Proposition 99", color="black")
plt.legend();

From the look of this plot, it looks like Proposition 99 had a pretty big effect on the reduction of cigarette sales.
Click to show
pred_data = data.assign(**{"residuals": data["california"] - model.predict(data.drop(columns=["california"]))})
plt.plot(pred_data["california"].index, pred_data["residuals"], label="Estimated Effect")
plt.hlines(y=0, xmin=1970, xmax=2000, lw=2, color="Black")
plt.vlines(x=1988, ymin=5, ymax=-25, linestyle=":", lw=2, label="Proposition 99", color="Black")
plt.legend();

Inference for Grown Ups#
In the Synthetic Control chapter, we showed an inference procedure where we’ve permuted units, pretending control units where treated. This is also referred to as a placebo test, where we check the effect of units that haven’t gone through the treatment. If the estimated effect in the treated unit is bigger than most of the placebo effects, we say that this effect estimate is significant.
Click to show
plt.figure(figsize=(10,5))
for state in data.columns:
model_ier = SyntheticControl()
train_iter = data[data.index < 1988]
model_ier.fit(train_iter.drop(columns=[state]), train_iter[state])
effect = data[state] - model_ier.predict(data.drop(columns=[state]))
is_california = state == "california"
plt.plot(effect,
color="C0" if is_california else "C1",
alpha=1 if is_california else 0.5,
label="California" if is_california else None)
plt.hlines(y=0, xmin=1970, xmax=2000, lw=2, color="Black")
plt.vlines(x=1988, ymin=-50, ymax=100, linestyle=":", lw=2, label="Proposition 99", color="Black")
plt.ylabel("Effect Estimate")
plt.legend();

In our example, we can see that the post-treatment difference for California is quite extreme, when compared to the other states. However, there are also some states with terrible pre-treatment fit, which then translates to a huge error in the post-intervention period. The guideline here is to remove units with high pretreatment error, but how high is a bit more complicated. Not only that, this procedure assumes a random assignment of the intervention, which is hard to believe for this kind of policy intervention (see Abadie, 2021)
One alternative method for inference is to recast the problem of effect estimation as counterfactual prediction. If you think about it, all we are trying to do is predict the counterfactual
To understand this procedure, let’s first look at how we would do Hypothesis Tests and get P-Values.
Hypothesis Test and P-Values#
Let’s say we are interested in testing the Hypothesis about the trajectory of effects in the post treatment period
For instance, if we wish to test for no effect whatsoever, we can set
The key idea is to then generate data following the null hypothesis we want to test and check the residuals of a model for
The first step is to generate data under the null hypothesis. This is achieved by simply subtracting the postulated null from the outcome of the treated unit, just like in the equation above. Here is the code to do that.
def with_effect(df, state, null_hypothesis, start_at, window):
window_mask = (df.index >= start_at) & (df.index < (start_at +window))
y = np.where(window_mask, df[state] - null_hypothesis, df[state])
return df.assign(**{state: y})
Click to show
plt.plot(with_effect(data, "california", 0, 1988, 2000-1988+1)["california"], label="H0: 0")
plt.plot(with_effect(data, "california", -4, 1988, 2000-1988+1)["california"], label="H0: -4")
plt.ylabel("Y0 Under the Null")
plt.legend();

If we postulate the null of no effect, the data under that null means that
The next part of the inference procedure is to fit a model for the counterfactual
The function to do that first uses the with_effect
function we created earlier to generate data under then null. Then, it fits the model in this data under the null. Next, we estimate
@curry
def residuals(df, state, null, intervention_start, window, model):
null_data = with_effect(df, state, null, intervention_start, window)
model.fit(null_data.drop(columns=[state]), null_data[state])
y0_est = pd.Series(model.predict(null_data.drop(columns=[state])), index=null_data.index)
residuals = null_data[state] - y0_est
test_mask = (null_data.index >= intervention_start) & (null_data.index < (intervention_start + window))
return pd.DataFrame({
"y0": null_data[state],
"y0_est": y0_est,
"residuals": residuals,
"post_intervention": test_mask
})[lambda d: d.index < (intervention_start + window)] # just discard points after the defined effect window
With our data, to get the residuals for
model = SyntheticControl()
residuals_df = residuals(data,
"california",
null=0.0,
intervention_start=1988,
window=2000-1988+1,
model=model)
residuals_df.head()
y0 | y0_est | residuals | post_intervention | |
---|---|---|---|---|
year | ||||
1970 | 123.000000 | 112.529475 | 10.470525 | False |
1971 | 121.000000 | 114.315723 | 6.684277 | False |
1972 | 123.500000 | 119.302289 | 4.197711 | False |
1973 | 124.400002 | 121.265554 | 3.134447 | False |
1974 | 126.699997 | 124.356696 | 2.343301 | False |
The result is a dataframe containing the estimated residuals for each time period, something we will use going forward. Remember that the idea here is to see if that residual, in the post intervention period, is too high. If it is, the data is unlikely to have come from this null, where the effect is zero. To get a visual idea of what we are talking about, we can inspect the error of our model in the post intervention period.
Click to show
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 4))
residuals_df[["y0", "y0_est"]].plot(ax=ax1)
ax1.set_title("Y0 under H0: 0");
residuals_df[["residuals"]].plot(ax=ax2);
ax2.set_title("Residuals under H0: 0");

We can already see that the model fitted under
Test Statistic#
This visual evidence is interesting for our own understanding, but we need to be a bit more precise here. This is done by the definition of a Test Statistic S, which summarizes how big are the residuals and hence, how unikly is the data we saw, under the null.
Here, we focus on
Notice that this statistic is computed using only the post-intervention period, with
def test_statistic(u_hat, q=1, axis=0):
return (np.abs(u_hat) ** q).mean(axis=axis) ** (1/q)
print("H0:0 ", test_statistic(residuals_df.query("post_intervention")["residuals"]))
H0:0 12.602929955114083
High values of this test statistic indicate poor post intervention fit and, hence rejection of the null. However, we could have pretty big test statistics in the post-intervention period if our model is poorly fitted, even if
P-Value#
To compute the P-value, we block-permute the residuals, calculating the test statistic in each permutation. This procedure is better understood by the following picture
Once we do that, we will end up with
Let
and
To implement this, we will make use of the np.roll
function, which takes an array and circles it, mujustch like we’ve represented in the image above.
def p_value(resid_df, q=1):
u = resid_df["residuals"].values
post_intervention = resid_df["post_intervention"].values
block_permutations = np.stack([np.roll(u, permutation, axis=0)[post_intervention]
for permutation in range(len(u))])
statistics = test_statistic(block_permutations, q=1, axis=1)
p_val = np.mean(statistics >= statistics[0])
return p_val
We can now compute the P-value for
p_value(residuals_df)
0.16129032258064516
Remember, this is the P-value for the null hypothesis which states that the effect in all time periods is zero:
Confidence Intervals#
To understand how we can place a confidence interval around the effect of each post-treatment period, let’s first try to understand how we would define the confidence interval for a single time period. If we have a single period, then
P-value(H_0: -20) = 0.01
P-value(H_0: -19) = 0.01
P-value(H_0: -18) = 0.02
...
P-value(H_0: 18) = 0.03
P-value(H_0: 19) = 0.03
P-value(H_0: 20) = 0.02
With the functions we’ve defined, this can be achieved by first appending the period of interest (1988 in this example) at the end of the pre-intervention period, creating what is called an augmented dataset. Then, we iterate over the fine line of nulls, computing the p-value of a post-intervention window of size 1, which starts at the period of interest
def p_val_grid(df, state, nulls, intervention_start, period, model):
df_aug = df[df.index < intervention_start].append(df.loc[period])
p_vals = {null: p_value(residuals(df_aug,
state,
null=null,
intervention_start=period,
window=1,
model=model)) for null in nulls}
return pd.DataFrame(p_vals, index=[period]).T
model = SyntheticControl()
nulls = np.linspace(-20, 20, 100)
p_values_df = p_val_grid(
data,
"california",
nulls=nulls,
intervention_start=1988,
period=1988,
model=model
)
p_values_df
1988 | |
---|---|
-20.000000 | 0.052632 |
-19.595960 | 0.052632 |
-19.191919 | 0.052632 |
-18.787879 | 0.052632 |
-18.383838 | 0.052632 |
... | ... |
18.383838 | 0.052632 |
18.787879 | 0.052632 |
19.191919 | 0.052632 |
19.595960 | 0.052632 |
20.000000 | 0.052632 |
100 rows × 1 columns
As you can see, the result is a table where the row index is the null hypothesis and the row values are the p-values.
To build the confidence interval, all we need to do is filter out the
def confidence_interval_from_p_values(p_values, alpha=0.1):
big_p_values = p_values[p_values.values >= alpha]
return pd.DataFrame({
f"{int(100-alpha*100)}_ci_lower": big_p_values.index.min(),
f"{int(100-alpha*100)}_ci_upper": big_p_values.index.max(),
}, index=[p_values.columns[0]])
confidence_interval_from_p_values(p_values_df)
90_ci_lower | 90_ci_upper | |
---|---|---|
1988 | -12.323232 | 8.686869 |
This gives us the confidence interval for the effect in 1988.
We can also plot the
Click to show
plt.plot(p_values_df[1988], p_values_df.index)
plt.xlabel("P-Value")
plt.ylabel("H0")
plt.vlines(0.1, nulls.min(), nulls.max(), color="black", ls="dotted", label="0.1")
plt.hlines(confidence_interval_from_p_values(p_values_df)["90_ci_upper"], 0, 1, color="C1", ls="dashed")
plt.hlines(confidence_interval_from_p_values(p_values_df)["90_ci_lower"], 0, 1, color="C1", ls="dashed", label="90% CI")
plt.legend()
plt.title("Confidence Interval for the Effect in 1988");

All there’s left to do is repeat the procedure above for each time period. This means that, for each post intervention year, appending it to the end of the pre-intervention period to create the augmented dataset and then computing the confidence interval just like we’ve done above.
def compute_period_ci(df, state, nulls, intervention_start, period, model, alpha=0.1):
p_vals = p_val_grid(df=df,
state=state,
nulls=nulls,
intervention_start=intervention_start,
period=period,
model=model)
return confidence_interval_from_p_values(p_vals, alpha=alpha)
def confidence_interval(df, state, nulls, intervention_start, window, model, alpha=0.1, jobs=4):
return pd.concat([compute_period_ci(df, state, nulls, intervention_start, period, model, alpha)
for period in range(intervention_start, intervention_start+window)])
We are now ready to compute the confidence interval for all the post-intervention periods
model = SyntheticControl()
nulls = np.linspace(-60, 20, 100)
ci_df = confidence_interval(
data,
"california",
nulls=nulls,
intervention_start=1988,
window=2000 - 1988 + 1,
model=model
)
ci_df
90_ci_lower | 90_ci_upper | |
---|---|---|
1988 | -12.323232 | 8.686869 |
1989 | -16.363636 | 2.222222 |
1990 | -17.171717 | 5.454545 |
1991 | -19.595960 | -5.858586 |
1992 | -22.828283 | -7.474747 |
1993 | -32.525253 | -12.323232 |
1994 | -36.565657 | -15.555556 |
1995 | -43.030303 | -14.747475 |
1996 | -41.414141 | -16.363636 |
1997 | -48.686869 | -13.131313 |
1998 | -46.262626 | -13.939394 |
1999 | -46.262626 | -16.363636 |
2000 | -51.111111 | -17.979798 |
Click to show
plt.figure(figsize=(10,5))
plt.fill_between(ci_df.index, ci_df["90_ci_lower"], ci_df["90_ci_upper"], alpha=0.2, color="C1")
plt.plot(pred_data["california"].index, pred_data["residuals"], label="California", color="C1")
plt.hlines(y=0, xmin=1970, xmax=2000, lw=2, color="Black")
plt.vlines(x=1988, ymin=10, ymax=-50, linestyle=":", color="Black", lw=2, label="Proposition 99")
plt.legend()
plt.ylabel("Gap in per-capita cigarette sales (in packs)");

Reference#
This Appendix based on the paper An Exact and Robust Conformal Inference Method for Counterfactual and Synthetic Controls, by Victor Chernozhukov, Kaspar Wüthrich, Yinchu Zhu. I would like to give special thanks to Kaspar, who clarified a lot of the questions I had.
For additional resources on Synthetic Controls, check out Using Synthetic Controls: Feasibility, Data Requirements, and Methodological Aspects, by Alberto Abadie (2021).
Contribute#
Causal Inference for the Brave and True is an open-source material on causal inference, the statistics of science. It uses only free software, based in Python. Its goal is to be accessible monetarily and intellectually. If you found this book valuable and you want to support it, please go to Patreon. If you are not ready to contribute financially, you can also help by fixing typos, suggesting edits or giving feedback on passages you didn’t understand. Just go to the book’s repository and open an issue. Finally, if you liked this content, please share it with others who might find it useful and give it a star on GitHub.