"""Single-slot VRAM cache for Qwen3-TTS model variants.

Only one model variant is loaded at a time. Switching variants unloads the
current model, frees VRAM, then loads the new one (~10-15s swap).
"""

from __future__ import annotations

import logging
import threading
from enum import Enum
from typing import Any

import numpy as np
import torch
from qwen_tts import Qwen3TTSModel

logger = logging.getLogger(__name__)

SAMPLE_RATE = 24000

# Default generation kwargs (from Qwen3-TTS examples)
DEFAULT_GEN_KWARGS: dict[str, Any] = dict(
    max_new_tokens=2048,
    do_sample=True,
    top_k=50,
    top_p=1.0,
    temperature=0.9,
    repetition_penalty=1.05,
)


class ModelVariant(str, Enum):
    BASE = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
    VOICE_DESIGN = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign"
    CUSTOM_VOICE = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"


class TTSEngine:
    """Thread-safe single-slot model manager."""

    def __init__(self, device: str = "cuda:0") -> None:
        self._device = device
        self._lock = threading.Lock()
        self._model: Qwen3TTSModel | None = None
        self._loaded_variant: ModelVariant | None = None

    @property
    def loaded_variant(self) -> str | None:
        return self._loaded_variant.value if self._loaded_variant else None

    def _load(self, variant: ModelVariant) -> Qwen3TTSModel:
        """Load a model variant, unloading the current one first if needed."""
        if self._loaded_variant == variant and self._model is not None:
            return self._model

        # Unload current
        if self._model is not None:
            logger.info("Unloading %s", self._loaded_variant)
            del self._model
            self._model = None
            self._loaded_variant = None
            torch.cuda.empty_cache()

        logger.info("Loading %s", variant.value)
        kwargs: dict[str, Any] = dict(
            device_map=self._device,
            dtype=torch.bfloat16,
        )
        try:
            kwargs["attn_implementation"] = "flash_attention_2"
            model = Qwen3TTSModel.from_pretrained(variant.value, **kwargs)
        except Exception:
            # Fall back without flash attention
            logger.warning("flash_attention_2 unavailable, falling back to default")
            kwargs.pop("attn_implementation")
            model = Qwen3TTSModel.from_pretrained(variant.value, **kwargs)

        self._model = model
        self._loaded_variant = variant
        logger.info("Loaded %s", variant.value)
        return model

    # ── Public generation methods ───────────────────────────────────

    def generate_voice_clone(
        self,
        text: str | list[str],
        language: str | list[str],
        *,
        ref_audio: str | tuple[np.ndarray, int] | None = None,
        ref_text: str | list[str] | None = None,
        voice_clone_prompt: Any | None = None,
    ) -> tuple[list[np.ndarray], int]:
        with self._lock:
            model = self._load(ModelVariant.BASE)
            kwargs: dict[str, Any] = dict(text=text, language=language, **DEFAULT_GEN_KWARGS)
            if voice_clone_prompt is not None:
                kwargs["voice_clone_prompt"] = voice_clone_prompt
            else:
                kwargs["ref_audio"] = ref_audio
                kwargs["ref_text"] = ref_text
            return model.generate_voice_clone(**kwargs)

    def create_voice_clone_prompt(
        self,
        ref_audio: str | tuple[np.ndarray, int],
        ref_text: str,
    ) -> Any:
        with self._lock:
            model = self._load(ModelVariant.BASE)
            return model.create_voice_clone_prompt(
                ref_audio=ref_audio,
                ref_text=ref_text,
            )

    def generate_voice_design(
        self,
        text: str | list[str],
        language: str | list[str],
        instruct: str | list[str],
    ) -> tuple[list[np.ndarray], int]:
        with self._lock:
            model = self._load(ModelVariant.VOICE_DESIGN)
            return model.generate_voice_design(
                text=text, language=language, instruct=instruct, **DEFAULT_GEN_KWARGS
            )

    def generate_custom_voice(
        self,
        text: str | list[str],
        language: str | list[str],
        speaker: str | list[str],
        instruct: str | list[str] = "",
    ) -> tuple[list[np.ndarray], int]:
        with self._lock:
            model = self._load(ModelVariant.CUSTOM_VOICE)
            return model.generate_custom_voice(
                text=text, language=language, speaker=speaker,
                instruct=instruct, **DEFAULT_GEN_KWARGS,
            )

    def get_supported_speakers(self) -> list[str]:
        with self._lock:
            model = self._load(ModelVariant.CUSTOM_VOICE)
            return model.get_supported_speakers()

    def gpu_memory_stats(self) -> dict[str, float]:
        if not torch.cuda.is_available():
            return {}
        return {
            "allocated_mb": round(torch.cuda.memory_allocated() / 1024**2, 1),
            "reserved_mb": round(torch.cuda.memory_reserved() / 1024**2, 1),
        }


# Module-level singleton
engine = TTSEngine()
