close
Skip to content

Replace decode_layer with the manual-scope e2e kernel#449

Open
lwDavid wants to merge 2 commits into
hw-native-sys:mainfrom
lwDavid:replace-decode-layer-with-e2e
Open

Replace decode_layer with the manual-scope e2e kernel#449
lwDavid wants to merge 2 commits into
hw-native-sys:mainfrom
lwDavid:replace-decode-layer-with-e2e

Conversation

@lwDavid
Copy link
Copy Markdown
Contributor

@lwDavid lwDavid commented Jun 3, 2026

Summary

models/qwen3/14b/decode_layer.py is now the single-layer manual-scope e2e
kernel
(qwen3_decode_mpmd): contiguous KV cache, fixed BATCH=16, per-head
Qwen3 QK-norm, deferred input/QK RMSNorm, no in-kernel LM head — driven
per-layer from the host (qwen3_manual_generate.py). The fused multi-layer
decode_fwd.py and the paged @pl.jit.inline decode_layer it called are removed.

The manual_scope kernel this derives from was optimized by
Hzfengsy — credit to him for the original
optimized implementation.

⚠️ Capabilities / mechanisms removed by this replacement

The old decode_fwd.py + paged decode_layer provided serving mechanisms the
e2e kernel does not. These are intentionally dropped:

  1. Paged KV cache — old decode_layer used block_table / slot_mapping
    (block-allocated paged cache). The e2e uses a contiguous per-(batch, kv-head)
    cache sized BATCH × NUM_KV_HEADS × MAX_SEQ. → No dynamic block allocation; every
    sequence reserves MAX_SEQ rows (memory-inefficient for variable-length / many
    concurrent sequences).

  2. Device-side fused multi-layer decodedecode_fwd looped all layers
    (for layer_idx in pl.range(num_layers)) inside one compiled program and ran
    the LM head, so one decode step = one device dispatch. The e2e is a single
    layer; the 40 layers are now driven from a Python host loop (~40 dispatches +
    host round-trips per token). → Loss of dispatch fusion / low host overhead.

  3. Variable batch / continuous batching — old path read user_batch from the
    input and padded to BATCH_TILE via valid_shape, so one compiled program served
    any batch ≤ capacity. The e2e is fixed BATCH=16. → No dynamic batch sizes or
    continuous batching of mixed-length requests.

  4. In-kernel LM headdecode_fwd ended with rms_lm_head (final RMSNorm +
    vocab matmul on device). The e2e has no LM head; the host driver does it in
    torch. → No on-device LM head.

  5. Multi-layer stacked weights — old decode_layer sliced per-layer weights via
    layer_idx from [num_layers*hidden, …] tensors. The e2e takes single-layer
    weights; the host loop feeds per-layer weights each call.

  6. decode_layer test/debug harness — the old file shipped test_decode_layer /
    golden_decode_layer / build_tensor_specs and a rich CLI via the golden
    run_jit harness: --batch, --lm-head {full,skip,single}, --enable-pmu,
    --export-kernel-insight / --kernel-insight-func. The e2e ships a simpler
    data-dir golden driver (--smoke, --enable-l2-swimlane, -p/-d/--data-dir,

    =98% ratio check). → Those debug/profiling switches and the run_jit golden path
    are gone.

  7. Serving integration breaks until re-wiredcli.main -> npu_executor -> decode_fwd is the live serving path; with decode_fwd removed it will not start.
    npu_executor._compile_decode_fwd_callable and the paged KvCacheManager path are
    now dead. Re-wiring npu_executor / cli.main to drive the e2e host-loop is not
    part of this PR. (prefill_fwd.py stays in-tree but its integrated decode
    counterpart is gone.)

Validation

  • python decode_layer.py --smoke -> compiles clean (52 functions).
  • On device (a2a3): PASS — 99.49% of elements within rtol/atol=3e-3
    (mismatches=420/81920, max_abs_err=0.0078). A strict 100% torch.allclose is
    impossible for bf16 outputs (1 bf16 ULP at value 1 = 0.0039 > atol 3e-3), so the
    driver uses a >=98% ratio check; the residual ~0.5% is 1-2 ULP bf16 quantization.
  • End-to-end token generation via qwen3_manual_generate.py (--kernel-file decode_layer.py) works.

Companion change (serving repo)

The serving repo's examples/model/qwen3_14b/qwen3_manual_generate.py default
--kernel-file is updated to decode_layer.py (was qwen3_manual_scope.py).


Follow-up (this branch): WIP device-side fused decode_fwd

This branch also rebuilds the fused multi-layer forward on the e2e kernel
(replacing the removed paged decode_fwd.py). The single-layer body is refactored
into an @pl.jit.inline _decode_layer parameterized by layer_idx (multi-layer
stacked weights + per-layer KV-cache offsets); qwen3_decode_mpmd stays the
single-layer entry; and decode_fwd is a device-side pl.range loop over all
NUM_LAYERS + in-kernel rms_lm_head — the full decode in one dispatch.

  • qwen3_decode_mpmd (single-layer) is verified correct on device (99.5%).
  • decode_fwd compiles (52/57 functions) but its fused execution currently
    returns all-zeros — a pypto pl.range loop-carry / inline dependency
    limitation: the inlined body's reads of the loop-carried hidden don't dep on
    their producer (copy_hidden) and race to pre-write zeros. Isolated precisely
    (copy-only decode_fwd reproduces hidden exactly; every kernel-level workaround
    fails the same way). Documented in the code as a KNOWN BUG; needs a
    framework-level fix. CI is unaffected — the compile-only fallback never
    exercises decode_fwd.

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Jun 3, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 340a9977-fc84-4ab9-8b8f-31f0d7af5910

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR removes the entire decode_fwd.py module from the Qwen3-14B model directory. The deletion eliminates 586 lines containing a JIT-compiled decode implementation and four public exports: decode_fwd, build_tensor_specs, golden_decode_fwd, and make_pass_rate_compare.

Changes

Cohort / File(s) Summary
Module Removal
models/qwen3/14b/decode_fwd.py
Complete file deletion removing JIT-compiled forward pass for Qwen3-14B decode, tensor spec construction, PyTorch reference implementation, and pass-rate comparison utilities (586 lines removed).

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

Poem

🐰 A module departs with grace,
Decode_fwd leaves empty space,
JIT and specs now fade away,
The rabbit cheers this tidy day! ✨

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The PR title references replacing decode_layer with a manual-scope e2e kernel, which is related to the removal of decode_fwd.py; however, the description focuses primarily on a ptoas toolchain bump from v0.36 to v0.40 with test fixes and tolerance adjustments, making the title partially descriptive but not capturing the main work.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The pull request description comprehensively explains the changes, including the replacement of the fused decode_fwd with a single-layer manual-scope e2e kernel, detailed rationale for mechanisms being removed, validation results, and implications for the codebase.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes the entire decode forward implementation file, which previously contained the JIT-compiled decode_fwd function for Qwen3-14B full-layer decode, its PyTorch golden reference golden_decode_fwd, and associated testing utilities. Since there are no review comments, I have no feedback to provide.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

@lwDavid lwDavid self-assigned this Jun 3, 2026
@lwDavid lwDavid added the enhancement New feature or request label Jun 3, 2026
@lwDavid lwDavid moved this to Done in pto project Jun 3, 2026
@lwDavid lwDavid requested a review from zhangqi-chen June 3, 2026 08:05
@lwDavid lwDavid force-pushed the replace-decode-layer-with-e2e branch 3 times, most recently from 95c0f86 to fb02737 Compare June 3, 2026 08:12
decode_layer.py is now the single-layer manual-scope e2e kernel
(qwen3_decode_mpmd): contiguous KV cache, fixed BATCH=16, per-head Qwen3
QK-norm, deferred input/QK RMSNorm, and no in-kernel LM head — driven
per-layer from the host (qwen3_manual_generate.py). The fused multi-layer
decode_fwd.py and the paged @pl.jit.inline decode_layer it called are removed.

The manual_scope kernel this derives from was optimized by Hzfengsy
(https://github.com/Hzfengsy).
@lwDavid lwDavid marked this pull request as draft June 3, 2026 09:00
Refactor the single-layer kernel into an @pl.jit.inline _decode_layer body
parameterized by layer_idx (multi-layer stacked weights + per-layer KV-cache
offsets); keep qwen3_decode_mpmd as the single-layer entry (verified correct on
device, 99.5%); add decode_fwd: a device-side pl.range loop over all NUM_LAYERS
plus in-kernel rms_lm_head, fusing the whole decode forward into one dispatch.

decode_fwd COMPILES (52/57 functions) but its fused execution currently returns
all-zeros — a pypto pl.range loop-carry / inline dependency limitation: the
inlined body's reads of the loop-carried hidden do not establish a dependency on
its producer (copy_hidden), so they race and read pre-write zeros. Documented in
the code. The single-layer path (qwen3_decode_mpmd) is unaffected and the CI
compile-only fallback does not exercise decode_fwd. Needs a framework-level fix.
@lwDavid lwDavid marked this pull request as ready for review June 3, 2026 10:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

1 participant