EM-Vectorization
The EM algorithm for the two-coin problem estimates the biases of two coins when outcomes are mixed and unlabeled. Vectorization replaces nested loops with matrix operations, letting PyTorch handle $[B, N]$ computations in parallel. The E-step assigns soft probabilities (responsibilities) of each outcome belonging to coin A or B. The M-step updates the coin biases, and the best initialization is chosen via log-likelihood maximization.
Vectorization in the EM Algorithm: A Two-Coin Example
Introduction
The Expectation-Maximization (EM) algorithm is a core method for estimating parameters in probabilistic models with latent variables. In the classic two-coin problem, each observed flip comes from either Coin A or Coin B (we do not observe which). Each coin has an unknown bias:
We observe a sequence of flips (o_1, o_2, \dots, o_N) where each (o_n \in {0,1}) (1=heads, 0=tails). The EM algorithm estimates (\pi_A) and (\pi_B) from these incomplete observations. This post explains how to vectorize EM so we can run many independent EM initializations in parallel using tensor operations (PyTorch), why that is faster, and how the math maps to the code.
Step 1 — Probability of a Single Observation
For a single outcome (o\in{0,1}) and a coin with bias (\pi):
This single expression covers both cases:
- If (o=1), the expression becomes (\pi).
- If (o=0), the expression becomes (1-\pi).
This algebraic trick lets us compute probabilities for heads or tails without branching.
Step 2 — Vectorization setup (shapes & broadcasting)
We will perform (B) EM runs in parallel (different random inits). Let:
- (N) = number of observed flips.
- (B) = number of parallel EM runs.
Represent data and parameters as tensors:
- Outcomes tensor (row vector): (\mathbf{o}\in\mathbb{R}^{1\times N}).
- Parameter vectors: (\boldsymbol{\pi}_A, \boldsymbol{\pi}_B \in \mathbb{R}^B).
We use unsqueeze to reshape for broadcasting:
Broadcasting expands these to shape ([B, N]) so elementwise operations act across all runs and all observations at once.
Step 3: The E-Step (Expectation Step)
In the E-step, we compute the "responsibilities" — i.e., how much each coin (A or B) is responsible for generating each observed outcome.
For a single data point :
- Probability that coin A generated :
- Probability that coin B generated :
The intuition:
- If (heads), the formula reduces to or .
- If (tails), the formula reduces to or .
In vectorized form:
prob_Abecomes a matrix, where each row corresponds to a different EM run and each column to a different data point.- Same for
prob_B.
Then we normalize to compute responsibilities:
This gives us two matrices of weights, and .
Step 4: The M-Step (Maximization Step)
In the M-step, we update the coin biases using weighted averages.
For coin A:
For coin B:
Interpretation:
- The numerator counts how many "effective heads" were assigned to each coin.
- The denominator counts the total "effective flips" assigned to each coin.
- Together, they compute the weighted fraction of heads.
In code:
Total_heads_A = torch.sum(gamma_A * outcomes, dim=1)Total_tails_A = torch.sum(gamma_A * (1 - outcomes), dim=1)pi_A = Total_heads_A / (Total_heads_A + Total_tails_A + eps)
The same applies for coin B.
Step 5: Log-Likelihood Calculation
To choose the best run (since EM can converge to local optima), we compute the log-likelihood for each run:
This measures how well the current parameters explain the observed data.
In code:
- Compute
prob_Aandprob_Bagain. - Average them:
total_prob = 0.5 * prob_A + 0.5 * prob_B - Take logs:
log_probs = torch.log(total_prob + eps) - Sum across data points.
The output is one log-likelihood per run, shape .
Step 6: Selecting the Best Run
Since EM is sensitive to initialization, we repeat the process times with different random initial guesses.
At the end, we select the run with the highest log-likelihood:
This ensures we find the most plausible parameters for and .
Step 7: Vectorization Summary
The key insight of vectorization is that instead of looping over:
- each data point
- and each initialization
we let PyTorch broadcasting handle the alignment of shapes.
outcomes: shapedpi_A.unsqueeze(1): shaped- Result: a broadcasted matrix
This means:
- One matrix stores all probabilities for coin A across all runs and all data points.
- Another for coin B.
- Then, all operations (sums, divisions, logs) are applied in parallel.
This drastically reduces Python-level loops and leverages optimized tensor math on GPU/MPS.
Final Thoughts
By vectorizing the EM algorithm:
- We can run dozens or hundreds of EM initializations in parallel.
- The code is simpler (no nested loops).
- The math maps directly to tensor operations.
This design is critical when working with larger datasets or more complex mixture models, where performance and stability matter.