Chapter 10

Causal AI

Combining machine learning with causal reasoning

Machine learning is extraordinarily good at prediction. Given data, it can learn patterns that generalize remarkably well — to new images, new text, new users.

But there's something it consistently fails at: decisions.

When you use a model's predictions to make a decision, you're intervening on the world. You're not passively predicting; you're changing things. And the data the model was trained on — historical data, collected without your intervention — may be deeply misleading about what will happen when you act.

Causal AI is the project of combining machine learning's power with causal inference's precision. It's one of the most important frontiers in AI, and the direct intellectual ancestor of what Reinforce OS is building toward.


Why ML Alone Isn't Enough

The spurious correlation problem

A model trained on historical data learns associations — not causes. If ice cream sales and drowning rates are correlated (both increase in summer), a model trained on historical data might "learn" to predict drowning from ice cream sales. In a world where this never changes, that prediction is fine. But if you intervened to ban ice cream, the model would predict a drop in drowning — badly wrong.

This is the core problem: predictions from observational data don't transfer to interventional contexts.

Distribution shift

An ML model trained on data from one distribution often fails when deployed in a new distribution. A recommendation model trained on pre-COVID behavior fails to recommend well during lockdowns. A demand forecasting model trained in normal times fails during supply chain shocks.

The root cause: the model learned spurious correlations that happen to hold in the training distribution but break when conditions change. A model that learned causal relationships would be more robust — causes generalize across contexts; correlations don't.

The shortcut problem

Language models learn to predict text. Some of the patterns they learn are genuinely linguistic — grammar, coherence, factual knowledge. But many are shortcuts: surface patterns that work in training data but don't reflect understanding.

This is the same problem at scale. Without causal structure, a model doesn't know which patterns are invariant and which are spurious.


Invariant Risk Minimization

Invariant Risk Minimization (IRM) (Arjovsky et al., 2019) is a training objective that explicitly targets invariant features — features whose relationship to the outcome holds across different environments.

The setup: you have data from multiple environments eEe \in \mathcal{E} (different hospitals, different countries, different time periods). Standard ERM (Empirical Risk Minimization) minimizes total loss. IRM instead finds a representation Φ(x)\Phi(x) such that the optimal classifier on top of Φ(x)\Phi(x) is the same in every environment.

minΦ,weELe(wΦ)s.t.wargminwLe(wΦ)  e\min_{\Phi, w} \sum_{e \in \mathcal{E}} \mathcal{L}^e(w \circ \Phi) \quad \text{s.t.} \quad w \in \arg\min_{w'} \mathcal{L}^e(w' \circ \Phi) \; \forall e

The intuition: if the same simple predictor works across all training environments, it's probably capturing a genuine causal relationship — not a spurious one that happens to hold in one context.

IRM doesn't always outperform ERM in practice (it's a hard optimization problem), but the principle is sound: invariance across environments is evidence of causation.


Causal Discovery

All the methods so far assume you know the causal graph. But what if you don't?

Causal discovery algorithms try to infer the DAG structure from data alone. This is hard — for nn variables, there are super-exponentially many possible DAGs — but several approaches have made progress.

Constraint-based methods (PC algorithm)

  1. Start with a fully connected undirected graph
  2. Test conditional independences: remove edge XYX-Y if XYZX \perp Y \mid Z for some set ZZ
  3. Orient edges using v-structures: if XZYX - Z - Y and X̸ ⁣YX \not\!\perp Y, then XZYX \to Z \leftarrow Y (collider)
  4. Apply orientation rules to propagate edge directions

The PC algorithm is consistent in the large-sample limit under the faithfulness assumption: every conditional independence in the distribution corresponds to a d-separation in the graph.

Weakness: conditional independence tests are unreliable in finite samples, especially with many variables.

Score-based methods (GES)

Greedy Equivalence Search scores DAGs by how well they fit the data (BIC, BDe) and searches over the space of DAGs greedily. More reliable than constraint-based methods in practice, but still struggles with large variable sets.

Continuous optimization (NOTEARS)

Zheng et al. (2018) reformulated DAG learning as a continuous optimization problem by encoding the acyclicity constraint as a smooth differentiable function:

h(W)=tr(eWW)d=0h(W) = \text{tr}(e^{W \circ W}) - d = 0

where WW is the weighted adjacency matrix. This turns DAG search into gradient descent — practical for hundreds of variables.

NOTEARS and its successors (DAGMA, NoCurl) make causal discovery scalable. They're used in genomics, neuroscience, and systems biology.

Limitations

No causal discovery algorithm can distinguish XYX \to Y from YXY \to X purely from observational data in general — the joint distribution is the same under both DAGs unless you exploit functional form assumptions. Methods that exploit non-Gaussianity (LiNGAM) or nonlinearity (ANM) can sometimes orient edges, but require strong assumptions.

Causal discovery works best as a hypothesis generator, not a definitive answer. Use it to narrow down candidate DAGs, then use domain knowledge and experiments to confirm.


Counterfactual Prediction with ML

Beyond discovering causal structure, ML can improve counterfactual prediction — estimating what would happen under interventions we haven't tried.

Causal representations

If a model learns representations that correspond to causal variables — rather than entangled mixtures of cause and effect — it can reason about interventions. The do-operator becomes a manipulation of the representation: do(X=x)\text{do}(X = x) sets the representation of XX to xx while leaving other representations unchanged.

This is the goal of causal representation learning: unsupervised or self-supervised learning that produces causally interpretable representations.

Structural Causal Models + Neural Networks

A Structural Causal Model (SCM) specifies each variable as a function of its parents and independent noise:

Xi=fi(Pa(Xi),εi)X_i = f_i(\text{Pa}(X_i), \varepsilon_i)

If you parameterize fif_i with neural networks and train on observational data, you get a neural SCM that can:

  • Answer interventional queries: P(Ydo(X=x))P(Y \mid \text{do}(X = x))
  • Answer counterfactual queries: "what would YY have been if XX had been xx', given that I observed X=xX = x and Y=yY = y?"

These are the three levels of Pearl's ladder of causation:

  1. Association: P(YX=x)P(Y \mid X = x) — passive observation
  2. Intervention: P(Ydo(X=x))P(Y \mid \text{do}(X = x)) — active manipulation
  3. Counterfactual: P(YX=xX=x,Y=y)P(Y_{X=x'} \mid X = x, Y = y) — imagining an alternative past

Standard ML operates only at level 1. Causal AI climbs to levels 2 and 3.


Bareinboim's Transportability Theory

A practical problem: you ran an experiment in population A. Can you use those results in population B?

Not necessarily. Population B might have different distributions of confounders, mediators, or effect modifiers. Naively applying results across populations is one of the most common errors in applied science.

Transportability theory (Bareinboim & Pearl) formalizes when and how results can be transported. Given:

  • The causal graph for both populations
  • Which variables differ between populations (captured by selection nodes SS)
  • Data from both populations (experimental in A, observational in B)

...you can derive conditions under which the causal effect in B is identified, and the adjustment formula to compute it.

This is directly relevant to Reinforce OS: results from randomized experiments on one population of users might not directly apply to another user segment — but with the right adjustment, they can inform estimates.


3. Counterfactual"What if I had done X instead?"2. Intervention"What happens if I do X?"1. Association"What correlates with X?"
Pearl's ladder of causation: ML lives at rung 1. Causal AI climbs to rungs 2 and 3.

Large Language Models and Causality

A natural question: do large language models (LLMs) like GPT-4 "understand" causality?

The evidence is mixed. LLMs can often answer causal questions correctly, pass the Winograd schema, and reason about counterfactuals in language. But they fail in systematic ways:

  • They confuse correlation with causation in novel contexts
  • Their "causal" answers often reflect training data frequencies, not structural reasoning
  • They can't reliably distinguish P(YX)P(Y \mid X) from P(Ydo(X))P(Y \mid \text{do}(X))

The underlying issue: LLMs are trained to predict text (level 1 of the ladder). Without explicit causal structure in training or architecture, they can't reliably climb to levels 2 and 3.

Hybrid approaches — LLMs augmented with explicit causal graphs, SCMs, or causal reasoning modules — are an active research direction.


The Vision for Causal AI in Personal Optimization

Here's what Causal AI looks like as an end state for personal behavior optimization:

  1. Causal discovery from your behavioral logs: what affects your sleep? What causes your focus to drop?
  2. Heterogeneous effect estimation: does caffeine work differently on days when you exercised vs. didn't?
  3. Counterfactual simulation: "if I had gone to bed at 10pm instead of midnight, what would my focus score have been today?"
  4. Sequential decision optimization: given your current state (sleep score, stress level, schedule), what's the best set of behaviors for the next week?
  5. Transportability: when your data is sparse, borrow from users with similar profiles — with appropriate calibration

This is the arc from Chapter 1 (why randomize?) to here: from basic experimental design to a fully adaptive, causally-grounded personal optimization engine.

The pieces exist. The challenge is putting them together robustly, at scale, with real user data.


Summary

  • Standard ML learns associations; causal AI learns relationships that hold under intervention
  • Distribution shift and spurious correlations are symptoms of learning non-causal features
  • IRM targets invariant features that generalize across environments
  • Causal discovery algorithms (PC, GES, NOTEARS) infer DAG structure from data, with important limitations
  • Neural SCMs can answer interventional and counterfactual queries — levels 2 and 3 of Pearl's ladder
  • Transportability theory formalizes when experimental results generalize across populations
  • LLMs can mimic causal reasoning but don't reliably perform it; hybrid approaches are promising

This is the frontier. The methods are advancing fast. The applications — in health, policy, and behavior optimization — are just beginning.