#!/usr/bin/env python3
import argparse

import torch

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import Cache, CacheLayerMixin, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast


class DummyTokenizer:
    def __init__(self, vocab_size: int):
        self.vocab_size = vocab_size
        self.pad_token_id = 0
        self.eos_token_id = 1

    def __call__(self, text: str, return_tensors: str = "pt"):
        if return_tensors != "pt":
            raise ValueError("DummyTokenizer only supports return_tensors='pt'.")
        encoded = text.encode("utf-8")
        ids = [b % self.vocab_size for b in encoded] or [self.eos_token_id]
        input_ids = torch.tensor([ids], dtype=torch.long)
        attention_mask = torch.ones_like(input_ids)
        return {"input_ids": input_ids, "attention_mask": attention_mask}

    def decode(self, ids, skip_special_tokens: bool = True):
        return " ".join(str(int(i)) for i in ids)


def get_kv_dims(config):
    head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
    return config.num_hidden_layers, config.num_key_value_heads, head_dim


def init_kv_tensors(config, batch_size: int, device: torch.device, dtype: torch.dtype, max_cache_len: int):
    num_layers, num_kv_heads, head_dim = get_kv_dims(config)
    shape = (num_layers, batch_size, num_kv_heads, max_cache_len, head_dim)
    past_k = torch.zeros(shape, device=device, dtype=dtype)
    past_v = torch.zeros(shape, device=device, dtype=dtype)
    return past_k, past_v


def ensure_model_buffers(model: torch.nn.Module, shape, device: torch.device, dtype: torch.dtype):
    if not hasattr(model, "export_new_k"):
        model.register_buffer("export_new_k", torch.zeros(shape, device=device, dtype=dtype), persistent=False)
    if not hasattr(model, "export_new_v"):
        model.register_buffer("export_new_v", torch.zeros(shape, device=device, dtype=dtype), persistent=False)
    if model.export_new_k.shape != shape or model.export_new_k.device != device or model.export_new_k.dtype != dtype:
        model.export_new_k = torch.zeros(shape, device=device, dtype=dtype)
    if model.export_new_v.shape != shape or model.export_new_v.device != device or model.export_new_v.dtype != dtype:
        model.export_new_v = torch.zeros(shape, device=device, dtype=dtype)


def scatter_update(target: torch.Tensor, cache_position: torch.Tensor, source: torch.Tensor) -> torch.Tensor:
    index = cache_position.view(1, 1, -1, 1).expand(source.shape)
    return target.scatter(2, index, source)


class CustomCacheLayer(CacheLayerMixin):
    is_compileable = True  # Match StaticLayer to avoid SDPA mask shortcut differences

    def __init__(self, past_k: torch.Tensor, past_v: torch.Tensor, layer_idx: int, model: torch.nn.Module, cache_position: torch.Tensor):
        super().__init__()
        self.layer_idx = layer_idx
        self.model = model
        self.keys = past_k[layer_idx]
        self.values = past_v[layer_idx]
        self.is_initialized = True
        self.device = self.keys.device
        self.dtype = self.keys.dtype
        self.max_cache_len = self.keys.shape[-2]
        # Derive cumulative_length from cache_position: the offset BEFORE this call's tokens are written
        # cache_position[0] tells where the first token of this call goes → that's the past seq length
        self.cumulative_length = cache_position[0:1]

    def lazy_initialization(self, key_states: torch.Tensor):
        return

    def update(self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs=None):
        # Mirror StaticLayer behavior: derive cache_position from self.cumulative_length.
        # Qwen3 attention calls cache.update(K, V, layer_idx) without forwarding cache_position
        # in cache_kwargs, so we cannot rely on the kwarg.
        kv_length = key_states.shape[-2]
        cache_position = torch.arange(kv_length, device=self.device) + self.cumulative_length
        batch_size = key_states.shape[0]
        if self.keys.device.type == "mps":
            self.keys = scatter_update(self.keys, cache_position, key_states)
            self.values = scatter_update(self.values, cache_position, value_states)
        else:
            try:
                self.keys[:batch_size].index_copy_(2, cache_position, key_states)
                self.values[:batch_size].index_copy_(2, cache_position, value_states)
            except NotImplementedError:
                self.keys[:batch_size, :, cache_position] = key_states
                self.values[:batch_size, :, cache_position] = value_states
        if hasattr(self.model, "export_new_k"):
            if self.model.export_new_k.device.type == "mps":
                self.model.export_new_k[self.layer_idx, :batch_size] = scatter_update(
                    self.model.export_new_k[self.layer_idx, :batch_size], cache_position, key_states
                )
                self.model.export_new_v[self.layer_idx, :batch_size] = scatter_update(
                    self.model.export_new_v[self.layer_idx, :batch_size], cache_position, value_states
                )
            else:
                try:
                    self.model.export_new_k[self.layer_idx, :batch_size].index_copy_(2, cache_position, key_states)
                    self.model.export_new_v[self.layer_idx, :batch_size].index_copy_(2, cache_position, value_states)
                except NotImplementedError:
                    self.model.export_new_k[self.layer_idx, :batch_size, :, cache_position] = key_states
                    self.model.export_new_v[self.layer_idx, :batch_size, :, cache_position] = value_states
        return self.keys, self.values

    def get_mask_sizes(self, cache_position: torch.Tensor):
        return self.max_cache_len, 0

    def get_seq_length(self):
        return self.cumulative_length

    def get_max_cache_shape(self):
        return self.max_cache_len


class CustomCache(Cache):
    def __init__(self, past_k: torch.Tensor, past_v: torch.Tensor, model: torch.nn.Module, cache_position: torch.Tensor):
        num_layers = past_k.shape[0]
        layers = [CustomCacheLayer(past_k, past_v, layer_idx, model, cache_position) for layer_idx in range(num_layers)]
        super().__init__(layers=layers)
        self.past_k = past_k
        self.past_v = past_v
        self.model = model
        self.cache_position = cache_position


def custom_cache_flatten(cache: CustomCache):
    return (cache.past_k, cache.past_v, cache.cache_position), cache.model


def custom_cache_flatten_with_keys(cache: CustomCache):
    return (
        (_pytree.SequenceKey(0), cache.past_k),
        (_pytree.SequenceKey(1), cache.past_v),
        (_pytree.SequenceKey(2), cache.cache_position),
    ), cache.model


def custom_cache_unflatten(values, context):
    past_k, past_v, cache_position = values
    return CustomCache(past_k, past_v, context, cache_position)


try:
    from torch.utils import _pytree

    _pytree.register_pytree_node(
        CustomCache,
        custom_cache_flatten,
        custom_cache_unflatten,
        flatten_with_keys_fn=custom_cache_flatten_with_keys,
    )
    torch.fx._pytree.register_pytree_flatten_spec(
        CustomCache, lambda cache, spec: torch.fx._pytree._tuple_flatten_spec((cache.past_k, cache.past_v, cache.cache_position), spec)
    )
except Exception as exc:
    raise RuntimeError(f"pytree registration failed: {exc}") from exc


class ExportableWrapper(torch.nn.Module):
    def __init__(self, text_model: torch.nn.Module, max_cache_len: int, max_batch_size: int):
        super().__init__()
        self.text_model = text_model
        self.config = getattr(text_model, "config", None)
        num_layers, num_kv_heads, head_dim = get_kv_dims(self.config)
        buffer_shape = (num_layers, max_batch_size, num_kv_heads, max_cache_len, head_dim)
        ensure_model_buffers(text_model, buffer_shape, text_model.embed_tokens.weight.device, text_model.embed_tokens.weight.dtype)
        self.max_cache_len = max_cache_len

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        past_k: torch.Tensor,
        past_v: torch.Tensor,
        cache_position: torch.Tensor,
        position_ids: torch.Tensor,
    ):
        self.text_model.export_new_k.zero_()
        self.text_model.export_new_v.zero_()
        cache = CustomCache(past_k, past_v, self.text_model, cache_position)
        outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=cache,
            use_cache=True,
            cache_position=cache_position,
            position_ids=position_ids,
        )
        return outputs.last_hidden_state, self.text_model.export_new_k, self.text_model.export_new_v


def cache_tensors_from_cache(
    cache: Cache | None,
    config,
    batch_size: int,
    device: torch.device,
    dtype: torch.dtype,
    max_cache_len: int,
):
    past_k, past_v = init_kv_tensors(config, batch_size, device, dtype, max_cache_len)
    if cache is None:
        return past_k, past_v
    for layer_idx, layer in enumerate(cache.layers):
        if layer.keys is None or layer.values is None:
            continue
        if not layer.is_initialized:
            continue
        layer_k = layer.keys
        layer_v = layer.values
        seq_len = layer_k.shape[-2]
        past_k[layer_idx, : layer_k.shape[0], :, :seq_len] = layer_k
        past_v[layer_idx, : layer_v.shape[0], :, :seq_len] = layer_v
    return past_k, past_v


def apply_new_kv_to_cache(
    cache: Cache | None,
    new_k: torch.Tensor,
    new_v: torch.Tensor,
    cache_position: torch.Tensor,
    config,
):
    if cache is None:
        cache = DynamicCache(config=config)
    for layer_idx in range(new_k.shape[0]):
        layer_new_k = new_k[layer_idx].index_select(2, cache_position)
        layer_new_v = new_v[layer_idx].index_select(2, cache_position)
        cache.update(layer_new_k, layer_new_v, layer_idx, {"cache_position": cache_position})
    return cache


class ExportedTextModelWrapper(torch.nn.Module):
    def __init__(
        self,
        exported_module: torch.nn.Module,
        base_model: torch.nn.Module,
        max_cache_len: int,
        max_batch_size: int,
        decode_module: torch.nn.Module | None = None,
    ):
        super().__init__()
        # prefill bucket (q_len > 1, dynamic seq); fallback for all shapes when decode_module is None.
        self.exported_module = exported_module
        # decode bucket (q_len == 1, static seq). When provided, dispatched for autoregressive steps.
        self.decode_module = decode_module
        self.embed_tokens = getattr(base_model, "embed_tokens", None)
        self.config = getattr(base_model, "config", None)
        self.max_cache_len = max_cache_len
        self.max_batch_size = max_batch_size

    @property
    def device(self):
        if self.embed_tokens is not None:
            return self.embed_tokens.weight.device
        return torch.device("cpu")

    @property
    def dtype(self):
        if self.embed_tokens is not None:
            return self.embed_tokens.weight.dtype
        return torch.float32

    def forward(
        self,
        input_ids: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        past_key_values: Cache | None = None,
        cache_position: torch.Tensor | None = None,
        use_cache: bool | None = None,
        position_ids: torch.Tensor | None = None,
        **kwargs,
    ) -> BaseModelOutputWithPast:
        if input_ids is None:
            raise ValueError("ExportedTextModelWrapper requires input_ids.")

        # Determine past_seen from the incoming cache. generate() does NOT always
        # forward cache_position to a custom text model, so we recompute it.
        if past_key_values is not None:
            past_seen_t = past_key_values.get_seq_length()
            if isinstance(past_seen_t, torch.Tensor):
                past_seen = int(past_seen_t.item()) if past_seen_t.numel() == 1 else int(past_seen_t[0].item())
            else:
                past_seen = int(past_seen_t)
        else:
            past_seen = 0

        q_len = input_ids.shape[1]

        if cache_position is None:
            cache_position = torch.arange(past_seen, past_seen + q_len, device=input_ids.device)
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        if attention_mask is None or isinstance(attention_mask, dict):
            # generate passes a pre-computed 4D dict mask. The exported module
            # needs a 2D mask of fixed length max_cache_len whose 1-valued prefix
            # marks the "valid KV region" (past_seen + current Q). Padding positions
            # in the static cache (yet-to-be-written zeros) must be marked 0,
            # otherwise attention would attend to zero KV and dilute the distribution.
            valid_len = past_seen + q_len
            arange = torch.arange(self.max_cache_len, device=input_ids.device)
            mask_1d = (arange < valid_len).to(input_ids.dtype)
            attention_mask = mask_1d.unsqueeze(0).expand(input_ids.shape[0], -1).contiguous()
        past_k, past_v = cache_tensors_from_cache(
            past_key_values,
            self.config,
            input_ids.shape[0],
            input_ids.device,
            self.dtype,
            self.max_cache_len,
        )
        # Dispatch by q_length: decode bucket (q=1) avoids the trace q_length-固化 bug
        # where prefill-traced graph mis-handles decode-step mask reshape.
        if input_ids.shape[1] == 1 and self.decode_module is not None:
            module = self.decode_module
        else:
            module = self.exported_module
        hidden_states, new_k, new_v = module(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_k=past_k,
            past_v=past_v,
            cache_position=cache_position,
            position_ids=position_ids,
        )
        updated_cache = None
        if use_cache is None or use_cache:
            updated_cache = apply_new_kv_to_cache(past_key_values, new_k, new_v, cache_position, self.config)
        return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=updated_cache)


def pick_device():
    if torch.cuda.is_available():
        return "cuda"
    if torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def pick_dtype(device: str, dtype_arg: str):
    if dtype_arg == "float16":
        return torch.float16
    if dtype_arg == "bfloat16":
        return torch.bfloat16
    if dtype_arg == "float32":
        return torch.float32
    if dtype_arg == "auto":
        return torch.float16 if device == "cuda" else torch.float32
    raise ValueError(f"Unsupported dtype: {dtype_arg}")


def main():
    parser = argparse.ArgumentParser(description="Qwen3-0.6B inference script (Transformers).")
    parser.add_argument(
        "--model",
        default="Qwen/Qwen3-0.6B",
        help="Model id or local path (default: Qwen/Qwen3-0.6B).",
    )
    parser.add_argument("--prompt", default="你好，介绍一下你自己。", help="Input prompt.")
    parser.add_argument("--max-new-tokens", type=int, default=128, help="Max new tokens to generate.")
    parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.")
    parser.add_argument("--top-p", type=float, default=0.9, help="Top-p sampling.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for generation.")
    parser.add_argument("--device", default=None, help="Device: cuda, mps, or cpu. Auto if not set.")
    parser.add_argument(
        "--dtype",
        default="auto",
        choices=["auto", "float16", "bfloat16", "float32"],
        help="Model dtype (default: auto).",
    )
    parser.add_argument(
        "--trust-remote-code",
        action="store_true",
        help="Allow loading remote code if the model requires it.",
    )
    parser.add_argument(
        "--local-files-only",
        action="store_true",
        help="Only use local files for model/tokenizer (offline mode).",
    )
    parser.add_argument(
        "--random-weights",
        action="store_true",
        help="Initialize model from config with random weights.",
    )
    parser.add_argument(
        "--export",
        action="store_true",
        help="Export the text model with torch.export before running inference (uses static cache).",
    )
    parser.add_argument(
        "--export-path",
        default=None,
        help="Optional path to save the exported program.",
    )
    parser.add_argument(
        "--export-strict",
        action="store_true",
        help="Use strict=True for torch.export.",
    )
    parser.add_argument(
        "--max-cache-len",
        type=int,
        default=4096,
        help="Maximum cache length for fixed-size KV tensors.",
    )
    parser.add_argument(
        "--max-batch-size",
        type=int,
        default=None,
        help="Maximum batch size for fixed-size KV tensors (defaults to input batch size).",
    )
    args = parser.parse_args()

    device = args.device or pick_device()
    dtype = pick_dtype(device, args.dtype)

    if args.random_weights:
        config = AutoConfig.from_pretrained(
            args.model,
            trust_remote_code=args.trust_remote_code,
            local_files_only=args.local_files_only,
        )
        model = AutoModelForCausalLM.from_config(config).to(device)
    else:
        model = AutoModelForCausalLM.from_pretrained(
            args.model,
            torch_dtype=dtype,
            trust_remote_code=args.trust_remote_code,
            local_files_only=args.local_files_only,
        ).to(device)

    try:
        tokenizer = AutoTokenizer.from_pretrained(
            args.model,
            trust_remote_code=args.trust_remote_code,
            local_files_only=args.local_files_only,
        )
        tokenizer_is_dummy = False
    except Exception as exc:
        if not args.random_weights:
            raise
        tokenizer = DummyTokenizer(getattr(model.config, "vocab_size", 32000))
        tokenizer_is_dummy = True
        print(f"Tokenizer load failed, using DummyTokenizer: {exc}")
    model.eval()

    if not tokenizer_is_dummy and getattr(tokenizer, "chat_template", None):
        messages = [{"role": "user", "content": args.prompt}]
        input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(input_text, return_tensors="pt").to(device)
    else:
        inputs = tokenizer(args.prompt, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}

    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    if args.export:
        base_model = getattr(model, model.base_model_prefix, model)
        max_batch_size = args.max_batch_size or inputs["input_ids"].shape[0]
        exportable = ExportableWrapper(base_model, args.max_cache_len, max_batch_size)
        exportable.eval()
        if inputs["attention_mask"].ndim != 2:
            raise ValueError("Exported text model expects a 2D attention_mask for static cache.")
        device_ = inputs["input_ids"].device
        embed_dtype = exportable.text_model.embed_tokens.weight.dtype

        def _mask_with_valid_len(valid_len: int) -> torch.Tensor:
            arange_ = torch.arange(args.max_cache_len, device=device_)
            m1d = (arange_ < valid_len).to(inputs["input_ids"].dtype)
            return m1d.unsqueeze(0).expand(max_batch_size, -1).contiguous()

        def _trace_bucket(q_len_for_trace: int, is_decode: bool):
            """Trace exportable with given q_length.

            Decode bucket: q_len=1 with no dynamic seq dim (avoids trace q_length
            固化 bug — the graph then takes the SDPA `is_causal` short-circuit path
            that's correct for autoregressive decode).
            Prefill bucket: q_len=prompt_len with dynamic seq in [1, max_cache_len-1].
            """
            ids_ = torch.zeros((max_batch_size, q_len_for_trace), dtype=torch.long, device=device_)
            cp_ = torch.arange(q_len_for_trace, device=device_)
            pid_ = cp_.unsqueeze(0)
            mask_ = _mask_with_valid_len(q_len_for_trace)
            pk_, pv_ = init_kv_tensors(
                base_model.config, max_batch_size, device_, embed_dtype, args.max_cache_len,
            )
            if is_decode:
                dyn = {
                    "input_ids": {}, "attention_mask": {},
                    "past_k": {}, "past_v": {},
                    "cache_position": {}, "position_ids": {},
                }
            else:
                seq_dim_ = torch.export.Dim("seq", min=1, max=max(1, args.max_cache_len - 1))
                dyn = {
                    "input_ids": {1: seq_dim_}, "attention_mask": {},
                    "past_k": {}, "past_v": {},
                    "cache_position": {0: seq_dim_}, "position_ids": {1: seq_dim_},
                }
            return torch.export.export(
                exportable, args=(), kwargs={
                    "input_ids": ids_, "attention_mask": mask_,
                    "past_k": pk_, "past_v": pv_,
                    "cache_position": cp_, "position_ids": pid_,
                }, dynamic_shapes=dyn, strict=args.export_strict,
            )

        prompt_len = inputs["input_ids"].shape[1]
        print(f"Tracing prefill bucket (q_len={prompt_len}, dynamic seq)...")
        prefill_exported = _trace_bucket(prompt_len, is_decode=False)
        print(f"Tracing decode bucket (q_len=1, static seq)...")
        decode_exported = _trace_bucket(1, is_decode=True)

        if args.export_path:
            prefill_exported.save(args.export_path + ".prefill")
            decode_exported.save(args.export_path + ".decode")

        wrapped_text_model = ExportedTextModelWrapper(
            prefill_exported.module(),
            base_model,
            args.max_cache_len,
            max_batch_size,
            decode_module=decode_exported.module(),
        )
        original_text_model = base_model
        if getattr(model, "base_model_prefix", None):
            setattr(model, model.base_model_prefix, wrapped_text_model)
        # Use static cache to match fixed-size KV tensors from the exported text model.
        cache_implementation = "static"
        do_sample = args.temperature > 0
        gen_kwargs = dict(
            max_new_tokens=args.max_new_tokens,
            pad_token_id=tokenizer.pad_token_id,
            use_cache=True,
            cache_implementation=cache_implementation,
        )
        if do_sample:
            gen_kwargs.update(do_sample=True, temperature=args.temperature, top_p=args.top_p)
        else:
            gen_kwargs.update(do_sample=False)

        with torch.inference_mode():
            torch.manual_seed(args.seed)
            export_output_ids = model.generate(**inputs, **gen_kwargs)

        if getattr(model, "base_model_prefix", None):
            setattr(model, model.base_model_prefix, original_text_model)
        with torch.inference_mode():
            torch.manual_seed(args.seed)
            raw_output_ids = model.generate(**inputs, **gen_kwargs)

        export_generated = export_output_ids[0, inputs["input_ids"].shape[-1] :]
        raw_generated = raw_output_ids[0, inputs["input_ids"].shape[-1] :]
        export_text = tokenizer.decode(export_generated, skip_special_tokens=True)
        raw_text = tokenizer.decode(raw_generated, skip_special_tokens=True)
        print("Exported model output:")
        print(export_text)
        print("Original model output:")
        print(raw_text)
        match = torch.equal(export_generated, raw_generated)
        print(f"Token ids match: {match}")
        return


if __name__ == "__main__":
    main()
