Where this came from
I came across Vlad Feinberg's article on breaking into a frontier lab a few weeks ago. One of his suggested exercises is to derive scaling laws empirically on a narrow task, using only free compute. The Chinchilla paper shows that compute-optimal training requires balancing model size and training tokens according to a specific power law — but that was derived for large language models on internet text. The question that stuck with me: does the same functional form show up on a deterministic, narrow task that fits entirely on a Kaggle T4?
The task I chose is 3-digit addition. The goal is not to solve addition — a transformer can do that in a few thousand steps. The goal is to empirically characterize how loss scales with model size N and training tokens D, then fit L(N, D) = A / N^α + B / D^β, the same form Hoffmann et al. used. Part 1 is the work that makes the sweep valid: getting the training procedure right (three attempts), finding and fixing the bugs that would have silently corrupted the results, running architecture ablations, and locking a configuration. The sweep itself is Part 2.
The task
Given a + b, predict c. Both operands are integers from 0 to 999, so c ranges from 0 to 1998. There are exactly 1,000,000 unique (a, b) pairs. We hold out 100,000 as a validation set. The split is combination-aware: each pair is keyed by its canonical sorted form, so if (3, 5) lands in val, the associative pair (5, 3) is removed from train — it does not appear in val, it simply cannot appear in training either. This prevents the model from seeing a structurally equivalent example during training and calling it generalisation.
Tokenization is character-level: digits 0–9 (indices 0–9), + (10), = (11), PAD (12), EOS (13). Vocabulary size 14. The sequence 999+999=1998 becomes [9,9,9,10,9,9,9,11,1,9,9,8,13] — 13 tokens, which sets MAX_SEQ_LEN.
The task being deterministic matters for the scaling law. The irreducible error E in the Chinchilla formula is approximately zero — there is a single correct answer for every input. Loss should scale to zero with enough capacity and data, with no noise floor to estimate separately.
Three attempts to get training right
Getting the training procedure correct took three distinct attempts. Each one produced a silent failure — a model that appeared to train (loss going down, no errors) but could not generate correct answers. The wrong turns are worth documenting in detail because they each point at a different failure mode.
The first instinct was to treat each number as a single token. Vocabulary: integers 0–999 as input tokens, 0–1998 as output tokens, plus + and =. Sequence: [a, +, b, =, c] — five tokens. Loss computed only on position 4 (the answer).
The reasoning seemed sound — let the model learn that 247 and 381 as atomic units compose to 628. This worked in the sense that loss went down: a 2-layer model reached ~25% token-level validation accuracy after 50 epochs. But generation was broken. When we tried to generate autoregressively, the model predicted garbage. Root cause: without a causal mask, the model attended to the answer token while predicting the answer token. It had never been forced to predict from the query alone — teacher forcing mismatch.
A second problem surfaced here: JAX does not throw on out-of-bounds embedding indices. It silently clamps them. With input_num_classes=1000, the = token (index 1001) shared an embedding vector with 999. The model trained on corrupted embeddings for several runs with no signal from the loss.
Switched to character-level tokenization and added a causal mask. But we kept computing loss only on answer tokens — positions after =. After 2 epochs, 1+2= predicted EOS immediately.
The failure mode was subtle. With output-only loss, the only gradient signal coming from position 3 (=) was from cases where the answer was 1 digit and EOS followed immediately. The model learned to predict EOS after = because that was the dominant token in the positions it was trained on. It had no incentive to predict the first answer digit from the query alone, because it was never penalized for failing to do that.
Debugging confirmed it: feeding the full correct sequence [1,10,2,11,3,13,...] showed the model correctly predicted 3 at position 4. Feeding [1,10,2,11,PAD,PAD,...] — the actual inference condition — caused it to predict garbage. Teacher forcing mismatch, again, through a different mechanism.
The fix: compute loss on all non-pad tokens. Labels = input tokens shifted by 1. Loss mask = 1 where the label is a real token. The model must predict every next token from all previous ones — predicting + from a, predicting = from a+b, and predicting the first answer digit from a+b= with only the query visible. With a causal mask, this forces true autoregressive learning. GPT-style.
After 2 epochs: 123+456=579 correct. 999+999=1998 correct. Carry errors still present on small numbers (1+2=0 instead of 3 — still converging). After a few more epochs, improving consistently. Generation works from query alone.
The generation loop, which runs autoregressively from the query string alone:
Bugs that mattered
Five bugs were found and fixed across the three attempts. Two of them — the cosine decay alpha and the missing causal mask — would have made the scaling sweep results meaningless, because the model would appear to train while learning almost nothing useful.
alpha = 0.1 * self.lr decays the learning rate to lr/2000 after ~100 steps, effectively killing learning for the rest of training. The correct form is alpha = 0.1 — decays to lr/10 over the full horizon. This single bug caused all early runs to plateau far too early. The 4-layer model that appeared to get stuck in the depth ablations almost certainly failed because of this, not because of the layer count.
The mask was computed correctly but the parameter was never forwarded into model.apply. The model trained with full attention — it could see all future tokens including the answer — and then failed completely during autoregressive generation when the answer tokens weren't in context. No error was raised; the loss looked fine.
As described above: the model learned to predict answer tokens only when answer tokens were already visible in context. At inference time, with only the query in context, it had no useful behaviour for the positions it needed to actually predict.
JAX clamps out-of-bounds embedding indices instead of throwing. With input_num_classes=1000, the = token (index 1001) silently shared an embedding vector with 999 for the entire number-level tokenization phase. The loss gave no indication.
The expand_mask helper used .unsqueeze() (PyTorch API) instead of jnp.expand_dims(). The causal mask happened to already be shape (1,1,S,S), so the call was a no-op rather than a crash — it worked by coincidence of the input shape, not by correctness.
Architecture ablations
With the correct training setup locked, we ran ablations across width, head count, depth, and weight tying. All runs used lr=5e-3, grad clip 1.0, and the full next-token loss. Numbers below are from W&B logs.
Width and head count (2 layers, 2 epochs):
| dim | Heads | Head dim | Val loss |
|---|---|---|---|
| 256 | 8 | 32 | 1.296 |
| 256 | 32 | 8 | 1.270 |
| 128 | 8 | 16 | 1.268 |
| 64 | 8 | 8 | 1.249 |
| 64 | 32 | 2 | 1.139 |
| 128 | 16 | 8 | 1.096 |
| 64 | 16 | 4 | 1.094 |
Smaller models converge better on this task — the bottleneck is data, not capacity. dim=64 with 16 heads (head_dim=4) is the clear winner. Larger dims hurt convergence even when the head count is adjusted. The small head_dim of 4 is unusual for NLP tasks but fits here: the sequence is only 13 tokens and the vocabulary is 14 tokens.
Depth (dim=64, heads=16, 2 epochs):
| Layers | Val loss |
|---|---|
| 2 | 1.094 |
| 6 | 1.077 |
| 4 | 1.061 |
4 layers edges out 2 and 6 at dim=64. The 6-layer result was only measured at 2 epochs so it may improve further, but 4 layers is stable and converges reliably.
Weight tying (4 layers, dim=64, heads=16):
| Weight tying | Val loss (2 epochs) | Val loss (50 epochs) |
|---|---|---|
| No | 1.061 | 1.053 |
| Yes | 1.038 | 1.030 |
With a vocabulary of 14 tokens, the input embedding and output projection can share the same matrix. We implemented it by storing one learned embedding parameter and using logits = x @ self.embedding.T in the forward pass. Weight tying gives ~0.02 lower val loss at both checkpoints and uses fewer parameters (~200K vs ~206K). Locked in.
Locked configuration
The configuration below is what the final run (4 layers, dim=64, 16 heads, 50 epochs, ~200K parameters) used, and is fixed for the scaling sweep. The only things that vary across runs are d_model (which scales N) and training tokens D.
| Hyperparameter | Value | Rationale |
|---|---|---|
| n_layers | 4 | Final config from ablations |
| num_heads | 16 (= d_model / 4) | head_dim=4 performed best; ratio kept fixed across sizes |
| ff_size | 4 × d_model | Standard transformer ratio |
| lr | 5e-3 | Best across ablations |
| cosine alpha | 0.1 | Decays to lr/10 over full horizon (bug fix from 0.1 × lr) |
| grad clip | 1.0 | Prevents spikes in early training |
| dropout | 0.15 | |
| optimizer | AdamW | |
| weight tying | yes | Faster convergence, fewer parameters |
| loss | next-token, all non-pad | Forces autoregressive learning |
| causal mask | yes | Required for generation to work |
| vocab size | 14 | Digits + operators + PAD + EOS |
| MAX_SEQ_LEN | 13 | Longest sequence: 999+999=1998⟨eos⟩ |
The 6ND FLOPs approximation is less accurate here than for large LLMs — at small model sizes, the embedding parameters are a large fraction of total parameters (~30%+), which the approximation doesn't account for. This will be flagged as a limitation in the Part 2 results.
Config is locked. The sweep is running. Results in Part 2.