![Writing an LLM compiler from scratch [Part 3]: Autotuning — A Search Loop Over Tile-IR Rewrites](https://external-preview.redd.it/KDcGWQ5_UV56ckQ-6ovCMd_OwmRmEMBWokx5ulSj2uo.jpeg?width=1080&crop=smart&auto=webp&s=dbb502da4ca30078f874d9a428515551ecaf3bb3)
Writing an LLM compiler from scratch [Part 3]: Autotuning — A Search Loop Over Tile-IR Rewrites
The third and final article of building a hackable ML compiler from scratch. The previous parts built a six-IR pipeline (Torch → Tensor → Loop → Tile → Kernel → CUDA) and lowered TinyLlama / Qwen2.5-7B through it.
Block sizes, register tiles, staging decisions, etc., were determined by a heuristic that didn't generalize beyond the matmul shapes it was fitted on.
This part swaps those heuristics for a search loop. An SP-MCTS that explores the cross-product of rule parameters, benchmarks each candidate, and persists winners in a SQLite cache keyed by structural op hash. The cache replays on subsequent compiles.
On RTX 5090, the tuned stack lands at geomean 0.96× vs PyTorch eager (vs 0.87× for the heuristic and 0.91× for torch.compile), with 32 of 84 kernel shapes faster than PyTorch hand-optimized kernels. Best kernels are 5.6× faster than PyTorch (tall-skinny matmuls).
Passes
Pass Forks
tileify —
chunk_matmul_k one per legal K-chunk size (divisors of K, 16..128)
split_matmul_k apply or skip — turn K into a parallel reduction
cooperative_reduce —
blockify_launch one per threads-per-block ∈ {64,128,256,512}
chunk_reduce —
stage_inputs which inputs to stage in smem (2^k combinations)
register_tile one per (F_M, F_N) divisor pair
permute_reg_tile inner-loop order ∈ {km, mk}
double_buffer apply or skip — split stage buffers for overlap
tma_copy apply or skip on sm_90+
async_copy apply or skip on sm_80+ (cp.async)
pad_smem —
pipeline_k_outer apply or skip
mark_unroll —
A dense matmul with six staging-relevant inputs, three legal K-chunks, four threads-per-block values, eight register-tile shapes, two pipelining choices, and two double-buffering choices spans 2^6 × 3 × 4 × 8 × 2 × 2 ≈ 24,000 terminals.
Search loop
SP-MCTS with max-Q propagation, normalized UCB1, and a patience termination criterion (stop after N consecutive measured terminals without a new best):
def sp_mcts(root, patience, c):
best_reward = 0.0
visits_at_best = 0
while root.visits - visits_at_best < patience:
# SELECT
# descend to a frontier node by UCB1 over normalized max-Q
node = root
while node.children and node.has_unfinished_descendant():
node = max(
(ch for ch in node.children if ch.has_unfinished_descendant()),
key=lambda ch: ucb(ch, node, c),
)
# SIMULATE / EXPAND — advance one rule
# spawn forks or bench a terminal
result = advance_one_rule(node.candidate)
if result.forks:
node.children = [Node(c, parent=node) for c in result.forks]
continue
reward = 1.0 / bench_latency(result.cuda_op)
# BACKPROP — walk parent links
# bump visits, max-update best_reward
n = node
while n is not None:
n.visits += 1
n.best_reward = max(n.best_reward, reward)
n = n.parent
Structural keys
The entire cache is keyed by structural digests that describe the kernel's structure. To produce a structural key, eight normalization passes are used: drop size-1 free axes, sequential SSA rename, sort commutative args, canonicalize external buffer names, collapse op clusters: sub ↔ add (FMA), mod ↔ divide (SFU), the compare family; then hash the result.
Under this transformation, the following ops become identical and the same scheduling decisions will be applied:
# Op A
for i in range(M):
for j in range(1):
tmp = load(X[i])
result = tmp + bias[i]
Y[i, j] = result
# Op B
# different names and '-' instead of '+'
for i in range(M):
a = load(input0[i])
b = load(input1[i])
c = a - b
output0[i] = c
Run CLI example from the repo:
# Eager 25 µs, Deplodock 38.9 µs (0.64× eager)
deplodock run --bench -c \
"a=torch.randn(1,32,2048);b=torch.randn(2048,5632);torch.matmul(a,b)"
# Tune (default patience 60). 207 variants explored in 67.7s,
# best 22.54 µs at BM=32, BN=64, F_M=8, F_N=2 (worst was 293.75 µs).
deplodock tune -v -c \
"a=torch.randn(1,32,2048);b=torch.randn(2048,5632);torch.matmul(a,b)"
# Re-run with the cached knobs — 22.7 µs (1.10× eager)
deplodock run --bench -c \
"a=torch.randn(1,32,2048);b=torch.randn(2048,5632);torch.matmul(a,b)"