#!/usr/bin/env python3
import sys
import json
import time
import os
import numpy as np
import traceback
import pyaudio
import threading
import queue

# Suppress NeMo stderr spam while preserving critical errors
class SelectiveStderr:
    def __init__(self, original_stderr):
        self.original_stderr = original_stderr
        # Aggressively filter NeMo/PyTorch startup spam for faster perceived startup
        self.suppress_keywords = [
            'nemo_logging', 'Transcribing', 'it/s', 'conditional node', 'redirects.py', 
            'multiprocessing', 'Redirects are currently not supported', 'FutureWarning',
            'DeprecationWarning', 'UserWarning', 'torchaudio', 'torch.', 'lightning',
            'hydra', 'omegaconf', 'Downloading', 'Fetching', 'Loading checkpoint',
            'Config', 'You are using a CUDA device', 'cuda:', 'device:', 'GPU available'
        ]
    
    def write(self, message):
        # Only show critical errors, suppress NeMo/PyTorch initialization noise
        if any(keyword in message for keyword in self.suppress_keywords):
            return  # Suppress spam
        if message.strip():  # Only write non-empty messages
            self.original_stderr.write(message)
    
    def flush(self):
        self.original_stderr.flush()

class ParakeetBridge:
    def __init__(self):
        # Save original stderr before any filtering
        self.original_stderr = sys.stderr
        
       # sys.stderr.write("[Python] NVIDIA Parakeet ASR - OFFLINE\n")
        sys.stderr.flush()
        
        # Initialize defaults BEFORE loading settings
        self.vad_sensitivity = 0.10  # Default VAD sensitivity
        self.default_device_index = 0
        self.default_phrase_limit = 55.0
        self.base_silence_threshold = 3.5  # Initialize BEFORE load_user_settings()
        
        # Load settings from user_settings.json (will override defaults if found)
        self.load_user_settings()
        
        self.model = None
        self.is_listening = False
        self.device_index = self.default_device_index
        self.phrase_time_limit = self.default_phrase_limit
        self.audio_queue = queue.Queue()
        self.vad_model = None
        self.device = "cpu"
        self.sample_rate = 16000
        self.chunk_size = 1024
        self.pyaudio_instance = None
        self.stream = None
        self.silence_threshold = self.base_silence_threshold  # Use the value loaded from settings
        self.last_speech_time = None
        self.last_transcription_time = 0  # Track when last transcription happened for post-transcription cooldown
        
        # Word counter for rate limiting and usage tracking
        self.total_words_transcribed = 0
        self.load_word_count()
    
    def load_user_settings(self):
        """Load VAD sensitivity, microphone index, phrase limit, and silence duration from user_settings.json"""
        try:
            # Search multiple locations for settings file (cross-platform support)
            script_dir = os.path.dirname(os.path.abspath(__file__))
            possible_paths = [
                # CrossPlatform: Same directory as script (running from bin)
                os.path.join(script_dir, "user_settings.json"),
                # CrossPlatform: Up two levels to SynergyX.Desktop
                os.path.join(script_dir, "..", "..", "user_settings.json"),
                # CrossPlatform: App working directory
                "user_settings.json",
                # Linux/macOS: ~/.config/SynergyX
                os.path.join(os.path.expanduser("~"), ".config", "SynergyX", "user_settings.json"),
                # macOS: ~/Library/Application Support/SynergyX
                os.path.join(os.path.expanduser("~"), "Library", "Application Support", "SynergyX", "user_settings.json"),
                # Windows: %APPDATA%/SynergyX
                os.path.join(os.environ.get('APPDATA', ''), "SynergyX", "user_settings.json") if os.name == 'nt' else "",
            ]
            
            settings_path = None
            for path in possible_paths:
                if path and os.path.exists(path):
                    settings_path = os.path.abspath(path)
                    break
            
            if settings_path and os.path.exists(settings_path):
                with open(settings_path, 'r') as f:
                    settings = json.load(f)
                    
                    old_vad = self.vad_sensitivity
                    old_silence = self.base_silence_threshold
                    
                    self.vad_sensitivity = settings.get("ParakeetVadSensitivity", 0.10)
                    self.default_device_index = settings.get("ParakeetMicrophoneIndex", 0)
                    self.default_phrase_limit = settings.get("ParakeetPhraseLimit", 55.0)
                    self.base_silence_threshold = settings.get("ParakeetSilenceDuration", 3.5)
                    self.silence_threshold = self.base_silence_threshold  # Reset to base value
                    
                    # Log changes for debugging
                    if old_vad is not None and old_vad != self.vad_sensitivity:
                #        self.original_stderr.write(f"[Python] 🔄 VAD changed: {old_vad} → {self.vad_sensitivity}\n")
                        self.original_stderr.flush()
                    
                    if old_silence is not None and old_silence != self.base_silence_threshold:
                   #     self.original_stderr.write(f"[Python] 🔄 Silence duration changed: {old_silence}s → {self.base_silence_threshold}s\n")
                        self.original_stderr.flush()
                    
                   # self.original_stderr.write(f"[Python] ✅ Settings loaded: VAD={self.vad_sensitivity}, Mic={self.default_device_index}, Phrase={self.default_phrase_limit}s, Silence={self.base_silence_threshold}s\n")
                    self.original_stderr.flush()
            else:
                # Settings file not found - use defaults
                pass
        except Exception as e:
        #S    self.original_stderr.write(f"[Python] ❌ Settings load error: {str(e)} (using defaults)\n")
            self.original_stderr.flush()
    
    def load_word_count(self):
        """Load persistent word count from file"""
        try:
            word_count_path = os.path.join(os.path.dirname(__file__), "..", "..", "word_count.json")
            if os.path.exists(word_count_path):
                with open(word_count_path, 'r') as f:
                    data = json.load(f)
                    self.total_words_transcribed = data.get("total_words", 0)
                 #   self.original_stderr.write(f"[Python] 📊 Word count loaded: {self.total_words_transcribed} words\n")
                    self.original_stderr.flush()
        except Exception as e:
            self.original_stderr.write(f"[Python] ⚠️ Word count load error: {str(e)}\n")
            self.original_stderr.flush()
    
    def save_word_count(self):
        """Save word count to persistent file"""
        try:
            word_count_path = os.path.join(os.path.dirname(__file__), "..", "..", "word_count.json")
            with open(word_count_path, 'w') as f:
                json.dump({"total_words": self.total_words_transcribed}, f)
        except Exception as e:
            self.original_stderr.write(f"[Python] ⚠️ Word count save error: {str(e)}\n")
            self.original_stderr.flush()
    
    def reset_word_count(self):
        """Reset word counter to zero"""
        self.total_words_transcribed = 0
        self.save_word_count()
        self.original_stderr.write(f"[Python] 🔄 Word count reset to 0\n")
        self.original_stderr.flush()
    
    def load_model(self):
        try:
            # Save original stderr before filtering
            original_stderr = sys.stderr
            
            # ⚡ PERFORMANCE: Install stderr filter BEFORE importing to suppress ALL startup spam
            sys.stderr = SelectiveStderr(original_stderr)
            
            # Suppress ALL logging BEFORE importing (CRITICAL for speed perception)
            import logging
            import warnings
            import os
            warnings.filterwarnings('ignore')
            os.environ['NEMO_LOGGING_LEVEL'] = 'CRITICAL'
            os.environ['HYDRA_FULL_ERROR'] = '0'
            os.environ['PYTHONWARNINGS'] = 'ignore'
            os.environ['TQDM_DISABLE'] = '1'
            os.environ['TORCH_HOME'] = os.path.join(os.path.expanduser('~'), '.cache', 'torch')
            
            # Configure all loggers BEFORE imports
            logging.getLogger('nemo_logger').setLevel(logging.CRITICAL)
            logging.getLogger('nemo').setLevel(logging.CRITICAL)
            logging.getLogger('NeMo').setLevel(logging.CRITICAL)
            logging.getLogger('pytorch_lightning').setLevel(logging.CRITICAL)
            logging.getLogger('lightning').setLevel(logging.CRITICAL)
            logging.getLogger('torch').setLevel(logging.CRITICAL)
            logging.getLogger().setLevel(logging.CRITICAL)
            
            try:
                import nemo.collections.asr as nemo_asr
                import torch
            except ImportError:
                original_stderr.write("[Python] NEMO TOOLKIT NOT INSTALLED\n")
                original_stderr.write("[Python] Run: pip install nemo_toolkit[asr]\n")
                original_stderr.flush()
                return False
            
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            
            # Load Parakeet model (fast on CUDA)
            local_model_path = "parakeet-tdt-0.6b-v2/parakeet-tdt-0.6b-v2.nemo"
            
            if os.path.exists(local_model_path):
                self.model = nemo_asr.models.ASRModel.restore_from(local_model_path, map_location=self.device)
            else:
                self.model = nemo_asr.models.ASRModel.from_pretrained(
                    model_name="nvidia/parakeet-tdt-0.6b-v2",
                    map_location=self.device
                )
            
            self.model.eval()
            
            # ⚡ OPTIMIZATION: Load Silero VAD from local cache (skip hub download check)
            # This eliminates the 2-3 second delay from checking GitHub
            try:
                import torch
                
                # Use local cache ONLY (force_reload=False is critical for speed)
                self.vad_model, utils = torch.hub.load(
                    repo_or_dir='snakers4/silero-vad',
                    model='silero_vad',
                    force_reload=False,  # Use cached model
                    trust_repo=True,     # Skip security check
                    onnx=False,
                    verbose=False        # Suppress output
                )
                self.vad_model.eval()
                self.get_speech_timestamps = utils[0]
                
            except Exception as e:
                # VAD is optional - continue without it
                pass
            
            return True
            
        except Exception as e:
            original_stderr.write(f"[Python] Load failed: {str(e)}\n")
            original_stderr.write(f"{traceback.format_exc()}\n")
            original_stderr.flush()
            return False
    
    def set_microphone(self, device_index):
        self.device_index = device_index
        return True
    
    def _send_response(self, text, confidence=0.78):
        response = {
            "type": "transcription",
            "text": text,
            "confidence": confidence,
            "language": "en",
            "enable_correction": False,
            "streaming": True,
            "word_count": self.total_words_transcribed
        }
        print(json.dumps(response), flush=True)
    
    def _send_error(self, error_message):
        response = {
            "type": "error",
            "message": error_message
        }
        print(json.dumps(response), flush=True)
    
    def _process_audio(self, audio_np):
        try:
            # Silero VAD check - filter out non-speech audio
            if self.vad_model is not None:
                import torch
                
                # Silero VAD requires exactly 512 samples for 16kHz
                # Check multiple 512-sample windows across the audio (IMPROVED SAMPLING)
                chunk_size = 512
                num_chunks = len(audio_np) // chunk_size
                
                if num_chunks > 0:
                    speech_probs = []
                    # Reduced from 20 to 6 windows for faster processing (OPTIMIZED)
                    for i in range(min(num_chunks, 12)):
                        start = i * chunk_size
                        end = start + chunk_size
                        chunk = audio_np[start:end]
                        
                        audio_tensor = torch.FloatTensor(chunk)
                        speech_prob = self.vad_model(audio_tensor, self.sample_rate).item()
                        speech_probs.append(speech_prob)
                    
                    # Use max speech probability across all windows
                    max_speech_prob = max(speech_probs)
                    
                    # Use user-defined VAD sensitivity threshold (0.10 - 0.80)
                    # Log current threshold for debugging real-time updates
                    if max_speech_prob < self.vad_sensitivity:
                     #   self.original_stderr.write(f"[Python] VAD rejected: {max_speech_prob:.2f} < {self.vad_sensitivity} (current threshold)\n")
                        self.original_stderr.flush()
                        return  # Skip transcription - no speech detected
                    
                   # self.original_stderr.write(f"[Python] VAD accepted: {max_speech_prob:.2f} >= {self.vad_sensitivity} (current threshold)\n")
                    self.original_stderr.flush()
            
            import tempfile
            import soundfile as sf
            
            # Create temp file with better error handling
            temp_fd = None
            temp_path = None
            try:
                # Use mkstemp for better control over file deletion
                temp_fd, temp_path = tempfile.mkstemp(suffix=".wav", prefix="parakeet_")
                os.close(temp_fd)  # Close file descriptor immediately
                
                # Write audio to temp file
                sf.write(temp_path, audio_np, self.sample_rate)
                
                # Transcribe silently
                output = self.model.transcribe([temp_path])
                
            finally:
                # Always clean up temp file, even on error
                if temp_path and os.path.exists(temp_path):
                    try:
                        file_size = os.path.getsize(temp_path) / (1024 * 1024)  # Size in MB
                        os.unlink(temp_path)
                      #  self.original_stderr.write(f"[Python] ✅ Deleted temp file: {os.path.basename(temp_path)} ({file_size:.2f} MB)\n")
                        self.original_stderr.flush()
                    except Exception as cleanup_error:
                        self.original_stderr.write(f"[Python] ⚠️ Temp cleanup failed: {cleanup_error}\n")
                        self.original_stderr.flush()
            
            if output and len(output) > 0:
                text = output[0].text.strip() if hasattr(output[0], 'text') else str(output[0]).strip()
                
                if text:
                    # Capitalize "god" → "God" (even python respects the great programmer)
                    import re
                    text = re.sub(r'\bgod\b', 'God', text, flags=re.IGNORECASE)
                    
                    # Ghost word prevention filters
                    text_lower = text.lower().strip()
                    confidence = 0.777  # Parakeet default confidence
                    
                    # Strip all non-letter characters to get letters-only version
                    import re
                    letters_only = re.sub(r'[^a-zA-Z]', '', text_lower)
                    
                    # ALWAYS block "mm-hmm" / "mmhmm" / "Mm-hmm" regardless of confidence
                    if letters_only in ['mmhmm', 'mhmm', 'mmhm']:
                        self.original_stderr.write(f"[Python] ❌ Blocked mm-hmm variant: '{text}' (letters: '{letters_only}')\n")
                        self.original_stderr.flush()
                        return
                    
                    # Block "okay" when it's alone AND low confidence (ghost word prevention)
                    if text_lower == "okay" and confidence < 0.98:
                        self.original_stderr.write(f"[Python] ❌ Blocked ghost 'okay' (confidence={confidence:.2f} < 0.98)\n")
                        self.original_stderr.flush()
                        return
                    
                    # Count words in transcription
                    word_count = len(text.split())
                    self.total_words_transcribed += word_count
                    self.save_word_count()
                    
                    # Send transcription to C# immediately
                    self._send_response(text, confidence=confidence)
                    

                  #  self.original_stderr.flush()
                
        except Exception as e:
            error_msg = f"Transcription error: {str(e)}"
            sys.stderr.write(f"[Python] ERROR: {error_msg}\n")
            sys.stderr.flush()
            self._send_error(error_msg)
    
    def _audio_callback(self, in_data, frame_count, time_info, status):
        if self.is_listening:
            self.audio_queue.put(in_data)
        return (None, pyaudio.paContinue)
    
    def _listen_loop(self):
        buffer = b""
        buffer_duration = 0
        silence_duration = 0
        last_speech_time = time.time()  # Track LAST speech, not last audio
        speech_active = False
        
        while self.is_listening:
            try:
                chunk = self.audio_queue.get(timeout=0.1)
                buffer += chunk
                chunk_duration = len(chunk) / (2 * self.sample_rate)
                buffer_duration += chunk_duration
                
                # Check if audio has energy (speech detected)
                audio_level = np.frombuffer(chunk, dtype=np.int16).astype(np.float32)
                rms = np.sqrt(np.mean(audio_level**2))
                
                # Speech detected - update last speech time
                if rms > 333:
                    last_speech_time = time.time()
                    
                    if not speech_active:
                        #sys.stderr.write(f"[Python] Speech detected (RMS: {rms:.0f})\n")
                        sys.stderr.flush()
                        speech_active = True
                else:
                    # Silence - calculate how long since last speech
                    silence_duration = time.time() - last_speech_time
                
                # 55-second cap
                if buffer_duration >= self.phrase_time_limit:
                    if buffer_duration > 0.5:
                        audio_np = np.frombuffer(buffer, dtype=np.int16).astype(np.float32) / 32768.0
                        self._process_audio(audio_np)
                    
                    buffer = b""
                    buffer_duration = 0
                    silence_duration = 0
                    last_speech_time = time.time()
                    speech_active = False
                    continue
                
                # 3.5s silence flush
                if silence_duration >= self.silence_threshold and buffer_duration > 0.5 and speech_active:
                    audio_np = np.frombuffer(buffer, dtype=np.int16).astype(np.float32) / 32768.0
                    self._process_audio(audio_np)
                    
                    buffer = b""
                    buffer_duration = 0
                    silence_duration = 0
                    last_speech_time = time.time()
                    speech_active = False
                    
            except queue.Empty:
                continue
            except Exception as e:
                if self.is_listening:
                    sys.stderr.write(f"Listen error: {str(e)}\n")
                    sys.stderr.flush()
    
    def start_listening(self):
        if not self.model:
            sys.stderr.write("Quantum Brain could not be loaded\n")
            sys.stderr.flush()
            return False
        
        try:
            self.pyaudio_instance = pyaudio.PyAudio()
            self.stream = self.pyaudio_instance.open(
                format=pyaudio.paInt16,
                channels=1,
                rate=self.sample_rate,
                input=True,
                input_device_index=self.device_index,
                frames_per_buffer=self.chunk_size,
                stream_callback=self._audio_callback
            )
            
            self.stream.start_stream()
            self.is_listening = True
            
            # Send confirmation that listening has started
            status = {"type": "status", "status": "listening", "message": "Microphone active"}
            print(json.dumps(status), flush=True)
            
            listen_thread = threading.Thread(target=self._listen_loop)
            listen_thread.daemon = True
            listen_thread.start()
            
            # NON-BLOCKING: Return immediately so main loop can process stop commands
            return True
            
        except Exception as e:
            sys.stderr.write(f"Start failed: {str(e)}\n")
            sys.stderr.flush()
            return False
    
    def stop_listening(self):
        self.is_listening = False
        
        if self.stream:
            self.stream.stop_stream()
            self.stream.close()
            self.stream = None
        if self.pyaudio_instance:
            self.pyaudio_instance.terminate()
            self.pyaudio_instance = None
        
        # Send confirmation status back to C#
        status = {"type": "status", "status": "stopped", "message": "Microphone stopped"}
        print(json.dumps(status), flush=True)


def main():
    bridge = ParakeetBridge()
    
    status = {"type": "status", "status": "initializing", "message": "Parakeet starting"}
    print(json.dumps(status), flush=True)
    
    for line in sys.stdin:
        try:
            line = line.strip()
            if not line:
                continue
            
            cmd = json.loads(line)
            
            if "model" in cmd or "silence_duration" in cmd:
                bridge.silence_threshold = cmd.get("silence_duration", 3)
                
                success = bridge.load_model()
                
                if success:
                    device_index = cmd.get("device_index", 0)
                    bridge.set_microphone(device_index)
                    
                    status = {"type": "status", "status": "initialized", "message": "Parakeet ready"}
                    print(json.dumps(status), flush=True)  # ✅ Show "Parakeet ready" first
                else:
                    error = {"type": "error", "message": "Failed to load model"}
                    print(json.dumps(error), flush=True)
            
            elif "command" in cmd:
                command = cmd["command"]
                
                if command == "start":
                    if "device_index" in cmd:
                        bridge.set_microphone(cmd["device_index"])
                    bridge.start_listening()
                
                elif command == "stop":
                    bridge.stop_listening()
                
                elif command == "reload_settings":
                    # Reload VAD sensitivity and silence duration from user_settings.json
                    old_vad = bridge.vad_sensitivity
                    old_silence = bridge.base_silence_threshold
                    bridge.load_user_settings()
                    
                    # Verbose logging to confirm real-time settings update
                  #  bridge.original_stderr.write(f"[Python] 🔄 VAD RELOAD: {old_vad} → {bridge.vad_sensitivity}\n")
                  #  bridge.original_stderr.write(f"[Python] 🔄 Silence Duration RELOAD: {old_silence}s → {bridge.base_silence_threshold}s\n")
                    bridge.original_stderr.flush()
                    
                 #   status = {"type": "status", "status": "settings_reloaded", "message": f"VAD {bridge.vad_sensitivity}, Silence {bridge.base_silence_threshold}s"}
                    print(json.dumps(status), flush=True)
                
                elif command == "reset_word_count":
                    # Reset word counter to zero
                    bridge.reset_word_count()
                    
                    status = {"type": "status", "status": "word_count_reset", "message": "Word count reset to 0", "word_count": 0}
                    print(json.dumps(status), flush=True)
                
                elif command == "shutdown":
                    sys.stderr.write("[Python] 🛑 Shutdown command received - cleaning up...\n")
                    sys.stderr.flush()
                    
                    # Stop listening if active
                    if bridge.is_listening:
                        bridge.stop_listening()
                    
                    sys.stderr.write("[Python] ✅ Shutdown complete - exiting\n")
                    sys.stderr.flush()
                    break
            
        except json.JSONDecodeError:
            error = {"type": "error", "message": "JSON parse error"}
            print(json.dumps(error), flush=True)
        except Exception as e:
            error = {"type": "error", "message": str(e)}
            print(json.dumps(error), flush=True)

if __name__ == "__main__":
    main()
