from natsort import natsorted
import requests
import os
import subprocess
import sys
from dotenv import load_dotenv

import librosa
import numpy as np
from scipy.signal import correlate

os.environ["LOKY_MAX_CPU_COUNT"] = "4"
load_dotenv()
api_key = os.getenv("ADMIN_SECRET_KEY")
api_url = os.getenv("ONCHILL_API_URL")
api_origin = os.getenv("ONCHILL_API_ORIGIN")
node_url = os.getenv("NODE_URL")

output_file_1 = "storage/app/introduction/file_list_1.txt"
output_file_2 = "storage/app/introduction/file_list_2.txt"

output_mp3_path_1 = "storage/app/introduction/output_1.mp3"  # Chemin du fichier MP3 de sortie
output_mp3_path_2 = "storage/app/introduction/output_2.mp3"  # Chemin du fichier MP3 de sortie

headers = {
    "x-api-key": api_key,
    "origin": api_origin
}

def generate_file_list(ts_directory, file_list_path):
    """
    Génère un fichier de liste contenant les chemins des fichiers .ts triés.
    """
    # Vérifie si le répertoire des segments existe
    if not os.path.exists(ts_directory):
        print(f"❌ Erreur: Le répertoire {ts_directory} n'existe pas.")
        return

    # Récupère tous les fichiers .ts dans le répertoire
    files = [os.path.join(ts_directory, f) for f in os.listdir(ts_directory) if f.endswith(".ts") and "audio" in f]
    if len(files) == 0 : 
        files = [os.path.join(ts_directory, f) for f in os.listdir(ts_directory) if f.endswith(".ts")]

    # Trie les fichiers
    sorted_files = natsorted(files)

    # Crée le fichier file_list.txt
    with open(file_list_path, "w") as file_list:
        for ts_file in sorted_files:
            # Écrit le chemin relatif du fichier ts
            relative_path = os.path.relpath(ts_file, start=os.path.dirname(file_list_path))
            file_list.write(f"file '{relative_path.replace(os.sep, '/')}'\n")
    
    print(f"✅ Fichier de liste généré : {file_list_path}")

def convert_to_mp3(episodeDir, output_file, output_mp3_path):

    if(episodeDir.endswith('.mp4')):
        convert_mp4_to_mp3(episodeDir, output_mp3_path)
    else:
        generate_file_list(episodeDir, output_file)
        convert_ts_to_mp3(output_file, output_mp3_path)

def convert_mp4_to_mp3(input_file, output_file=None):
    """
    Convertit un fichier MP4 en MP3.
    
    :param input_file: Chemin du fichier MP4 à convertir.
    :param output_file: Chemin du fichier de sortie MP3. Si None, remplace l'extension par .mp3.
    :return: Chemin du fichier MP3 généré.
    """
    if not os.path.exists(input_file):
        raise FileNotFoundError(f"Le fichier {input_file} n'existe pas.")
    
    # Définir le nom du fichier de sortie
    if output_file is None:
        output_file = os.path.splitext(input_file)[0] + ".mp3"
    
    # Commande FFmpeg
    command = [
        "ffmpeg",
        "-loglevel", "quiet",
        "-i", input_file,   # Fichier d'entrée
        "-map", "0:a",            # Sélectionne uniquement le flux audio 
        "-vn",              # Pas de flux vidéo
        "-acodec", "libmp3lame",  # Codeur MP3
        "-q:a", "2",        # Qualité audio (0=meilleure, 9=moins bonne)
        "-af", "loudnorm",        # Applique la normalisation audio
        output_file         # Fichier de sortie
    ]
    
    # Exécuter la commande
    try:
        subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        print(f"✅ Conversion réussie : {output_file}")
    except subprocess.CalledProcessError as e:
        raise RuntimeError(f"❌ Erreur lors de la conversion : {e.stderr.decode()}")
    
    return output_file

def convert_ts_to_mp3(file_list_path, output_mp3_path):
    """
    Convertit les segments .ts en un fichier MP3 en utilisant ffmpeg et le fichier de liste.
    """
    # Commande ffmpeg
    command = [
        "ffmpeg",
        "-loglevel", "quiet", 
        "-f", "concat",
        "-safe", "0",
        "-i", file_list_path,
        "-vn",
        "-map", "0:a",            # Sélectionne uniquement le flux audio
        "-acodec", "libmp3lame",  # Codec audio MP3
        "-q:a", "2",              # Qualité maximale
        "-af", "loudnorm",        # Applique la normalisation audio
        output_mp3_path
    ]

    # Exécute la commande
    subprocess.run(command, check=True)
    print(f"✅ Conversion terminée. Le fichier MP3 est enregistré sous : {output_mp3_path}")

os.environ["LOKY_MAX_CPU_COUNT"] = "14"

def find_similar_segments_in_zones(file1, file2, segment_duration=10, correlation_threshold=0.2, min_duration=30.0):
    """
    Recherche les segments similaires entre des zones spécifiques du fichier 1 et du fichier 2.

    :param file1: Chemin du premier fichier audio (.mp3).
    :param file2: Chemin du second fichier audio (.mp3).
    :param segment_duration: Durée de chaque segment en secondes.
    :param correlation_threshold: Seuil pour considérer deux segments comme similaires.
    :param min_duration: Durée minimale en secondes pour inclure un segment regroupé.
    :return: Liste de tuples (start_time_file1, end_time_file1, start_time_file2, end_time_file2).
    """
    # Charger les fichiers audio
    y1, sr1 = librosa.load(file1, sr=None)
    y2, sr2 = librosa.load(file2, sr=None)

    if sr1 != sr2:
        raise ValueError("Les fichiers audio doivent avoir le même taux d'échantillonnage.")

    segment_length = sr1 * segment_duration
    similar_segments = []

    # Définir les zones d'intérêt
    mid_point_file1 = len(y1) // 2

    count_find_start_zone = 0
    count_not_find_start_zone = 0

    # Première moitié du fichier 1 et première moitié du fichier 2
    for i in range(0, mid_point_file1, segment_length):
        segment1 = y1[i:i + segment_length]
        start_time1 = i / sr1
        end_time1 = (i + segment_length) / sr1

        corr = correlate(y2[:mid_point_file1], segment1, mode='valid', method='fft')
        peak = np.argmax(corr)
        peak_correlation = corr[peak]

        norm_corr = peak_correlation / (np.linalg.norm(segment1) * np.linalg.norm(y2[max(0, peak):peak + len(segment1)]))

        if norm_corr >= correlation_threshold:
            # start_time2 = peak / sr2
            # end_time2 = (peak + len(segment1)) / sr2
            similar_segments.append((start_time1, end_time1, 'opening'))
            count_find_start_zone += 1
            count_not_find_start_zone = 0
        elif count_find_start_zone <= 6 and count_not_find_start_zone != 0:
            count_find_start_zone = 0
        elif count_find_start_zone > 6 and count_not_find_start_zone > 0:
            break
        else:
            count_not_find_start_zone += 1


    # Deuxième zone : recherche inversée dans les dernières 5 minutes

    count_find_end_zone = 0
    count_not_find_end_zone = 0
    count = 0
    
    reversed_y1 = y1[::-1]
    reversed_y2 = y2[::-1]

    for i in range(0, min(len(reversed_y1), sr1 * 300), segment_length):
        segment1 = reversed_y1[i:i + segment_length]
        start_time1 = (len(y1) - (i + segment_length)) / sr1  # Calcul du timecode en fonction de la position inversée
        end_time1 = (len(y1) - i) / sr1

        segment2_start = max(0, peak)
        segment2_end = min(len(reversed_y2), peak + len(segment1))
        segment2 = reversed_y2[segment2_start:segment2_end]

        norm_segment1 = np.linalg.norm(segment1)
        norm_segment2 = np.linalg.norm(segment2)

        corr = correlate(reversed_y2[:sr2 * 300], segment1, mode='valid', method='fft')
        peak = np.argmax(corr)
        peak_correlation = corr[peak]

        if norm_segment1 > 0 and norm_segment2 > 0:
            norm_corr = norm_corr = peak_correlation / (np.linalg.norm(segment1) * np.linalg.norm(reversed_y2[max(0, peak):peak + len(segment1)]))
        else:
            norm_corr = 0  # Mettre une valeur par défaut si la norme est nulle

        if norm_corr >= correlation_threshold:
            # start_time2 = (len(y2) - (peak + len(segment1))) / sr2
            # end_time2 = (len(y2) - peak) / sr2
            similar_segments.append((start_time1, end_time1, 'ending'))
            count_find_end_zone += 1
            count_not_find_end_zone = 0
        elif count_find_end_zone <= 6 and count_not_find_end_zone != 0:
            count_find_end_zone = 0
        elif count_find_end_zone > 6 and count_not_find_end_zone > 0:
            break
        else:
            count_not_find_end_zone += 1
        

        # print(f"coun: {count}, count_find_end_zone: {count_find_end_zone}, count_not_find_end_zone: {count_not_find_end_zone}")
        # count += 1

     # Regrouper les segments consécutifs
    grouped_consecutive_segments = merge_consecutive_segments(similar_segments)
    grouped_segments = merge_segments_with_gap(grouped_consecutive_segments)

    # Filtrer les segments qui durent moins que la durée minimale
    filtered_segments = [(start, end, name) for start, end, name in grouped_segments if end - start >= min_duration]

    return filtered_segments

def merge_segments_with_gap(segments, max_gap=60.0):
    """
    Regroupe les segments si l'écart entre eux est inférieur ou égal à une valeur donnée.

    :param segments: Liste de tuples (start_time, end_time, name).
    :param max_gap: Ecart maximal entre les segments à regrouper (en secondes).
    :return: Liste de tuples regroupés.
    """
    if not segments:
        return []

    # Trier les segments par start_time
    segments = sorted(segments, key=lambda x: x[0])

    merged_segments = [segments[0]]
    for current_start, current_end, current_name in segments[1:]:
        last_start, last_end, last_name = merged_segments[-1]

        # Vérifier si les segments sont à moins de `max_gap` secondes
        if current_start - last_end <= max_gap and current_name == last_name:
            # Fusionner les segments
            merged_segments[-1] = (last_start, current_end, last_name)
        else:
            # Ajouter le segment séparément
            merged_segments.append((current_start, current_end, current_name))

    return merged_segments

def merge_consecutive_segments(segments):
    """
    Regroupe les timecodes consécutifs.

    :param segments: Liste de tuples (start_time, end_time).
    :return: Liste de tuples regroupés.
    """
    if not segments:
        return []

    # Trier les segments par start_time pour éviter tout désordre
    segments = sorted(segments)

    merged_segments = [segments[0]]
    for current_start, current_end, name in segments[1:]:
        last_start, last_end, last_name = merged_segments[-1]

        # Vérifier si les segments sont consécutifs ou se chevauchent
        if np.isclose(last_end, current_start, atol=1e-3):  # tolérance pour les flottants
            merged_segments[-1] = (last_start, current_end, name)  # Fusionner
        else:
            merged_segments.append((current_start, current_end, name))  # Ajouter séparément

    return merged_segments

def get_segment_dir(episode):

    url = "{}/video/{}/video-link".format(api_url, episode.get('id'))
    response = requests.get(url, headers=headers)
    if response.status_code != 200:
        print("Erreur: ", response.status_code)
        return None

    data = response.json()
    if data is None:
        return None
    return "storage/app/uploads/" + data.get('url').replace('/api/video/', '').replace('.m3u8', '').replace(node_url, '')

def store_segment(episode, segment):

    if segment[2] == 'ending' and segment[1] * 100 / episode.get('seconds') < 95:
        return

    url = "{}/timecode".format(api_url)
    data = {
        "video_id": episode.get('id'),
        "name": segment[2],
        "start": segment[0],
        "end": segment[1],
    }
    response = requests.post(url, headers=headers, json=data)
    if response.status_code != 200:
        print("Erreur: ", response.status_code)
    else:
        video = response.json().get('video')
        timecode = response.json().get('timecode')
        print(f"ID: {video.get('id')}, Episode: : {video.get('name')}, Debut de l'intro: {timecode.get('start')}s , Fin de l'intro: {timecode.get('end')}s")

def delete_all_files():
    if os.path.exists(output_mp3_path_1):
        os.remove(output_mp3_path_1)

    if os.path.exists(output_mp3_path_2):
        os.remove(output_mp3_path_2)

    if os.path.exists(output_file_1):
        os.remove(output_file_1)
    
    if os.path.exists(output_file_2):
        os.remove(output_file_2)


print("🚀 Lancement du script de détection d'introduction")
# Récupérer l'id de la series
if len(sys.argv) < 2:
    print("❌ Usage: py monscript.py <id>")
    sys.exit(1)

id_param = sys.argv[1]
gap_param = int(sys.argv[2]) if len(sys.argv) > 2 else 0


delete_all_files()

# Faire une requête api pour récupérer tous les épisodes
url = "{}/series/{}/videos".format(api_url, id_param)

response = requests.get(url, headers=headers)
if response.status_code != 200:
    print("❌ Erreur: ", response.status_code)

# Filtrer les épisodes qui n'ont pas de timecode
episodes = [item for item in response.json() if not item.get('search_timecode', True)]

# Boucler sur chaque épisode
for index, episode1 in enumerate(episodes):
    if index < int(gap_param):
        continue
    # Si dernier élément kill la boucle 
    if index + 1 >= len(episodes):
        delete_all_files()
        break

    episode2 = episodes[index + 1]

    # Identifier les fichiers vidéo
    if episode1.get('segmentDir') == None:
        episode1['segmentDir'] = get_segment_dir(episode1)


    episode2['segmentDir'] = get_segment_dir(episode2)
    count = 2

    # Si pas de segment chercher celui d'après
    while episode2.get('segmentDir') == None:

        if index + count >= len(episodes):
            break

        print("❌ Le fichier n'a pas de lien")
        episode2 = episodes[index + count]

        episode2['segmentDir'] = get_segment_dir(episode2)
        count += 1

    if episode2.get('segmentDir') == None:
        delete_all_files()
        break

    convert_to_mp3(episode2.get('segmentDir'), output_file_2, output_mp3_path_2)

    # Si le 1er fichier n'a pas de timecode
    if episode1.get('search_timecode') == False:
        if not os.path.exists(output_mp3_path_1):
            convert_to_mp3(episode1.get('segmentDir'), output_file_1, output_mp3_path_1)

        # Chercher les corrélations du 1er fichier
        segments = find_similar_segments_in_zones(output_mp3_path_1, output_mp3_path_2, 5)
        if(len(segments) > 0):
            # Update data 1er fichier
            episode1['search_timecode'] = True
        
        # POST les timecode du 1er fichier
        for segment in segments:
            store_segment(episode1, segment)
            # print(f"File: {segment[0]:.2f}s - {segment[1]:.2f}s")
    
    # Chercher les corrélation du 2ème fichier
    segments = find_similar_segments_in_zones(output_mp3_path_2, output_mp3_path_1, 5)
    
    # Update data 2ème fichier
    if(len(segments) > 0):
        episode2['search_timecode'] = True

    # POST les timecode 2ème fichier
    for segment in segments:
        store_segment(episode2, segment)

    # Supprimer le fichier audio 1
    os.remove(output_mp3_path_1)
    # Renomer le fichier audio 2 en 1
    os.rename(output_mp3_path_2, output_mp3_path_1)
    # Supprimer le file list 1
    if os.path.exists(output_file_1):
        os.remove(output_file_1)
    # Renomer le file list 2 en 1
    if os.path.exists(output_file_1) and os.path.exists(output_file_2):
        os.rename(output_file_2, output_file_1)

print("✅ Toute les timecode des introductions ont été trouvé!")