← Writing

Deriving Chinchilla scaling laws from scratch — Part 1

01 Task design 3-digit addition char-level, vocab 14 1M unique (a, b) pairs 02 Fix training 3 attempts 5 bugs fixed GPT next-token loss 03 Architecture ablations depth, width, heads weight tying dim=64 16 heads, tied weights 04 Config locked 20 model sizes × 5 data budgets 100 runs → Part 2 fit L(N,D) = A/Nᵅ + B/Dᵝ Part 2

Where this came from

I came across Vlad Feinberg's article on breaking into a frontier lab a few weeks ago. One of his suggested exercises is to derive scaling laws empirically on a narrow task, using only free compute. The Chinchilla paper shows that compute-optimal training requires balancing model size and training tokens according to a specific power law — but that was derived for large language models on internet text. The question that stuck with me: does the same functional form show up on a deterministic, narrow task that fits entirely on a Kaggle T4?

The task I chose is 3-digit addition. The goal is not to solve addition — a transformer can do that in a few thousand steps. The goal is to empirically characterize how loss scales with model size N and training tokens D, then fit L(N, D) = A / N^α + B / D^β, the same form Hoffmann et al. used. Part 1 is the work that makes the sweep valid: getting the training procedure right (three attempts), finding and fixing the bugs that would have silently corrupted the results, running architecture ablations, and locking a configuration. The sweep itself is Part 2.

The task

Given a + b, predict c. Both operands are integers from 0 to 999, so c ranges from 0 to 1998. There are exactly 1,000,000 unique (a, b) pairs. We hold out 100,000 as a validation set. The split is combination-aware: each pair is keyed by its canonical sorted form, so if (3, 5) lands in val, the associative pair (5, 3) is removed from train — it does not appear in val, it simply cannot appear in training either. This prevents the model from seeing a structurally equivalent example during training and calling it generalisation.

Tokenization is character-level: digits 0–9 (indices 0–9), + (10), = (11), PAD (12), EOS (13). Vocabulary size 14. The sequence 999+999=1998 becomes [9,9,9,10,9,9,9,11,1,9,9,8,13] — 13 tokens, which sets MAX_SEQ_LEN.

The task being deterministic matters for the scaling law. The irreducible error E in the Chinchilla formula is approximately zero — there is a single correct answer for every input. Loss should scale to zero with enough capacity and data, with no noise floor to estimate separately.

Three attempts to get training right

Getting the training procedure correct took three distinct attempts. Each one produced a silent failure — a model that appeared to train (loss going down, no errors) but could not generate correct answers. The wrong turns are worth documenting in detail because they each point at a different failure mode.

Failed Attempt 1 — number-level tokenization

The first instinct was to treat each number as a single token. Vocabulary: integers 0–999 as input tokens, 0–1998 as output tokens, plus + and =. Sequence: [a, +, b, =, c] — five tokens. Loss computed only on position 4 (the answer).

The reasoning seemed sound — let the model learn that 247 and 381 as atomic units compose to 628. This worked in the sense that loss went down: a 2-layer model reached ~25% token-level validation accuracy after 50 epochs. But generation was broken. When we tried to generate autoregressively, the model predicted garbage. Root cause: without a causal mask, the model attended to the answer token while predicting the answer token. It had never been forced to predict from the query alone — teacher forcing mismatch.

A second problem surfaced here: JAX does not throw on out-of-bounds embedding indices. It silently clamps them. With input_num_classes=1000, the = token (index 1001) shared an embedding vector with 999. The model trained on corrupted embeddings for several runs with no signal from the loss.

KeptCombination-level val split · transformer architecture · AdamW + cosine decay
DroppedNumber-level tokenization · output-only loss
Failed Attempt 2 — character-level tokens, output-only loss

Switched to character-level tokenization and added a causal mask. But we kept computing loss only on answer tokens — positions after =. After 2 epochs, 1+2= predicted EOS immediately.

The failure mode was subtle. With output-only loss, the only gradient signal coming from position 3 (=) was from cases where the answer was 1 digit and EOS followed immediately. The model learned to predict EOS after = because that was the dominant token in the positions it was trained on. It had no incentive to predict the first answer digit from the query alone, because it was never penalized for failing to do that.

Debugging confirmed it: feeding the full correct sequence [1,10,2,11,3,13,...] showed the model correctly predicted 3 at position 4. Feeding [1,10,2,11,PAD,PAD,...] — the actual inference condition — caused it to predict garbage. Teacher forcing mismatch, again, through a different mechanism.

KeptCharacter-level tokenization · causal mask
DroppedOutput-only loss
Worked Attempt 3 — full next-token prediction

The fix: compute loss on all non-pad tokens. Labels = input tokens shifted by 1. Loss mask = 1 where the label is a real token. The model must predict every next token from all previous ones — predicting + from a, predicting = from a+b, and predicting the first answer digit from a+b= with only the query visible. With a causal mask, this forces true autoregressive learning. GPT-style.

After 2 epochs: 123+456=579 correct. 999+999=1998 correct. Carry errors still present on small numbers (1+2=0 instead of 3 — still converging). After a few more epochs, improving consistently. Generation works from query alone.

KeptEverything — this is the training setup for the scaling sweep
DroppedNothing

The generation loop, which runs autoregressively from the query string alone:

def predict(trainer, tokenizer, a, b): query = f"{a}+{b}=" tokens = [tokenizer.token_to_idx[c] for c in query] answer = '' for _ in range(tokenizer.MAX_SEQ_LEN - len(tokens)): inp = jnp.array([tokens]).astype(jnp.int32) causal_mask = generate_causal_mask(len(tokens)) logits = trainer.model.apply( {'params': trainer.state.params}, inp, mask=causal_mask, train=False ) next_token = int(logits[0, -1, :].argmax()) if next_token == tokenizer.EOS_IDX: break if next_token in [tokenizer.token_to_idx['+'], tokenizer.token_to_idx['=']]: break answer += tokenizer.idx_to_token[next_token] tokens.append(next_token) return answer

Bugs that mattered

Five bugs were found and fixed across the three attempts. Two of them — the cosine decay alpha and the missing causal mask — would have made the scaling sweep results meaningless, because the model would appear to train while learning almost nothing useful.

01
Wrong cosine decay alpha

alpha = 0.1 * self.lr decays the learning rate to lr/2000 after ~100 steps, effectively killing learning for the rest of training. The correct form is alpha = 0.1 — decays to lr/10 over the full horizon. This single bug caused all early runs to plateau far too early. The 4-layer model that appeared to get stuck in the depth ablations almost certainly failed because of this, not because of the layer count.

02
Causal mask generated but never passed to model.apply

The mask was computed correctly but the parameter was never forwarded into model.apply. The model trained with full attention — it could see all future tokens including the answer — and then failed completely during autoregressive generation when the answer tokens weren't in context. No error was raised; the loss looked fine.

03
Loss only on output tokens

As described above: the model learned to predict answer tokens only when answer tokens were already visible in context. At inference time, with only the query in context, it had no useful behaviour for the positions it needed to actually predict.

04
Silent JAX out-of-bounds embedding

JAX clamps out-of-bounds embedding indices instead of throwing. With input_num_classes=1000, the = token (index 1001) silently shared an embedding vector with 999 for the entire number-level tokenization phase. The loss gave no indication.

05
PyTorch .unsqueeze in JAX expand_mask

The expand_mask helper used .unsqueeze() (PyTorch API) instead of jnp.expand_dims(). The causal mask happened to already be shape (1,1,S,S), so the call was a no-op rather than a crash — it worked by coincidence of the input shape, not by correctness.

Architecture ablations

With the correct training setup locked, we ran ablations across width, head count, depth, and weight tying. All runs used lr=5e-3, grad clip 1.0, and the full next-token loss. Numbers below are from W&B logs.

Width and head count (2 layers, 2 epochs):

dimHeadsHead dimVal loss
2568321.296
2563281.270
1288161.268
64881.249
643221.139
1281681.096
641641.094

Smaller models converge better on this task — the bottleneck is data, not capacity. dim=64 with 16 heads (head_dim=4) is the clear winner. Larger dims hurt convergence even when the head count is adjusted. The small head_dim of 4 is unusual for NLP tasks but fits here: the sequence is only 13 tokens and the vocabulary is 14 tokens.

Depth (dim=64, heads=16, 2 epochs):

LayersVal loss
21.094
61.077
41.061

4 layers edges out 2 and 6 at dim=64. The 6-layer result was only measured at 2 epochs so it may improve further, but 4 layers is stable and converges reliably.

Weight tying (4 layers, dim=64, heads=16):

Weight tyingVal loss (2 epochs)Val loss (50 epochs)
No1.0611.053
Yes1.0381.030

With a vocabulary of 14 tokens, the input embedding and output projection can share the same matrix. We implemented it by storing one learned embedding parameter and using logits = x @ self.embedding.T in the forward pass. Weight tying gives ~0.02 lower val loss at both checkpoints and uses fewer parameters (~200K vs ~206K). Locked in.

Locked configuration

The configuration below is what the final run (4 layers, dim=64, 16 heads, 50 epochs, ~200K parameters) used, and is fixed for the scaling sweep. The only things that vary across runs are d_model (which scales N) and training tokens D.

HyperparameterValueRationale
n_layers4Final config from ablations
num_heads16 (= d_model / 4)head_dim=4 performed best; ratio kept fixed across sizes
ff_size4 × d_modelStandard transformer ratio
lr5e-3Best across ablations
cosine alpha0.1Decays to lr/10 over full horizon (bug fix from 0.1 × lr)
grad clip1.0Prevents spikes in early training
dropout0.15
optimizerAdamW
weight tyingyesFaster convergence, fewer parameters
lossnext-token, all non-padForces autoregressive learning
causal maskyesRequired for generation to work
vocab size14Digits + operators + PAD + EOS
MAX_SEQ_LEN13Longest sequence: 999+999=1998⟨eos⟩

The 6ND FLOPs approximation is less accurate here than for large LLMs — at small model sizes, the embedding parameters are a large fraction of total parameters (~30%+), which the approximation doesn't account for. This will be flagged as a limitation in the Part 2 results.

Config is locked. The sweep is running. Results in Part 2.