Replace decode_layer with the manual-scope e2e kernel#449
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR removes the entire Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request removes the entire decode forward implementation file, which previously contained the JIT-compiled decode_fwd function for Qwen3-14B full-layer decode, its PyTorch golden reference golden_decode_fwd, and associated testing utilities. Since there are no review comments, I have no feedback to provide.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
95c0f86 to
fb02737
Compare
decode_layer.py is now the single-layer manual-scope e2e kernel (qwen3_decode_mpmd): contiguous KV cache, fixed BATCH=16, per-head Qwen3 QK-norm, deferred input/QK RMSNorm, and no in-kernel LM head — driven per-layer from the host (qwen3_manual_generate.py). The fused multi-layer decode_fwd.py and the paged @pl.jit.inline decode_layer it called are removed. The manual_scope kernel this derives from was optimized by Hzfengsy (https://github.com/Hzfengsy).
Refactor the single-layer kernel into an @pl.jit.inline _decode_layer body parameterized by layer_idx (multi-layer stacked weights + per-layer KV-cache offsets); keep qwen3_decode_mpmd as the single-layer entry (verified correct on device, 99.5%); add decode_fwd: a device-side pl.range loop over all NUM_LAYERS plus in-kernel rms_lm_head, fusing the whole decode forward into one dispatch. decode_fwd COMPILES (52/57 functions) but its fused execution currently returns all-zeros — a pypto pl.range loop-carry / inline dependency limitation: the inlined body's reads of the loop-carried hidden do not establish a dependency on its producer (copy_hidden), so they race and read pre-write zeros. Documented in the code. The single-layer path (qwen3_decode_mpmd) is unaffected and the CI compile-only fallback does not exercise decode_fwd. Needs a framework-level fix.
Summary
models/qwen3/14b/decode_layer.pyis now the single-layer manual-scope e2ekernel (
qwen3_decode_mpmd): contiguous KV cache, fixedBATCH=16, per-headQwen3 QK-norm, deferred input/QK RMSNorm, no in-kernel LM head — driven
per-layer from the host (
qwen3_manual_generate.py). The fused multi-layerdecode_fwd.pyand the paged@pl.jit.inlinedecode_layerit called are removed.The
manual_scopekernel this derives from was optimized byHzfengsy — credit to him for the original
optimized implementation.
The old
decode_fwd.py+ pageddecode_layerprovided serving mechanisms thee2e kernel does not. These are intentionally dropped:
Paged KV cache — old
decode_layerusedblock_table/slot_mapping(block-allocated paged cache). The e2e uses a contiguous per-(batch, kv-head)
cache sized
BATCH × NUM_KV_HEADS × MAX_SEQ. → No dynamic block allocation; everysequence reserves
MAX_SEQrows (memory-inefficient for variable-length / manyconcurrent sequences).
Device-side fused multi-layer decode —
decode_fwdlooped all layers(
for layer_idx in pl.range(num_layers)) inside one compiled program and ranthe LM head, so one decode step = one device dispatch. The e2e is a single
layer; the 40 layers are now driven from a Python host loop (~40 dispatches +
host round-trips per token). → Loss of dispatch fusion / low host overhead.
Variable batch / continuous batching — old path read
user_batchfrom theinput and padded to
BATCH_TILEviavalid_shape, so one compiled program servedany batch ≤ capacity. The e2e is fixed
BATCH=16. → No dynamic batch sizes orcontinuous batching of mixed-length requests.
In-kernel LM head —
decode_fwdended withrms_lm_head(final RMSNorm +vocab matmul on device). The e2e has no LM head; the host driver does it in
torch. → No on-device LM head.
Multi-layer stacked weights — old
decode_layersliced per-layer weights vialayer_idxfrom[num_layers*hidden, …]tensors. The e2e takes single-layerweights; the host loop feeds per-layer weights each call.
decode_layer test/debug harness — the old file shipped
test_decode_layer/golden_decode_layer/build_tensor_specsand a rich CLI via thegoldenrun_jitharness:--batch,--lm-head {full,skip,single},--enable-pmu,--export-kernel-insight/--kernel-insight-func. The e2e ships a simplerdata-dir golden driver (
--smoke,--enable-l2-swimlane,-p/-d/--data-dir,Serving integration breaks until re-wired —
cli.main -> npu_executor -> decode_fwdis the live serving path; withdecode_fwdremoved it will not start.npu_executor._compile_decode_fwd_callableand the pagedKvCacheManagerpath arenow dead. Re-wiring
npu_executor/cli.mainto drive the e2e host-loop is notpart of this PR. (
prefill_fwd.pystays in-tree but its integrated decodecounterpart is gone.)
Validation
python decode_layer.py --smoke-> compiles clean (52 functions).a2a3): PASS — 99.49% of elements withinrtol/atol=3e-3(
mismatches=420/81920,max_abs_err=0.0078). A strict 100%torch.allcloseisimpossible for bf16 outputs (1 bf16 ULP at value 1 = 0.0039 > atol 3e-3), so the
driver uses a >=98% ratio check; the residual ~0.5% is 1-2 ULP bf16 quantization.
qwen3_manual_generate.py(--kernel-file decode_layer.py) works.Companion change (serving repo)
The serving repo's
examples/model/qwen3_14b/qwen3_manual_generate.pydefault--kernel-fileis updated todecode_layer.py(wasqwen3_manual_scope.py).Follow-up (this branch): WIP device-side fused
decode_fwdThis branch also rebuilds the fused multi-layer forward on the e2e kernel
(replacing the removed paged
decode_fwd.py). The single-layer body is refactoredinto an
@pl.jit.inline _decode_layerparameterized bylayer_idx(multi-layerstacked weights + per-layer KV-cache offsets);
qwen3_decode_mpmdstays thesingle-layer entry; and
decode_fwdis a device-sidepl.rangeloop over allNUM_LAYERS+ in-kernelrms_lm_head— the full decode in one dispatch.qwen3_decode_mpmd(single-layer) is verified correct on device (99.5%).decode_fwdcompiles (52/57 functions) but its fused execution currentlyreturns all-zeros — a pypto
pl.rangeloop-carry / inline dependencylimitation: the inlined body's reads of the loop-carried hidden don't dep on
their producer (
copy_hidden) and race to pre-write zeros. Isolated precisely(copy-only
decode_fwdreproduces hidden exactly; every kernel-level workaroundfails the same way). Documented in the code as a
KNOWN BUG; needs aframework-level fix. CI is unaffected — the compile-only fallback never
exercises
decode_fwd.