#!/usr/bin/env python3
"""Fix audio files to meet Razzmatazz WAV requirements."""

import json
import os
import re
import shutil
import subprocess
import sys
import time
import wave
from collections import defaultdict
from shutil import get_terminal_size

FORBIDDEN_CHARS = re.compile(r'["/\\?*<>:|]')
ALLOWED_BIT_DEPTHS = {16, 24, 32}
ALLOWED_CHANNELS = {1, 2}
NATIVE_SAMPLE_RATE = 48000
MAX_PATH_LENGTH = 255
AUDIO_EXTENSIONS = {'.wav', '.mp3', '.flac', '.ogg', '.aiff', '.aif', '.wma', '.m4a', '.aac', '.ape', '.opus'}

SAMPLE_FMT_TO_BITS = {
    's16': 16, 's16le': 16, 's16be': 16, 'u16': 16,
    's32': 32, 's32le': 32, 's32be': 32,
    's24': 24, 's24le': 24, 's24be': 24,
    'flt': 32, 'fltle': 32, 'fltbe': 32,
    'dbl': 64,
    'u8': 8, 's8': 8,
}


def ffprobe(path):
    try:
        r = subprocess.run(
            ['ffprobe', '-v', 'quiet', '-print_format', 'json',
             '-show_streams', path],
            capture_output=True, text=True, timeout=30)
        if r.returncode != 0:
            return None
        data = json.loads(r.stdout)
        for s in data.get('streams', []):
            if s.get('codec_type') == 'audio':
                return s
    except Exception:
        return None
    return None


def check_file(path):
    issues = []

    if len(path) > MAX_PATH_LENGTH:
        issues.append(f"path length {len(path)} > {MAX_PATH_LENGTH}")

    basename = os.path.basename(path)
    if FORBIDDEN_CHARS.search(basename):
        issues.append("filename has forbidden characters")

    ext = os.path.splitext(path)[1].lower()

    if ext == '.wav':
        try:
            with wave.open(path, 'rb') as wf:
                params = wf.getparams()
                ch = params.nchannels
                bd = params.sampwidth * 8
                sr = params.framerate
                if ch not in ALLOWED_CHANNELS:
                    issues.append(f"{ch} channels (need 1 or 2)")
                if bd not in ALLOWED_BIT_DEPTHS:
                    issues.append(f"{bd}-bit (need 16, 24, or 32)")
                if sr != NATIVE_SAMPLE_RATE:
                    issues.append(f"{sr} Hz sample rate (native is {NATIVE_SAMPLE_RATE})")
        except Exception:
            issues.append("not a valid WAV or unreadable")
    else:
        issues.append(f"not WAV ({ext})")
        info = ffprobe(path)
        if info is None:
            issues.append("could not read audio info")
        else:
            ch = info.get('channels', 0)
            sr = info.get('sample_rate', '?')
            fmt = info.get('sample_fmt', '?')
            bd = SAMPLE_FMT_TO_BITS.get(fmt, '?')
            if ch not in ALLOWED_CHANNELS:
                issues.append(f"{ch} channels (need 1 or 2)")
            if isinstance(bd, int) and bd not in ALLOWED_BIT_DEPTHS:
                issues.append(f"{bd}-bit (need 16, 24, or 32)")
            if sr not in (None, '?') and int(sr) != NATIVE_SAMPLE_RATE:
                issues.append(f"{sr} Hz sample rate (native is {NATIVE_SAMPLE_RATE})")

    return issues


def fix_filename(name):
    root, ext = os.path.splitext(name)
    root = FORBIDDEN_CHARS.sub('_', root)
    root = re.sub(r'_+', '_', root).strip('_')
    root = root or 'fixed'
    return root + '.wav'


def process_file(path, outpath, quiet=False):
    rel = os.path.relpath(path)

    issues = check_file(path)

    ext = os.path.splitext(path)[1].lower()
    is_wav = (ext == '.wav')

    os.makedirs(os.path.dirname(outpath) or '.', exist_ok=True)

    clean_name = os.path.basename(outpath)
    name_changed = (os.path.basename(path) != clean_name)

    if is_wav and not issues and not name_changed:
        shutil.copy2(path, outpath)
        if not quiet:
            print(f"OK    {rel}")
        return True

    if not quiet:
        parts = issues[:]
        if name_changed:
            parts.append(f"filename → {clean_name}")
        print(f"FIX   {rel}: {', '.join(parts)}")

    if is_wav:
        try:
            with wave.open(path, 'rb') as wf:
                params = wf.getparams()
                ch = params.nchannels
                bd = params.sampwidth * 8
                sr = params.framerate
        except Exception:
            if not quiet:
                print(f"FAIL  {rel}: cannot parse WAV header, copying as-is")
            shutil.copy2(path, outpath)
            return False
    else:
        info = ffprobe(path)
        if info is None:
            if not quiet:
                print(f"FAIL  {rel}: cannot read audio info, copying as-is")
            shutil.copy2(path, outpath)
            return False
        ch = info.get('channels', 2)
        sr = int(info.get('sample_rate', NATIVE_SAMPLE_RATE))
        fmt = info.get('sample_fmt', 's16')
        bd = SAMPLE_FMT_TO_BITS.get(fmt, 16)

    target_ch = ch if ch in ALLOWED_CHANNELS else 2
    target_bd = bd if bd in ALLOWED_BIT_DEPTHS else 16
    acodec = {16: 'pcm_s16le', 24: 'pcm_s24le', 32: 'pcm_s32le'}[target_bd]

    cmd = ['ffmpeg', '-y', '-i', path]

    needs_recode = (target_bd != bd) or (sr != NATIVE_SAMPLE_RATE) or (target_ch != ch) or (not is_wav)

    if needs_recode:
        if target_ch != ch:
            cmd += ['-ac', str(target_ch)]
        cmd += ['-acodec', acodec]
        if sr != NATIVE_SAMPLE_RATE:
            cmd += ['-ar', str(NATIVE_SAMPLE_RATE)]
            cmd += ['-af', 'aresample=resampler=soxr']
    else:
        cmd += ['-c', 'copy']
    cmd.append(outpath)

    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        if not quiet:
            print(f"FAIL  {rel}: ffmpeg error — {result.stderr.strip()}")
        return False
    return True


class ProgressBar:
    def __init__(self, total):
        self.total = total
        self.current = 0
        self.success = 0
        self.failed = 0
        self.start_time = time.time()
        self.last_draw = 0

    def update(self, n=1, success=True):
        self.current += n
        if success:
            self.success += 1
        else:
            self.failed += 1
        self._draw()

    def _draw(self):
        now = time.time()
        if now - self.last_draw < 0.1:
            return
        self.last_draw = now

        cols, _ = get_terminal_size((80, 24))
        elapsed = now - self.start_time

        pct = self.current / self.total if self.total else 0
        bar_w = cols - 30
        if bar_w < 10:
            bar_w = 10
        filled = int(bar_w * pct)
        bar = '[' + '=' * filled + '>' * min(1, bar_w - filled) + '.' * (bar_w - filled - min(1, bar_w - filled)) + ']'

        eta = (elapsed / self.current * (self.total - self.current)) if self.current > 0 else 0

        line = f"\r{bar} {pct * 100:5.1f}%  {self.current}/{self.total}  ETA {eta:.0f}s"
        sys.stdout.write(line[:cols])
        sys.stdout.flush()

    def done(self):
        elapsed = time.time() - self.start_time
        line = f"\r{' ' * get_terminal_size((80, 24)).columns}\r"
        sys.stdout.write(line)
        sys.stdout.flush()
        print(f"Done — {self.success} ok, {self.failed} failed  ({elapsed:.1f}s)")


def main():
    import argparse
    parser = argparse.ArgumentParser(description='Fix audio files to valid Razzmatazz WAVs.')
    parser.add_argument('paths', nargs='+', help='Audio files or directories')
    parser.add_argument('--out', '-o', default='fixed_wavs', help='Output directory (default: fixed_wavs)')
    parser.add_argument('--dry-run', '-n', action='store_true', help='Preview without converting')
    parser.add_argument('--no-progress', '-P', action='store_true', help='Disable progress bar (one line per file)')
    args = parser.parse_args()

    files = []
    for p in args.paths:
        if os.path.isdir(p):
            for root, _, filenames in os.walk(p):
                for f in filenames:
                    if os.path.splitext(f)[1].lower() in AUDIO_EXTENSIONS:
                        files.append(os.path.join(root, f))
        else:
            files.append(p)

    if not files:
        print("No supported audio files found.")
        print(f"Extensions: {', '.join(sorted(AUDIO_EXTENSIONS))}")
        sys.exit(0)

    files.sort()

    # Resolve output paths
    out_paths = []
    for path in files:
        cwd = os.getcwd()
        try:
            rel_to_cwd = os.path.relpath(path, cwd)
            subdir = '' if rel_to_cwd.startswith('..') else os.path.dirname(rel_to_cwd)
        except ValueError:
            subdir = ''
        clean = fix_filename(os.path.basename(path))
        d = os.path.join(args.out, subdir) if subdir else args.out
        out_paths.append(os.path.join(d, clean))

    used = defaultdict(int)
    resolved = {}
    for p, path in zip(out_paths, files):
        if used[p] > 0:
            root, ext = os.path.splitext(p)
            p = f"{root}_{used[p]}{ext}"
        used[p] += 1
        resolved[path] = p
        while True:
            np = os.path.normpath(p)
            cn = os.path.basename(np)
            if len(cn) > 255:
                root, ext = os.path.splitext(cn)
                cn = root[:255 - len(ext)] + ext
            if len(np) > 255:
                d = os.path.dirname(np)
                available = 255 - len(d) - 1
                root, ext = os.path.splitext(cn)
                cn = root[:available - len(ext)] + ext
            resolved[path] = os.path.join(os.path.dirname(np), cn)
            break

    # Scan phase
    use_progress = not args.dry_run and not args.no_progress
    if not use_progress:
        # Original mode: one line per file
        ok = failed = 0
        for path in files:
            if process_file(path, resolved[path], quiet=False):
                ok += 1
            else:
                failed += 1
        print(f"\n{ok} processed, {failed} failed")
        sys.exit(1 if failed else 0)

    # --- Progress bar mode ---
    n_total = len(files)

    # Scan: check all files quickly
    print(f"Scanning {n_total} files...")
    ok_count = fix_count = skip_count = 0
    for path in files:
        issues = check_file(path)
        clean_name = os.path.basename(resolved[path])
        name_changed = (os.path.basename(path) != clean_name)
        ext = os.path.splitext(path)[1].lower()
        if issues or name_changed or ext != '.wav':
            fix_count += 1
        else:
            ok_count += 1

    print(f"  {ok_count} already valid, {fix_count} need processing")
    if fix_count == 0:
        # Just copy everything
        bar = ProgressBar(n_total)
        for path in files:
            ok = process_file(path, resolved[path], quiet=True)
            bar.update(success=ok)
        bar.done()
        sys.exit(0)

    # Process with progress bar
    bar = ProgressBar(n_total)
    for path in files:
        ok = process_file(path, resolved[path], quiet=True)
        bar.update(success=ok)
    bar.done()
    sys.exit(1 if bar.failed else 0)


if __name__ == '__main__':
    main()
