"""Dialog script parser and multi-character audio assembly.

Script format:
    # Comments start with #
    [narrator_morgan] In a world where code compiles on the first try...
    [young_hero] Is that even possible?
    [narrator_morgan] No. No it is not. :: pause=1500
    [Ryan] Let me show you something.

Character names reference saved voice profiles OR built-in speakers.
Lines ending with `:: pause=<ms>` override the default inter-line pause.
"""

from __future__ import annotations

import re
from dataclasses import dataclass

import numpy as np

from engine import TTSEngine, SAMPLE_RATE
from models import DialogLine
import voice_store

LINE_RE = re.compile(r"^\[([^\]]+)\]\s*(.+)$")
PAUSE_RE = re.compile(r"::\s*pause\s*=\s*(\d+)\s*$")


@dataclass
class _ResolvedLine:
    """A dialog line with its generation strategy resolved."""
    index: int  # original order
    character: str
    text: str
    pause_ms: int
    is_builtin: bool  # True → CustomVoice model, False → Base model (clone)


def parse_script(text: str, default_pause_ms: int = 500) -> list[DialogLine]:
    """Parse a dialog script string into DialogLine objects."""
    lines: list[DialogLine] = []
    for raw in text.splitlines():
        raw = raw.strip()
        if not raw or raw.startswith("#"):
            continue
        m = LINE_RE.match(raw)
        if not m:
            continue
        character = m.group(1).strip()
        body = m.group(2).strip()

        pause_ms = default_pause_ms
        pm = PAUSE_RE.search(body)
        if pm:
            pause_ms = int(pm.group(1))
            body = body[: pm.start()].strip()

        lines.append(DialogLine(character=character, text=body, pause_ms=pause_ms))
    return lines


def generate_dialog(
    engine: TTSEngine,
    lines: list[DialogLine],
    builtin_speakers: set[str],
) -> tuple[np.ndarray, int]:
    """Generate combined audio for a dialog, minimizing model swaps.

    Groups lines by model type (builtin vs. saved voice) and generates each
    group contiguously, then reassembles in original order.
    """
    # Resolve each line (speaker names are lowercase in the model)
    builtin_lower = {s.lower() for s in builtin_speakers}
    resolved: list[_ResolvedLine] = []
    for i, line in enumerate(lines):
        is_builtin = line.character.lower() in builtin_lower
        if not is_builtin and not voice_store.exists(line.character):
            raise ValueError(
                f"Character '{line.character}' is not a built-in speaker "
                f"or saved voice profile"
            )
        resolved.append(_ResolvedLine(
            index=i,
            character=line.character,
            text=line.text,
            pause_ms=line.pause_ms,
            is_builtin=is_builtin,
        ))

    # Generate audio — group by model type to minimize swaps
    audio_clips: dict[int, np.ndarray] = {}

    # Built-in speaker lines (CustomVoice model)
    builtin_lines = [r for r in resolved if r.is_builtin]
    if builtin_lines:
        wavs, _ = engine.generate_custom_voice(
            text=[r.text for r in builtin_lines],
            language=["Auto"] * len(builtin_lines),
            speaker=[r.character.lower() for r in builtin_lines],
        )
        for r, wav in zip(builtin_lines, wavs):
            audio_clips[r.index] = wav

    # Saved voice lines (Base model) — group by character for prompt reuse
    clone_lines = [r for r in resolved if not r.is_builtin]
    chars = sorted(set(r.character for r in clone_lines))
    for char in chars:
        char_lines = [r for r in clone_lines if r.character == char]
        prompt = voice_store.load_prompt(char)
        wavs, _ = engine.generate_voice_clone(
            text=[r.text for r in char_lines],
            language=["Auto"] * len(char_lines),
            voice_clone_prompt=prompt,
        )
        for r, wav in zip(char_lines, wavs):
            audio_clips[r.index] = wav

    # Assemble in original order with silence gaps
    segments: list[np.ndarray] = []
    for i, r in enumerate(resolved):
        segments.append(audio_clips[r.index])
        if i < len(resolved) - 1:
            silence_samples = int(SAMPLE_RATE * r.pause_ms / 1000)
            if silence_samples > 0:
                segments.append(np.zeros(silence_samples, dtype=np.float32))

    return np.concatenate(segments), SAMPLE_RATE
