把 LLM 部署到固定 shape / 静态 KV cache 的推理后端 (NPU、edge 芯片、ExecuTorch 等) 时, 模型必须通过 torch.export 导出为固定形状的计算图。一种直接做法是导出图后自己写 token-by-token 推理循环 (类似 transformers 官方 TorchExportableModuleWithStaticCache + ExecuTorch 方案), 但这样:
model.generate(...)本方案目标是 "透明": 让导出后的图伪装成一个普通 text model, 仍然挂在 Qwen3ForCausalLM 下, 仍然能被 generate() 调用。算法/评测/调用方完全无感, 底层却已经是静态图推理。
核心策略: 只替换 model.model (text model), 不替换 Qwen3ForCausalLM — 保留 lm_head、generate() 等所有上层逻辑不动, 仅在 text model 这一层完成「Cache 对象 ↔ 固定 KV 张量」的桥接。
下图展示 generate 调用一次 forward 时, ExportedTextModelWrapper 内部如何完成 Cache ↔ 固定 KV 的双向桥接:
整套方案由四轮配合完成: 导出前把 text model 包成扁平 IO 接口 → 导出分桶 trace → 导出后桥接回 CausalLM → 运行时由 wrapper dispatch 双桶。下面逐轮讲。
目标
torch.export 无法直接追踪 HF 的动态 Cache 对象 (随 token 数增长), 也不能跟踪可变 buffer 的副作用。需要把 text model 的 IO 重新包成: 纯 tensor in, 纯 tensor out。
(a) ExportableWrapper — trace 的真正对象
class ExportableWrapper(torch.nn.Module):
"""把 text model 包成扁平 IO 接口, 作为 torch.export 的追踪对象。"""
def __init__(self, text_model, max_cache_len, max_batch_size):
super().__init__()
self.text_model = text_model
num_layers, num_kv_heads, head_dim = get_kv_dims(text_model.config)
# 固定大小的 KV buffer 形状: (层数, batch, KV head 数, 最大缓存长度, head 维度)
buffer_shape = (num_layers, max_batch_size, num_kv_heads, max_cache_len, head_dim)
# 在 text_model 上注册 sidecar buffer (用于存放"本次新增 KV", 见 (b))
ensure_model_buffers(text_model, buffer_shape, device, dtype)
def forward(self, input_ids, attention_mask, past_k, past_v,
cache_position, position_ids):
# 清零 sidecar buffer, 准备记录本次 update 写入的 KV
self.text_model.export_new_k.zero_()
self.text_model.export_new_v.zero_()
# 用固定 KV 张量构造 CustomCache, 让 text model 仍然能用 past_key_values 接口
cache = CustomCache(past_k, past_v, self.text_model, cache_position)
outputs = self.text_model(
input_ids=input_ids, attention_mask=attention_mask,
past_key_values=cache, use_cache=True,
cache_position=cache_position, position_ids=position_ids,
)
# 三个扁平 tensor 输出: 隐藏状态 + 本次新增的 KV (从 sidecar 取出)
return outputs.last_hidden_state, \
self.text_model.export_new_k, self.text_model.export_new_v
(b) CustomCacheLayer.update — 方案的"灵魂", 双写 past_k 和 sidecar
class CustomCacheLayer(CacheLayerMixin):
is_compileable = True # 与 StaticLayer 对齐, 避免 SDPA mask 优化路径差异
def __init__(self, past_k, past_v, layer_idx, model, cache_position):
super().__init__()
self.layer_idx = layer_idx
self.model = model
# 本层的 KV view, 指向外部传入的固定大小 past_k/past_v
self.keys = past_k[layer_idx]
self.values = past_v[layer_idx]
self.max_cache_len = self.keys.shape[-2]
# cumulative_length 必须是 tensor (而非 Python int), 这样 torch.export 才能
# 把它当作数据依赖, 在不同 cache_position 下走出不同的 graph 行为
self.cumulative_length = cache_position[0:1]
def update(self, key_states, value_states, cache_kwargs=None):
# Qwen3 attention 调 cache.update(K, V, layer_idx) 时不传 cache_position
# kwarg, 因此必须像 StaticLayer 一样自己从 cumulative_length 推
kv_length = key_states.shape[-2]
cache_position = torch.arange(kv_length, device=self.device) + self.cumulative_length
# 写入 1: 把新 KV 写到固定 past_k 的对应位置 (供本层 attention 读)
self.keys[:batch_size].index_copy_(2, cache_position, key_states)
self.values[:batch_size].index_copy_(2, cache_position, value_states)
# 写入 2: 同时写到 sidecar buffer export_new_k (供 export 图作为显式输出)
# 这样 ExportedProgram 才能把"本次新增 KV"作为返回值导出
self.model.export_new_k[self.layer_idx, :batch_size].index_copy_(
2, cache_position, key_states)
self.model.export_new_v[self.layer_idx, :batch_size].index_copy_(
2, cache_position, value_states)
return self.keys, self.values
def get_seq_length(self):
# 返回 tensor 而不是 Python int, 让 torch.export 能追踪到数据依赖,
# 否则 graph 内 mask 创建的 q_offset 会被 bake 为常量
return self.cumulative_length
(c) pytree 注册 — 让 torch.export 能识别 CustomCache
# torch.export 把所有非 Tensor 类型展平 (flatten) 为 tensor leaves。
# CustomCache 是自定义类, 必须告诉 pytree 如何拆/装。
def custom_cache_flatten(cache):
# leaves = 数据 tensor, context = 重建时需要的对象 (这里是 model 引用)
return (cache.past_k, cache.past_v, cache.cache_position), cache.model
def custom_cache_unflatten(values, context):
past_k, past_v, cache_position = values
return CustomCache(past_k, past_v, context, cache_position)
_pytree.register_pytree_node(CustomCache, custom_cache_flatten, custom_cache_unflatten, ...)
目标
torch.export trace 时会把 graph 内部按当时的 q_length 选定一条数据路径并固化。如果只 trace 一次 (q_length = prompt_len), 运行 decode 步 (q_length = 1) 时就会走错路径。解决: 按 shape 分桶 — prefill 和 decode 各 trace 一个 ExportedProgram, 运行时按 q_length dispatch。
(d) trace 工厂 — 用同一个 ExportableWrapper trace 两次, 给不同 dynamic_shapes
def _trace_bucket(q_len_for_trace: int, is_decode: bool):
"""同一个 ExportableWrapper, trace 两次, 给不同 dynamic_shapes。"""
# 占位输入 (内容无所谓, 形状决定 trace 的特化路径)
ids_ = torch.zeros((max_batch_size, q_len_for_trace), dtype=torch.long, device=device)
cp_ = torch.arange(q_len_for_trace, device=device)
pid_ = cp_.unsqueeze(0)
mask_ = _mask_with_valid_len(q_len_for_trace) # prefix-1, 长度 = max_cache_len
pk_, pv_ = init_kv_tensors(...)
if is_decode:
# decode 桶: q_len=1 静态, 与运行时一致, graph 内走 SDPA is_causal 短路
dyn = {"input_ids": {}, "attention_mask": {}, "past_k": {}, "past_v": {},
"cache_position": {}, "position_ids": {}}
else:
# prefill 桶: seq 动态, prompt 长度可以变
seq_dim_ = torch.export.Dim("seq", min=1, max=max_cache_len - 1)
dyn = {"input_ids": {1: seq_dim_}, "attention_mask": {},
"past_k": {}, "past_v": {},
"cache_position": {0: seq_dim_}, "position_ids": {1: seq_dim_}}
return torch.export.export(
exportable, args=(), kwargs={
"input_ids": ids_, "attention_mask": mask_,
"past_k": pk_, "past_v": pv_,
"cache_position": cp_, "position_ids": pid_,
}, dynamic_shapes=dyn, strict=args.export_strict,
)
# trace 两个桶
prefill_exported = _trace_bucket(prompt_len, is_decode=False)
decode_exported = _trace_bucket(1, is_decode=True)
生产部署可按需扩展为多桶 (prefill_64/128/256, decode_1), dispatch 逻辑不变。
目标
exported 图只接受扁平 tensor IO, 但 generate() 会传 Cache 对象、dict 形态的 4D mask 等复杂结构。需要一个 wrapper 做双向转换: 进来时把 Cache 拆成 past_k/v, 出去时把新 KV 写回 Cache。
(e) ExportedTextModelWrapper.forward — 桥接逻辑 + dispatch
class ExportedTextModelWrapper(torch.nn.Module):
def __init__(self, exported_module, base_model, max_cache_len, max_batch_size,
decode_module=None):
super().__init__()
# prefill 桶 (q_len > 1, 动态 seq)
self.exported_module = exported_module
# decode 桶 (q_len == 1, 静态), 由 forward 中按 q_length dispatch
self.decode_module = decode_module
...
def forward(self, input_ids=None, attention_mask=None, past_key_values=None,
cache_position=None, position_ids=None, use_cache=None, **kwargs):
# ---- 1) 从 Cache 推 past_seen, 构造正确的 cache_position / position_ids ----
# generate() 不一定把 cache_position 转发给自定义 text model, 自己重建更可靠
if past_key_values is not None:
past_seen_t = past_key_values.get_seq_length()
past_seen = int(past_seen_t.item()) if past_seen_t.numel() == 1 \
else int(past_seen_t[0].item())
else:
past_seen = 0
q_len = input_ids.shape[1]
if cache_position is None:
cache_position = torch.arange(past_seen, past_seen + q_len, device=input_ids.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# ---- 2) 把 generate 传入的 dict 4D mask 转成 exported 图能接受的 2D mask ----
# 关键: 长度固定为 max_cache_len, 但 1-valued 前缀只到 past_seen + Q,
# 后面填 0, 让 attention 知道"零填充位是 invalid"。
if attention_mask is None or isinstance(attention_mask, dict):
valid_len = past_seen + q_len
arange = torch.arange(self.max_cache_len, device=input_ids.device)
mask_1d = (arange < valid_len).to(input_ids.dtype)
attention_mask = mask_1d.unsqueeze(0).expand(input_ids.shape[0], -1).contiguous()
# ---- 3) Cache 对象 → 固定 KV 张量 ----
past_k, past_v = cache_tensors_from_cache(
past_key_values, self.config, input_ids.shape[0],
input_ids.device, self.dtype, self.max_cache_len,
)
# ---- 4) 按 q_length 选桶, 调用 exported 图 ----
if input_ids.shape[1] == 1 and self.decode_module is not None:
module = self.decode_module
else:
module = self.exported_module
hidden_states, new_k, new_v = module(
input_ids=input_ids, attention_mask=attention_mask,
past_k=past_k, past_v=past_v,
cache_position=cache_position, position_ids=position_ids,
)
# ---- 5) 新增 KV 写回原 Cache 对象, 让 generate 下一步能继续用 ----
updated_cache = apply_new_kv_to_cache(
past_key_values, new_k, new_v, cache_position, self.config,
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states, past_key_values=updated_cache,
)
(f) apply_new_kv_to_cache — 把 sidecar 中的新增 KV 写回任意 Cache 类型
def apply_new_kv_to_cache(cache, new_k, new_v, cache_position, config):
"""把 exported 图输出的 sidecar new_k/v 按 cache_position 抽出本次写入的 slice,
再通过标准 Cache.update 写回 generate 持有的 Cache 对象 (StaticCache/DynamicCache 等)。"""
if cache is None:
cache = DynamicCache(config=config)
for layer_idx in range(new_k.shape[0]):
# 从 sidecar 抽出本次新增 KV (sidecar 内除了 cache_position 之外都是零填充)
layer_new_k = new_k[layer_idx].index_select(2, cache_position)
layer_new_v = new_v[layer_idx].index_select(2, cache_position)
# 用 Cache 标准接口写回, 自动兼容 StaticCache / DynamicCache / ...
cache.update(layer_new_k, layer_new_v, layer_idx,
{"cache_position": cache_position})
return cache
目标
把 ExportedTextModelWrapper 挂到 model.model 上, 整个 generate 链路其它部分 (sampling、stopping、chat template ...) 都不需要任何修改。
(g) 替换 text model 一行实现"透明"
# 把双桶 ExportedProgram 包进 wrapper
wrapped_text_model = ExportedTextModelWrapper(
prefill_exported.module(),
base_model,
args.max_cache_len,
max_batch_size,
decode_module=decode_exported.module(),
)
# 替换前: model.model = Qwen3Model (原始 PyTorch 实现)
# 替换后: model.model = ExportedTextModelWrapper (背后是两个 torch.export 出的图)
setattr(model, model.base_model_prefix, wrapped_text_model)
# generate 完全无感地用导出图推理 — sampling/beam search/chat template 全部照旧
output_ids = model.generate(
**inputs, do_sample=True, temperature=0.7, top_p=0.9,
max_new_tokens=128, use_cache=True, cache_implementation="static",
)
用 --random-weights 模式 (避免下载权重, 跑 CPU/MPS 都行), 对比 替换前 vs 替换后 的 generate 输出:
| 测试场景 | 结果 | 说明 |
|---|---|---|
| greedy, 16 tokens | PASS | exported vs raw 逐 token 一致 |
| greedy, 64 tokens | PASS | 更长序列稳定性 |
| sampling, temperature=0.7 | PASS | 相同 seed 下采样路径也一致 |
$ python qwen3_0_6b_infer.py --random-weights --export \
--device cpu --dtype float32 \
--max-new-tokens 16 --max-cache-len 128 --temperature 0
Tracing prefill bucket (q_len=13, dynamic seq)...
Tracing decode bucket (q_len=1, static seq)...
Exported model output:
pdatapdatapdatapdatapdatapdatadings estimates estimates estimates ...
Original model output:
pdatapdatapdatapdatapdatapdatadings estimates estimates estimates ...
Token ids match: True
| 维度 | 本方案 | transformers executorch.py |
|---|---|---|
| 包装层级 | 只包装 text model | 包装整个 CausalLM |
| generate 兼容 | 透明走 generate() | 不走 generate, 自己实现 token loop |
| Cache 策略 | CustomCache + export_new_k/v sidecar | StaticCache + early_initialization |
| KV 传递 | 显式 past_k/v 扁平张量 | Cache 作为 module attribute, 从图中隐藏 |
| 多 shape 支持 | 分桶 dispatch (prefill/decode/...) | 固定 shape, 单图 |
cache_tensors_from_cache() 每步分配 (layers, batch, heads, max_cache_len, dim) 张量, decode 阶段开销显著CacheLayerMixin 的方法签名可能变化