"""FastAPI application for Qwen3-TTS service."""

from __future__ import annotations

import io
import logging
import subprocess

import numpy as np
import soundfile as sf
from fastapi import FastAPI, File, Form, HTTPException, Query, UploadFile
from fastapi.responses import Response

import dialog
import voice_store
from engine import SAMPLE_RATE, engine
from models import (
    CustomVoiceRequest,
    DialogRequest,
    GenerateRequest,
    HealthResponse,
    VoiceDesignRequest,
    VoiceInfo,
)

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)

app = FastAPI(title="Qwen3-TTS", version="1.0.0")

# Cache of built-in speaker names (populated on first use)
_builtin_speakers: set[str] | None = None


def _get_builtin_speakers() -> set[str]:
    global _builtin_speakers
    if _builtin_speakers is None:
        _builtin_speakers = set(engine.get_supported_speakers())
    return _builtin_speakers


def _audio_response(audio: np.ndarray, sr: int = SAMPLE_RATE, fmt: str = "wav") -> Response:
    """Convert a numpy waveform to a WAV or MP3 HTTP response."""
    wav_buf = io.BytesIO()
    sf.write(wav_buf, audio, sr, format="WAV")
    wav_buf.seek(0)

    if fmt == "mp3":
        result = subprocess.run(
            ["ffmpeg", "-i", "pipe:0", "-f", "mp3", "-ab", "128k", "pipe:1"],
            input=wav_buf.read(), capture_output=True,
        )
        if result.returncode != 0:
            raise HTTPException(500, "MP3 conversion failed")
        return Response(content=result.stdout, media_type="audio/mpeg")

    return Response(content=wav_buf.read(), media_type="audio/wav")


# ── Health / info ───────────────────────────────────────────────────


@app.get("/health", response_model=HealthResponse)
def health() -> HealthResponse:
    stats = engine.gpu_memory_stats()
    return HealthResponse(
        status="ok",
        loaded_model=engine.loaded_variant,
        gpu_memory_allocated_mb=stats.get("allocated_mb"),
        gpu_memory_reserved_mb=stats.get("reserved_mb"),
    )


@app.get("/speakers")
def speakers() -> list[str]:
    return sorted(_get_builtin_speakers())


# ── Voice management ────────────────────────────────────────────────


@app.get("/voices", response_model=list[VoiceInfo])
def list_voices() -> list[VoiceInfo]:
    return [VoiceInfo(**v) for v in voice_store.list_voices()]


@app.post("/voices/clone", response_model=VoiceInfo)
async def clone_voice(
    name: str = Form(...),
    ref_text: str = Form(...),
    language: str = Form("Auto"),
    file: UploadFile = File(...),
) -> VoiceInfo:
    if voice_store.exists(name):
        raise HTTPException(400, f"Voice '{name}' already exists")

    # Read uploaded audio
    audio_bytes = await file.read()
    buf = io.BytesIO(audio_bytes)
    try:
        ref_audio, file_sr = sf.read(buf)
    except Exception as e:
        raise HTTPException(400, f"Could not read audio file: {e}")
    ref_audio = ref_audio.astype(np.float32)

    # Create voice clone prompt
    prompt = engine.create_voice_clone_prompt(
        ref_audio=(ref_audio, file_sr),
        ref_text=ref_text,
    )

    voice_store.save_clone_profile(
        name=name,
        ref_audio=ref_audio,
        sample_rate=file_sr,
        ref_text=ref_text,
        prompt=prompt,
        language=language,
    )
    logger.info("Created clone voice profile: %s", name)
    return VoiceInfo(**voice_store.get_metadata(name))


@app.post("/voices/design", response_model=VoiceInfo)
def design_voice(req: VoiceDesignRequest) -> VoiceInfo:
    if voice_store.exists(req.name):
        raise HTTPException(400, f"Voice '{req.name}' already exists")

    # Step 1: Generate sample audio with VoiceDesign model
    wavs, sr = engine.generate_voice_design(
        text=req.sample_text,
        language=req.language,
        instruct=req.instruct,
    )
    ref_audio = wavs[0]

    # Step 2: Create voice clone prompt from the generated sample (uses Base model)
    prompt = engine.create_voice_clone_prompt(
        ref_audio=(ref_audio, sr),
        ref_text=req.sample_text,
    )

    voice_store.save_design_profile(
        name=req.name,
        ref_audio=ref_audio,
        sample_rate=sr,
        ref_text=req.sample_text,
        prompt=prompt,
        instruct=req.instruct,
        language=req.language,
    )
    logger.info("Created designed voice profile: %s", req.name)
    return VoiceInfo(**voice_store.get_metadata(req.name))


@app.get("/voices/{name}/audio")
def get_voice_audio(name: str, format: str = Query("wav")) -> Response:
    if not voice_store.exists(name):
        raise HTTPException(404, f"Voice '{name}' not found")
    path = voice_store.get_reference_audio_path(name)
    audio, sr = sf.read(str(path))
    return _audio_response(audio.astype(np.float32), sr, fmt=format)


@app.delete("/voices/{name}")
def delete_voice(name: str) -> dict[str, str]:
    if not voice_store.exists(name):
        raise HTTPException(404, f"Voice '{name}' not found")
    voice_store.delete_voice(name)
    logger.info("Deleted voice profile: %s", name)
    return {"status": "deleted", "name": name}


# ── TTS generation ──────────────────────────────────────────────────


@app.post("/tts/generate")
def tts_generate(req: GenerateRequest) -> Response:
    if not voice_store.exists(req.voice):
        raise HTTPException(404, f"Voice '{req.voice}' not found")

    prompt = voice_store.load_prompt(req.voice)
    wavs, sr = engine.generate_voice_clone(
        text=req.text,
        language=req.language,
        voice_clone_prompt=prompt,
    )
    return _audio_response(wavs[0], sr, fmt=req.format)


@app.post("/tts/custom")
def tts_custom(req: CustomVoiceRequest) -> Response:
    speaker = req.speaker.lower()
    if speaker not in _get_builtin_speakers():
        raise HTTPException(400, f"Unknown speaker '{req.speaker}'. Use GET /speakers for list.")

    wavs, sr = engine.generate_custom_voice(
        text=req.text,
        language=req.language,
        speaker=speaker,
        instruct=req.instruct,
    )
    return _audio_response(wavs[0], sr, fmt=req.format)


# ── Dialog ──────────────────────────────────────────────────────────


@app.post("/tts/dialog")
def tts_dialog(req: DialogRequest) -> Response:
    try:
        audio, sr = dialog.generate_dialog(
            engine=engine,
            lines=req.lines,
            builtin_speakers=_get_builtin_speakers(),
        )
    except ValueError as e:
        raise HTTPException(400, str(e))
    return _audio_response(audio, sr, fmt=req.format)
