Techniques for masking positions in sequence models to prevent attention to padded elements or future tokens (causal masking) during autoregressive generation; includes mask construction and integration with attention score computation. Masks are crucial for correct training and inference in Transformers.
Self-serve tutorial - low prerequisites, straightforward concepts.
Attention is only “correct” if it looks at the right places. In real Transformer training/inference, two kinds of “wrong places” appear constantly: (1) padding tokens that are not real data, and (2) future tokens that the model must not peek at during autoregressive generation. Sequence masking is the practical mechanism that prevents both failures—by editing attention scores before softmax so disallowed positions receive (almost) zero probability.
An attention mask is a binary allow/block matrix M (1 = allow, 0 = block) applied to attention logits before softmax, typically via adding a large negative number (≈ −∞) where M=0. Padding masks block attention to padded keys (and sometimes padded queries). Causal masks block attention to future keys (k > q) for autoregressive decoding. In practice you broadcast masks to (B,H,Lq,Lk) and combine them with logical AND (or additive −∞) before computing softmax.
In attention, every query position produces a distribution over key positions :
The softmax turns each row of scores into weights that sum to 1. That “sum to 1” is the core reason masking matters: if your sequence contains invalid tokens (padding) or forbidden tokens (future positions during autoregressive decoding), softmax will still allocate some probability mass to them unless you explicitly prevent it.
Masking is the standard way to encode structural constraints into attention: “these keys may be attended to, those keys must not.”
An attention mask is a matrix with entries:
Shape-wise, the conceptual mask is often , where:
In self-attention, typically .
Masks are applied to the attention logits (scores) before the softmax. The most common implementation is additive masking:
Then
Because , masked positions get zero probability.
In real code, you don’t literally use ; you use a large negative constant like or (depending on dtype and stability). The concept is “so negative that softmax assigns ~0.”
They solve different problems and are often combined.
Think of attention as a directed weighted graph from queries to keys. Masking removes edges before normalization.
That’s it: masking is not a learned component; it’s a rule baked into the computation.
In autoregressive language modeling, at position the model must predict token using only past context .
Self-attention without constraints would let position attend to future tokens during training (because they are present in the same sequence). That creates label leakage: the model can “cheat” by looking at the answer.
Causal masking prevents this by allowing attention only to the prefix.
For a sequence length with 0-indexed positions:
So the mask is lower-triangular.
Let and write 1=allow, 0=block. Rows are queries , columns are keys .
Causal mask M_causal (L=6):
k: 0 1 2 3 4 5
q=0 1 0 0 0 0 0
q=1 1 1 0 0 0 0
q=2 1 1 1 0 0 0
q=3 1 1 1 1 0 0
q=4 1 1 1 1 1 0
q=5 1 1 1 1 1 1This is the “causal triangle.” It encodes “no looking ahead.”
A clean definition is:
where is the indicator function.
If you use an additive mask that adds where blocked:
Then you compute .
Causal masking is a task constraint, not a Transformer requirement.
Real batches contain sequences with different lengths. To form a rectangular tensor, shorter sequences are padded to a common length .
Padding tokens are not real data. If attention can attend to them, several things go wrong:
Padding masks ensure padded keys do not receive attention.
In attention, the output at query is a weighted sum over values at keys .
If key is padding, then value vₖ is meaningless; we must prevent it from contributing.
So most padding masks are applied along the key dimension: for each query, block the same set of padded key positions.
Often you start with a 1D “valid tokens” vector for each sequence:
This is shape (or for a batch).
To turn that into an attention mask you broadcast across queries:
So each row is identical: any query cannot attend to padded keys.
Suppose we have a batch element with length 4 padded to :
Positions: 0 1 2 3 4 5
Tokens: A B C D <pad> <pad>
Validity: 1 1 1 1 0 0
Padding-only mask (1=allow, 0=block) applied on keys:
M_pad (allow real keys only):
k: 0 1 2 3 4 5
q=0 1 1 1 1 0 0
q=1 1 1 1 1 0 0
q=2 1 1 1 1 0 0
q=3 1 1 1 1 0 0
q=4 1 1 1 1 0 0
q=5 1 1 1 1 0 0Causal mask (from earlier):
M_causal:
k: 0 1 2 3 4 5
q=0 1 0 0 0 0 0
q=1 1 1 0 0 0 0
q=2 1 1 1 0 0 0
q=3 1 1 1 1 0 0
q=4 1 1 1 1 1 0
q=5 1 1 1 1 1 1Combined mask uses logical AND (allow only if both allow):
Result:
M_combined = M_causal AND M_pad:
k: 0 1 2 3 4 5
q=0 1 0 0 0 0 0
q=1 1 1 0 0 0 0
q=2 1 1 1 0 0 0
q=3 1 1 1 1 0 0
q=4 1 1 1 1 0 0
q=5 1 1 1 1 0 0Notice how:
There are two related concerns:
In practice, padding keys is non-negotiable; padding queries is optional but can save compute and avoid odd numerical artifacts.
If your implementation uses additive masks (0 for allow, −∞ for block), then combining is simple:
Because:
This “sum of additive masks” is a very common pattern in Transformer codebases.
For one batch element and one head, attention typically follows:
1) Compute scores:
2) Apply mask to get masked logits .
3) Softmax over keys:
4) Weighted sum of values:
Masking must happen before softmax. Masking after softmax is not equivalent because the distribution would no longer be normalized correctly.
In multi-head attention with batching, you commonly work with these shapes:
You want a mask that can broadcast to (B, H, Lq, Lk).
Here are typical mask shapes and how they broadcast:
Target scores S: (B, H, Lq, Lk)
Causal mask: (1, 1, Lq, Lk) or (Lq, Lk)
Padding key mask: (B, 1, 1, Lk) from (B, Lk)
Combined mask: (B, 1, Lq, Lk) after broadcast AND/addKey principle:
| Representation | Allowed? | Typical values | Combine rule | Pros | Cons |
|---|---|---|---|---|---|
| Binary allow mask | 1=allow | {0,1} or {False,True} | AND | Conceptually clear | You still need to convert to additive for logits |
| Additive bias mask | 0=allow | {0, −∞} (or large negative) | add | Directly added to scores | Must choose safe negative constant |
Softmax is typically computed using a stabilized form:
If masked positions are very negative, , which is what you want.
But two pitfalls:
Typical engineering fixes:
Let be scores and be a binary mask.
Convert to an additive mask:
Then apply:
Then:
Finally output:
Masks are a “plumbing” idea that becomes crucial in:
Once you can build and broadcast these masks correctly, multi-head attention and full Transformers become much easier to implement without silent correctness bugs.
We have a single-head self-attention with L=3. Scores (already scaled by 1/√d) are:
S =
[ [ 2.0, 1.0, 0.0 ],
[ 1.0, 3.0, 2.0 ],
[ 0.0, 1.0, 4.0 ] ]
We want causal attention: query q can attend only to keys k ≤ q. Use additive mask with −∞ for blocked positions.
Step 1: Write the causal binary mask M (1=allow, 0=block) for L=3:
M =
[ [1,0,0],
[1,1,0],
[1,1,1] ]
Step 2: Convert to additive mask B (0 for allow, −∞ for block):
B =
[ [0, −∞, −∞],
[0, 0, −∞],
[0, 0, 0] ]
Step 3: Add B to scores: S' = S + B:
S' =
[ [ 2.0, −∞, −∞ ],
[ 1.0, 3.0, −∞ ],
[ 0.0, 1.0, 4.0 ] ]
Step 4: Softmax each row.
Row q=0: softmax([2.0, −∞, −∞]) = [1, 0, 0]
Row q=1: softmax([1.0, 3.0, −∞])
Compute stabilized:
max = 3.0
exp([1−3, 3−3, −∞]) = exp([−2, 0, −∞]) = [e^(−2), 1, 0]
Sum = e^(−2) + 1
So weights = [ e^(−2)/(1+e^(−2)), 1/(1+e^(−2)), 0 ]
Numerically e^(−2)≈0.1353 ⇒ weights ≈ [0.1192, 0.8808, 0]
Row q=2: softmax([0.0, 1.0, 4.0])
max = 4.0
exp([−4, −3, 0]) = [e^(−4), e^(−3), 1] ≈ [0.0183, 0.0498, 1]
Sum ≈ 1.0681
weights ≈ [0.0171, 0.0466, 0.9363]
Insight: Masking changes the probability simplex that softmax operates on: it removes forbidden keys before normalization. Notice q=1 would have put some mass on k=2 without masking (because score 2.0 is high), but causal masking forces that mass to be redistributed among allowed keys.
We have B=2 sequences padded to L=5, and we run H=4 heads of decoder self-attention (so we need causal masking).
Sequence lengths: [3, 5]
So valid tokens p (1=real, 0=pad) are:
We want a final additive mask B_combined broadcastable to scores S of shape (B,H,L,L).
Step 1: Start with padding validity p of shape (B,L):
p =
[ [1,1,1,0,0],
[1,1,1,1,1] ] shape (2,5)
Step 2: Convert padding validity into a key-mask over attention scores.
We want shape (B,1,1,Lk) so it broadcasts across heads and queries.
M_pad_key[b,1,1,k] = p[b,k]
So M_pad_key has shape (2,1,1,5).
Step 3: Build causal mask once for L=5.
Binary causal mask M_causal has shape (1,1,Lq,Lk) = (1,1,5,5):
For q,k in 0..4:
M_causal[1,1,q,k] = 1 if k ≤ q else 0.
Step 4: Combine binary masks with AND via broadcasting:
M_combined = M_causal AND M_pad_key
Broadcast reasoning:
Result: (2,1,5,5)
Then it can broadcast to (2,4,5,5) across heads.
Step 5: Convert to additive mask B_combined.
A common conversion is:
B_combined = (1 − M_combined) * (−C)
where C is a large constant like 1e9 (float32) or 1e4 (fp16).
Shape is (2,1,5,5), broadcastable to (2,4,5,5).
Step 6: Apply to scores.
If scores S are (B,H,L,L) = (2,4,5,5), then masked logits are:
S' = S + B_combined
Because B_combined broadcasts over H, every head uses the same structural constraint.
Insight: Most masking bugs are shape bugs. If you remember two canonical shapes—causal as (1,1,L,L) and padding-as-keys as (B,1,1,L)—then “AND then broadcast to (B,H,L,L)” becomes almost mechanical.
Suppose in a batch you include sequences with length 0 after truncation (or you mistakenly treat all positions as padding). For one batch element, the padding validity is p=[0,0,0,0]. You build M_pad and apply it to attention logits, then softmax returns NaNs.
Step 1: Observe what happens to one query row.
If all keys are masked, then masked logits look like:
S'_{q,:} = [−∞, −∞, −∞, −∞]
Step 2: Softmax is undefined in this case.
Stabilized softmax subtracts max, but max is −∞, leading to indeterminate forms like exp(−∞ − (−∞)). Many kernels output NaN.
Step 3: Fix options.
Option A (data hygiene): never create length-0 sequences; filter them out.
Option B (force at least one allowed key): if a row would be fully masked, unmask a safe position (often k=0) just to avoid NaNs.
Option C (mask padded queries): don’t compute attention outputs for padded query positions (or set their outputs to 0) and ensure loss ignores them.
Step 4: Preferred fix in Transformers.
Most training pipelines ensure each example has at least 1 real token. Additionally, they ignore padded positions in the loss, so padded queries don’t matter.
Insight: Masking is a correctness constraint, but it can create undefined softmax rows if you accidentally mask everything. Robust systems treat this as an invariant: every query row must have at least one valid key, or the query itself is ignored.
Sequence masking edits attention logits before softmax so disallowed positions receive ~0 probability.
Causal masks enforce autoregressive behavior via a lower-triangular (L×L) structure: allow k ≤ q, block k > q.
Padding masks block attention to padded keys; they usually start from a (B,L) validity vector and broadcast to (B,1,1,Lk).
Masks are commonly implemented as additive biases: 0 for allowed, −∞ (or a large negative) for blocked; combined masks add together.
Correct broadcasting targets scores of shape (B,H,Lq,Lk); typical causal mask is (1,1,Lq,Lk).
Masking keys/values is essential; masking queries is optional but can prevent NaNs and wasted compute when queries are padding.
Watch out for the “all keys masked” row, which can cause NaNs in softmax.
Applying the mask after softmax (this breaks normalization and still allows masked positions to influence the distribution indirectly).
Using the wrong mask orientation (masking queries instead of keys, or transposing (Lq×Lk) so the triangle points the wrong way).
Broadcasting to the wrong shape (e.g., (B,L) added directly to (B,H,L,L) without expanding dims), causing silent incorrect masking.
Accidentally masking all keys for some queries (often due to length bugs), leading to NaNs.
Construct the binary causal mask M for L=4 (1=allow, 0=block). Then indicate which entries are blocked for query q=2.
Hint: Causal means allow k ≤ q. Write a 4×4 lower-triangular matrix of ones.
For L=4 (q,k ∈ {0,1,2,3}):
M =
[ [1,0,0,0],
[1,1,0,0],
[1,1,1,0],
[1,1,1,1] ]
For query q=2, keys k=3 is blocked (M[2,3]=0). Keys 0,1,2 are allowed.
You have a batch with B=3 sequences padded to L=6. Their lengths are [6, 2, 4]. Build the padding validity matrix p of shape (B,L) using 1 for real tokens and 0 for pads. Then state the shape you would broadcast it to for masking keys in attention scores of shape (B,H,L,L).
Hint: Each row has 'length' ones followed by zeros. Key masking typically becomes (B,1,1,L).
Validity p (B=3, L=6):
So
p =
[ [1,1,1,1,1,1],
[1,1,0,0,0,0],
[1,1,1,1,0,0] ]
To mask keys for scores (B,H,L,L), broadcast to (B,1,1,L) = (3,1,1,6).
Given attention logits for one query row s = [5, 1, 0] and a binary mask m = [1, 0, 1] (so key 1 is blocked), compute the masked softmax weights exactly in terms of exponentials, and approximately as decimals.
Hint: Set the blocked logit to −∞, then softmax over the remaining two positions. Use stabilization by subtracting max=5.
Masking gives s' = [5, −∞, 0].
Softmax weights:
Stabilize by subtracting 5:
Denominator = exp(0) + exp(−5) = 1 + e^(−5)
So
w₀ = 1 / (1 + e^(−5))
w₂ = e^(−5) / (1 + e^(−5))
Numerically e^(−5)≈0.006737:
w₀≈ 1 / 1.006737 ≈ 0.993307
w₁=0
w₂≈ 0.006693
Next, use masking inside full attention blocks and architectures:
Related ideas you’ll likely encounter alongside masking: