#!/usr/bin/env python3
"""
silence_cutter.py — Remove silent parts from a video using FFmpeg silencedetect.

Usage:
    python3 silence_cutter.py <input_video> <output_video> [--threshold-db DB]
                                                [--min-duration SEC]
                                                [--padding SEC]

Requires: ffmpeg and ffprobe in /home/claw/.local/tools/ffmpeg/
"""

import argparse
import json
import re
import subprocess
import sys
from pathlib import Path

FFMPEG = "/home/claw/.local/tools/ffmpeg/ffmpeg"
FFPROBE = "/home/claw/.local/tools/ffmpeg/ffprobe"


def get_audio_duration(video_path):
    """Return total audio duration in seconds via ffprobe."""
    cmd = [
        FFPROBE, "-v", "error",
        "-show_entries", "format=duration",
        "-of", "json",
        str(video_path)
    ]
    result = subprocess.run(cmd, capture_output=True, text=True)
    data = json.loads(result.stdout)
    return float(data["format"]["duration"])


def detect_silences(video_path, threshold_db=-30, min_duration=0.5):
    """
    Run FFmpeg silencedetect and return list of (start, end) silence intervals.
    Threshold in dB (e.g. -30, -40). min_duration in seconds.
    """
    cmd = [
        FFMPEG, "-i", str(video_path),
        "-af", f"silencedetect=noise={threshold_db}dB:d={min_duration}",
        "-f", "null", "-"
    ]
    result = subprocess.run(cmd, capture_output=True, text=True)
    stderr = result.stderr

    silences = []
    start_times = []
    end_times = []

    for line in stderr.splitlines():
        line = line.strip()
        # Parse silence_start
        m = re.search(r"silence_start:\s*([\d.]+)", line)
        if m:
            start_times.append(float(m.group(1)))
        # Parse silence_end
        m = re.search(r"silence_end:\s*([\d.]+)\s*\|\s*silence_duration:\s*([\d.]+)", line)
        if m:
            end_times.append(float(m.group(1)))

    # Build (start, end) pairs
    for s, e in zip(start_times, end_times):
        silences.append((s, e))

    return silences


def build_keep_segments(silences, total_duration, padding=0.1):
    """
    Given silence intervals, return the complement: the segments to KEEP.
    Adds small padding to avoid cutting too aggressively.
    """
    segments = []
    last_end = 0.0

    for s, e in silences:
        # Clip to range
        s = max(last_end, s - padding)
        e = min(total_duration, e + padding)
        if s > last_end:
            segments.append((last_end, s))
        last_end = e

    if last_end < total_duration:
        segments.append((last_end, total_duration))

    return segments


def concatenate_segments(video_path, segments, output_path):
    """
    Cut each keep-segment and stitch them together using FFmpeg concat.
    Re-encodes with libx264 for compatibility.
    """
    tmp_dir = Path(output_path).parent / ".concat_tmp"
    tmp_dir.mkdir(exist_ok=True)

    # Write concat file
    list_file = tmp_dir / "segments.txt"
    with open(list_file, "w") as f:
        for i, (start, end) in enumerate(segments):
            duration = end - start
            segment_file = tmp_dir / f"seg_{i:03d}.mp4"
            cmd = [
                FFMPEG, "-y",
                "-ss", str(start),
                "-i", str(video_path),
                "-t", str(duration),
                "-c:v", "libx264", "-preset", "fast",
                "-crf", "23",
                "-c:a", "aac", "-b:a", "128k",
                str(segment_file)
            ]
            subprocess.run(cmd, capture_output=True, text=True)
            f.write(f"file '{segment_file}'\n")

    # Concat all segments
    concat_cmd = [
        FFMPEG, "-y",
        "-f", "concat",
        "-safe", "0",
        "-i", str(list_file),
        "-c:v", "libx264", "-preset", "fast", "-crf", "23",
        "-c:a", "aac", "-b:a", "128k",
        str(output_path)
    ]
    result = subprocess.run(concat_cmd, capture_output=True, text=True)

    # Cleanup
    for f in tmp_dir.glob("seg_*.mp4"):
        f.unlink()
    list_file.unlink()
    tmp_dir.rmdir()

    return result.returncode == 0


def main():
    parser = argparse.ArgumentParser(description="Remove silent parts from a video.")
    parser.add_argument("input_video", help="Path to input video file")
    parser.add_argument("output_video", help="Path to output video file")
    parser.add_argument("--threshold-db", type=float, default=-30,
                        help="Silence threshold in dB (default: -30)")
    parser.add_argument("--min-duration", type=float, default=0.5,
                        help="Minimum silence duration in seconds (default: 0.5)")
    parser.add_argument("--padding", type=float, default=0.1,
                        help="Padding around kept segments in seconds (default: 0.1)")

    args = parser.parse_args()

    input_path = Path(args.input_video)
    output_path = Path(args.output_video)

    if not input_path.exists():
        print(f"[ERROR] Input file not found: {input_path}", file=sys.stderr)
        sys.exit(1)

    print(f"Detecting silences (threshold={args.threshold_db}dB, min_duration={args.min_duration}s)...")
    silences = detect_silences(input_path, args.threshold_db, args.min_duration)
    print(f"Found {len(silences)} silent segment(s)")

    total_duration = get_audio_duration(input_path)
    segments = build_keep_segments(silences, total_duration, args.padding)
    print(f"Keeping {len(segments)} segment(s)")
    for i, (s, e) in enumerate(segments):
        print(f"  Segment {i+1}: {s:.2f}s → {e:.2f}s (duration={e-s:.2f}s)")

    print(f"Cutting and concatenating...")
    ok = concatenate_segments(input_path, segments, output_path)

    if ok:
        print(f"[OK] Saved: {output_path}")
    else:
        print("[ERROR] FFmpeg concat failed", file=sys.stderr)
        sys.exit(1)


if __name__ == "__main__":
    main()