把静态图推理透明包装到 HF generate

torch.export 出的 Qwen3 静态图 → 替换 model.model → 走原版 generate()
分析对象: qwen3_0_6b_infer.py  |  transformers 源码: 本地 main 分支

0. 为什么要走 generate 接口?

把 LLM 部署到固定 shape / 静态 KV cache 的推理后端 (NPU、edge 芯片、ExecuTorch 等) 时, 模型必须通过 torch.export 导出为固定形状的计算图。一种直接做法是导出图后自己写 token-by-token 推理循环 (类似 transformers 官方 TorchExportableModuleWithStaticCache + ExecuTorch 方案), 但这样:

本方案目标是 "透明": 让导出后的图伪装成一个普通 text model, 仍然挂在 Qwen3ForCausalLM 下, 仍然能被 generate() 调用。算法/评测/调用方完全无感, 底层却已经是静态图推理。

核心策略: 只替换 model.model (text model), 不替换 Qwen3ForCausalLM — 保留 lm_headgenerate() 等所有上层逻辑不动, 仅在 text model 这一层完成「Cache 对象 ↔ 固定 KV 张量」的桥接。

1. 整体架构

Qwen3ForCausalLM.generate() | Qwen3ForCausalLM.forward() | +-----------------------------+ | self.model (被替换) | | ExportedTextModelWrapper | | | | 1. Cache → past_k/v | | 2. dispatch → prefill/decode| | 桶, 调 exported_module | | 3. new_k/v → 写回 Cache | +-----------------------------+ | torch.export 出的两个图模块 (prefill / decode 各一个)

1.1 单步推理数据流 (decode step)

下图展示 generate 调用一次 forward 时, ExportedTextModelWrapper 内部如何完成 Cache ↔ 固定 KV 的双向桥接:

generate() 循环 持有 StaticCache cache, ids, mask, cache_pos, pos_ids ExportedTextModelWrapper.forward() ① cache_tensors_from_cache StaticCache.layers[i].keys ↓ copy past_k: (L,B,H,max_len,D) ② exported_module() prefill 桶 / decode 桶 (按 q_len dispatch) CustomCache(past_k, past_v, model, cache_position) ↓ get_seq_length = cache_position[0:1] Qwen3Model.forward attention 用 past_k 读, 新 KV 写回 (双写 past_k + export_new_k) 3 outputs ③ apply_new_kv_to_cache hidden: (B,Q,H_dim) new_k/v: (L,B,H,max_len,D) ↓ index_select at cache_pos slice: (B,H,Q,D) ↓ cache.update(...) StaticCache 累积成长 返回: hidden_state + 更新后的 Cache (cumulative_length += seq_len) → generate 继续下一步, 把这个 Cache 再传回来 下一步 generate (复用同一个 StaticCache)

2. 实现的四轮逻辑

整套方案由四轮配合完成: 导出前把 text model 包成扁平 IO 接口 → 导出分桶 trace → 导出后桥接回 CausalLM → 运行时由 wrapper dispatch 双桶。下面逐轮讲。

第一轮: 导出前 — 把 text model 变成 export 友好的接口

目标

torch.export 无法直接追踪 HF 的动态 Cache 对象 (随 token 数增长), 也不能跟踪可变 buffer 的副作用。需要把 text model 的 IO 重新包成: 纯 tensor in, 纯 tensor out

(a) ExportableWrapper — trace 的真正对象

class ExportableWrapper(torch.nn.Module):
    """把 text model 包成扁平 IO 接口, 作为 torch.export 的追踪对象。"""

    def __init__(self, text_model, max_cache_len, max_batch_size):
        super().__init__()
        self.text_model = text_model
        num_layers, num_kv_heads, head_dim = get_kv_dims(text_model.config)
        # 固定大小的 KV buffer 形状: (层数, batch, KV head 数, 最大缓存长度, head 维度)
        buffer_shape = (num_layers, max_batch_size, num_kv_heads, max_cache_len, head_dim)
        # 在 text_model 上注册 sidecar buffer (用于存放"本次新增 KV", 见 (b))
        ensure_model_buffers(text_model, buffer_shape, device, dtype)

    def forward(self, input_ids, attention_mask, past_k, past_v,
                cache_position, position_ids):
        # 清零 sidecar buffer, 准备记录本次 update 写入的 KV
        self.text_model.export_new_k.zero_()
        self.text_model.export_new_v.zero_()
        # 用固定 KV 张量构造 CustomCache, 让 text model 仍然能用 past_key_values 接口
        cache = CustomCache(past_k, past_v, self.text_model, cache_position)
        outputs = self.text_model(
            input_ids=input_ids, attention_mask=attention_mask,
            past_key_values=cache, use_cache=True,
            cache_position=cache_position, position_ids=position_ids,
        )
        # 三个扁平 tensor 输出: 隐藏状态 + 本次新增的 KV (从 sidecar 取出)
        return outputs.last_hidden_state, \
               self.text_model.export_new_k, self.text_model.export_new_v

(b) CustomCacheLayer.update — 方案的"灵魂", 双写 past_k 和 sidecar

class CustomCacheLayer(CacheLayerMixin):
    is_compileable = True  # 与 StaticLayer 对齐, 避免 SDPA mask 优化路径差异

    def __init__(self, past_k, past_v, layer_idx, model, cache_position):
        super().__init__()
        self.layer_idx = layer_idx
        self.model = model
        # 本层的 KV view, 指向外部传入的固定大小 past_k/past_v
        self.keys = past_k[layer_idx]
        self.values = past_v[layer_idx]
        self.max_cache_len = self.keys.shape[-2]
        # cumulative_length 必须是 tensor (而非 Python int), 这样 torch.export 才能
        # 把它当作数据依赖, 在不同 cache_position 下走出不同的 graph 行为
        self.cumulative_length = cache_position[0:1]

    def update(self, key_states, value_states, cache_kwargs=None):
        # Qwen3 attention 调 cache.update(K, V, layer_idx) 时不传 cache_position
        # kwarg, 因此必须像 StaticLayer 一样自己从 cumulative_length 推
        kv_length = key_states.shape[-2]
        cache_position = torch.arange(kv_length, device=self.device) + self.cumulative_length

        # 写入 1: 把新 KV 写到固定 past_k 的对应位置 (供本层 attention 读)
        self.keys[:batch_size].index_copy_(2, cache_position, key_states)
        self.values[:batch_size].index_copy_(2, cache_position, value_states)

        # 写入 2: 同时写到 sidecar buffer export_new_k (供 export 图作为显式输出)
        # 这样 ExportedProgram 才能把"本次新增 KV"作为返回值导出
        self.model.export_new_k[self.layer_idx, :batch_size].index_copy_(
            2, cache_position, key_states)
        self.model.export_new_v[self.layer_idx, :batch_size].index_copy_(
            2, cache_position, value_states)

        return self.keys, self.values

    def get_seq_length(self):
        # 返回 tensor 而不是 Python int, 让 torch.export 能追踪到数据依赖,
        # 否则 graph 内 mask 创建的 q_offset 会被 bake 为常量
        return self.cumulative_length

(c) pytree 注册 — 让 torch.export 能识别 CustomCache

# torch.export 把所有非 Tensor 类型展平 (flatten) 为 tensor leaves。
# CustomCache 是自定义类, 必须告诉 pytree 如何拆/装。
def custom_cache_flatten(cache):
    # leaves = 数据 tensor, context = 重建时需要的对象 (这里是 model 引用)
    return (cache.past_k, cache.past_v, cache.cache_position), cache.model

def custom_cache_unflatten(values, context):
    past_k, past_v, cache_position = values
    return CustomCache(past_k, past_v, context, cache_position)

_pytree.register_pytree_node(CustomCache, custom_cache_flatten, custom_cache_unflatten, ...)

第二轮: 导出 — 按 shape 分桶 trace 两个 ExportedProgram

目标

torch.export trace 时会把 graph 内部按当时的 q_length 选定一条数据路径并固化。如果只 trace 一次 (q_length = prompt_len), 运行 decode 步 (q_length = 1) 时就会走错路径。解决: 按 shape 分桶 — prefill 和 decode 各 trace 一个 ExportedProgram, 运行时按 q_length dispatch。

(d) trace 工厂 — 用同一个 ExportableWrapper trace 两次, 给不同 dynamic_shapes

def _trace_bucket(q_len_for_trace: int, is_decode: bool):
    """同一个 ExportableWrapper, trace 两次, 给不同 dynamic_shapes。"""
    # 占位输入 (内容无所谓, 形状决定 trace 的特化路径)
    ids_ = torch.zeros((max_batch_size, q_len_for_trace), dtype=torch.long, device=device)
    cp_ = torch.arange(q_len_for_trace, device=device)
    pid_ = cp_.unsqueeze(0)
    mask_ = _mask_with_valid_len(q_len_for_trace)  # prefix-1, 长度 = max_cache_len
    pk_, pv_ = init_kv_tensors(...)

    if is_decode:
        # decode 桶: q_len=1 静态, 与运行时一致, graph 内走 SDPA is_causal 短路
        dyn = {"input_ids": {}, "attention_mask": {}, "past_k": {}, "past_v": {},
               "cache_position": {}, "position_ids": {}}
    else:
        # prefill 桶: seq 动态, prompt 长度可以变
        seq_dim_ = torch.export.Dim("seq", min=1, max=max_cache_len - 1)
        dyn = {"input_ids": {1: seq_dim_}, "attention_mask": {},
               "past_k": {}, "past_v": {},
               "cache_position": {0: seq_dim_}, "position_ids": {1: seq_dim_}}

    return torch.export.export(
        exportable, args=(), kwargs={
            "input_ids": ids_, "attention_mask": mask_,
            "past_k": pk_, "past_v": pv_,
            "cache_position": cp_, "position_ids": pid_,
        }, dynamic_shapes=dyn, strict=args.export_strict,
    )

# trace 两个桶
prefill_exported = _trace_bucket(prompt_len, is_decode=False)
decode_exported = _trace_bucket(1, is_decode=True)

生产部署可按需扩展为多桶 (prefill_64/128/256, decode_1), dispatch 逻辑不变。

第三轮: 导出后 — 把扁平图模块桥接回 CausalLM

目标

exported 图只接受扁平 tensor IO, 但 generate() 会传 Cache 对象、dict 形态的 4D mask 等复杂结构。需要一个 wrapper 做双向转换: 进来时把 Cache 拆成 past_k/v, 出去时把新 KV 写回 Cache。

(e) ExportedTextModelWrapper.forward — 桥接逻辑 + dispatch

class ExportedTextModelWrapper(torch.nn.Module):
    def __init__(self, exported_module, base_model, max_cache_len, max_batch_size,
                 decode_module=None):
        super().__init__()
        # prefill 桶 (q_len > 1, 动态 seq)
        self.exported_module = exported_module
        # decode 桶 (q_len == 1, 静态), 由 forward 中按 q_length dispatch
        self.decode_module = decode_module
        ...

    def forward(self, input_ids=None, attention_mask=None, past_key_values=None,
                cache_position=None, position_ids=None, use_cache=None, **kwargs):
        # ---- 1) 从 Cache 推 past_seen, 构造正确的 cache_position / position_ids ----
        # generate() 不一定把 cache_position 转发给自定义 text model, 自己重建更可靠
        if past_key_values is not None:
            past_seen_t = past_key_values.get_seq_length()
            past_seen = int(past_seen_t.item()) if past_seen_t.numel() == 1 \
                                                  else int(past_seen_t[0].item())
        else:
            past_seen = 0
        q_len = input_ids.shape[1]

        if cache_position is None:
            cache_position = torch.arange(past_seen, past_seen + q_len, device=input_ids.device)
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # ---- 2) 把 generate 传入的 dict 4D mask 转成 exported 图能接受的 2D mask ----
        # 关键: 长度固定为 max_cache_len, 但 1-valued 前缀只到 past_seen + Q,
        # 后面填 0, 让 attention 知道"零填充位是 invalid"。
        if attention_mask is None or isinstance(attention_mask, dict):
            valid_len = past_seen + q_len
            arange = torch.arange(self.max_cache_len, device=input_ids.device)
            mask_1d = (arange < valid_len).to(input_ids.dtype)
            attention_mask = mask_1d.unsqueeze(0).expand(input_ids.shape[0], -1).contiguous()

        # ---- 3) Cache 对象 → 固定 KV 张量 ----
        past_k, past_v = cache_tensors_from_cache(
            past_key_values, self.config, input_ids.shape[0],
            input_ids.device, self.dtype, self.max_cache_len,
        )

        # ---- 4) 按 q_length 选桶, 调用 exported 图 ----
        if input_ids.shape[1] == 1 and self.decode_module is not None:
            module = self.decode_module
        else:
            module = self.exported_module
        hidden_states, new_k, new_v = module(
            input_ids=input_ids, attention_mask=attention_mask,
            past_k=past_k, past_v=past_v,
            cache_position=cache_position, position_ids=position_ids,
        )

        # ---- 5) 新增 KV 写回原 Cache 对象, 让 generate 下一步能继续用 ----
        updated_cache = apply_new_kv_to_cache(
            past_key_values, new_k, new_v, cache_position, self.config,
        )
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states, past_key_values=updated_cache,
        )

(f) apply_new_kv_to_cache — 把 sidecar 中的新增 KV 写回任意 Cache 类型

def apply_new_kv_to_cache(cache, new_k, new_v, cache_position, config):
    """把 exported 图输出的 sidecar new_k/v 按 cache_position 抽出本次写入的 slice,
       再通过标准 Cache.update 写回 generate 持有的 Cache 对象 (StaticCache/DynamicCache 等)。"""
    if cache is None:
        cache = DynamicCache(config=config)
    for layer_idx in range(new_k.shape[0]):
        # 从 sidecar 抽出本次新增 KV (sidecar 内除了 cache_position 之外都是零填充)
        layer_new_k = new_k[layer_idx].index_select(2, cache_position)
        layer_new_v = new_v[layer_idx].index_select(2, cache_position)
        # 用 Cache 标准接口写回, 自动兼容 StaticCache / DynamicCache / ...
        cache.update(layer_new_k, layer_new_v, layer_idx,
                     {"cache_position": cache_position})
    return cache

第四轮: 运行时 — 替换 text model, generate() 完全无感

目标

ExportedTextModelWrapper 挂到 model.model 上, 整个 generate 链路其它部分 (sampling、stopping、chat template ...) 都不需要任何修改。

(g) 替换 text model 一行实现"透明"

# 把双桶 ExportedProgram 包进 wrapper
wrapped_text_model = ExportedTextModelWrapper(
    prefill_exported.module(),
    base_model,
    args.max_cache_len,
    max_batch_size,
    decode_module=decode_exported.module(),
)

# 替换前: model.model = Qwen3Model (原始 PyTorch 实现)
# 替换后: model.model = ExportedTextModelWrapper (背后是两个 torch.export 出的图)
setattr(model, model.base_model_prefix, wrapped_text_model)

# generate 完全无感地用导出图推理 — sampling/beam search/chat template 全部照旧
output_ids = model.generate(
    **inputs, do_sample=True, temperature=0.7, top_p=0.9,
    max_new_tokens=128, use_cache=True, cache_implementation="static",
)

3. 验证

--random-weights 模式 (避免下载权重, 跑 CPU/MPS 都行), 对比 替换前 vs 替换后 的 generate 输出:

测试场景结果说明
greedy, 16 tokensPASSexported vs raw 逐 token 一致
greedy, 64 tokensPASS更长序列稳定性
sampling, temperature=0.7PASS相同 seed 下采样路径也一致
$ python qwen3_0_6b_infer.py --random-weights --export \
    --device cpu --dtype float32 \
    --max-new-tokens 16 --max-cache-len 128 --temperature 0

Tracing prefill bucket (q_len=13, dynamic seq)...
Tracing decode bucket (q_len=1, static seq)...
Exported model output:
pdatapdatapdatapdatapdatapdatadings estimates estimates estimates ...
Original model output:
pdatapdatapdatapdatapdatapdatadings estimates estimates estimates ...
Token ids match: True

4. 与 transformers 官方 torch.export 方案对比

维度本方案transformers executorch.py
包装层级只包装 text model包装整个 CausalLM
generate 兼容透明走 generate()不走 generate, 自己实现 token loop
Cache 策略CustomCache + export_new_k/v sidecarStaticCache + early_initialization
KV 传递显式 past_k/v 扁平张量Cache 作为 module attribute, 从图中隐藏
多 shape 支持分桶 dispatch (prefill/decode/...)固定 shape, 单图

5. 已知限制

6. 完整代码下载

📄 qwen3_0_6b_infer.py (完整实现)

运行: python qwen3_0_6b_infer.py --random-weights --export --device cpu --dtype float32 --max-new-tokens 16 --max-cache-len 128 --temperature 0


使用 Hypothesis 添加注释 (右上角侧边栏)