import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan import whisper import gradio as gr import re import pandas as pd import numpy as np import os import time import logging import threading import queue from scipy.io.wavfile import write as write_wav from html import escape import traceback # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler()] ) logger = logging.getLogger('profanity_detector') # Define device at the top of the script (global scope) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") # Global variables for models profanity_model = None profanity_tokenizer = None t5_model = None t5_tokenizer = None whisper_model = None tts_processor = None tts_model = None vocoder = None models_loaded = False # Default speaker embeddings for TTS speaker_embeddings = None # Queue for real-time audio processing audio_queue = queue.Queue() processing_active = False # Model loading with int8 quantization def load_models(): global profanity_model, profanity_tokenizer, t5_model, t5_tokenizer, whisper_model global tts_processor, tts_model, vocoder, speaker_embeddings, models_loaded try: logger.info("Loading profanity detection model...") PROFANITY_MODEL = "parsawar/profanity_model_3.1" profanity_tokenizer = AutoTokenizer.from_pretrained(PROFANITY_MODEL) # Load model with memory optimization using half-precision profanity_model = AutoModelForSequenceClassification.from_pretrained(PROFANITY_MODEL) # Move to GPU if available and optimize with half-precision where possible if torch.cuda.is_available(): profanity_model = profanity_model.to(device) # Convert to half precision to save memory (if possible) try: profanity_model = profanity_model.half() # Convert to FP16 logger.info("Successfully converted profanity model to half precision") except Exception as e: logger.warning(f"Could not convert to half precision: {str(e)}") logger.info("Loading detoxification model...") T5_MODEL = "s-nlp/t5-paranmt-detox" t5_tokenizer = AutoTokenizer.from_pretrained(T5_MODEL) # Load model with memory optimization t5_model = AutoModelForSeq2SeqLM.from_pretrained(T5_MODEL) # Move to GPU if available and optimize with half-precision where possible if torch.cuda.is_available(): t5_model = t5_model.to(device) # Convert to half precision to save memory (if possible) try: t5_model = t5_model.half() # Convert to FP16 logger.info("Successfully converted T5 model to half precision") except Exception as e: logger.warning(f"Could not convert to half precision: {str(e)}") logger.info("Loading Whisper speech-to-text model...") whisper_model = whisper.load_model("large") if torch.cuda.is_available(): whisper_model = whisper_model.to(device) logger.info("Loading Text-to-Speech model...") TTS_MODEL = "microsoft/speecht5_tts" tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL) # Load TTS models without automatic device mapping tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL) vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") # Move models to appropriate device if torch.cuda.is_available(): tts_model = tts_model.to(device) vocoder = vocoder.to(device) # Speaker embeddings for TTS speaker_embeddings = torch.zeros((1, 512)) if torch.cuda.is_available(): speaker_embeddings = speaker_embeddings.to(device) models_loaded = True logger.info("All models loaded successfully.") return "Models loaded successfully." except Exception as e: error_msg = f"Error loading models: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) return error_msg def detect_profanity(text: str, threshold: float = 0.5): """ Detect profanity in text with adjustable threshold Args: text: The input text to analyze threshold: Profanity detection threshold (0.0-1.0) Returns: Dictionary with analysis results """ if not models_loaded: return {"error": "Models not loaded yet. Please wait."} try: # Detect profanity and score inputs = profanity_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) if torch.cuda.is_available(): inputs = inputs.to(device) with torch.no_grad(): outputs = profanity_model(**inputs).logits score = torch.nn.functional.softmax(outputs, dim=1)[0][1].item() # Identify specific profane words words = re.findall(r'\b\w+\b', text) profane_words = [] word_scores = {} if score > threshold: for word in words: if len(word) < 2: # Skip very short words continue word_inputs = profanity_tokenizer(word, return_tensors="pt", truncation=True, max_length=512) if torch.cuda.is_available(): word_inputs = word_inputs.to(device) with torch.no_grad(): word_outputs = profanity_model(**word_inputs).logits word_score = torch.nn.functional.softmax(word_outputs, dim=1)[0][1].item() word_scores[word] = word_score if word_score > threshold: profane_words.append(word.lower()) # Create highlighted version of the text highlighted_text = create_highlighted_text(text, profane_words) return { "text": text, "score": score, "profanity": score > threshold, "profane_words": profane_words, "highlighted_text": highlighted_text, "word_scores": word_scores } except Exception as e: error_msg = f"Error in profanity detection: {str(e)}" logger.error(error_msg) return {"error": error_msg, "text": text, "score": 0, "profanity": False} def create_highlighted_text(text, profane_words): """ Create HTML-formatted text with profane words highlighted """ if not profane_words: return escape(text) # Create a regex pattern matching any of the profane words (case insensitive) pattern = r'\b(' + '|'.join(re.escape(word) for word in profane_words) + r')\b' # Replace occurrences with highlighted versions def highlight_match(match): return f'{match.group(0)}' highlighted = re.sub(pattern, highlight_match, text, flags=re.IGNORECASE) return highlighted def rephrase_profanity(text): """ Rephrase text containing profanity """ if not models_loaded: return "Models not loaded yet. Please wait." try: # Rephrase using the detoxification model inputs = t5_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) if torch.cuda.is_available(): inputs = inputs.to(device) # Use more conservative generation settings with error handling try: outputs = t5_model.generate( **inputs, max_length=512, num_beams=4, # Reduced from 5 to be more memory-efficient early_stopping=True, no_repeat_ngram_size=2, length_penalty=1.0 ) rephrased_text = t5_tokenizer.decode(outputs[0], skip_special_tokens=True) # Verify the output is reasonable if not rephrased_text or len(rephrased_text) < 3: logger.warning(f"T5 model produced unusable output: '{rephrased_text}'") return text # Return original if output is too short return rephrased_text.strip() except RuntimeError as e: # Handle potential CUDA out of memory error if "CUDA out of memory" in str(e): logger.warning("CUDA out of memory in T5 model. Trying with smaller beam size...") # Try again with smaller beam size outputs = t5_model.generate( **inputs, max_length=512, num_beams=2, # Use smaller beam size early_stopping=True ) rephrased_text = t5_tokenizer.decode(outputs[0], skip_special_tokens=True) return rephrased_text.strip() else: raise e # Re-raise if it's not a memory issue except Exception as e: error_msg = f"Error in rephrasing: {str(e)}" logger.error(error_msg) return text # Return original text if rephrasing fails def text_to_speech(text): """ Convert text to speech using SpeechT5 """ if not models_loaded: return None try: # Create a temporary file path to save the audio temp_file = f"temp_tts_output_{int(time.time())}.wav" # Process the text input inputs = tts_processor(text=text, return_tensors="pt") if torch.cuda.is_available(): inputs = inputs.to(device) # Generate speech with a fixed speaker embedding speech = tts_model.generate_speech( inputs["input_ids"], speaker_embeddings, vocoder=vocoder ) # Convert from PyTorch tensor to NumPy array speech_np = speech.cpu().numpy() # Save as WAV file (sampling rate is 16kHz for SpeechT5) write_wav(temp_file, 16000, speech_np) return temp_file except Exception as e: error_msg = f"Error in text-to-speech conversion: {str(e)}" logger.error(error_msg) return None def text_analysis(input_text, threshold=0.5): """ Analyze text for profanity with adjustable threshold """ if not models_loaded: return "Models not loaded yet. Please wait for initialization to complete.", None, None try: # Detect profanity with the given threshold result = detect_profanity(input_text, threshold=threshold) # Handle error case if "error" in result: return result["error"], None, None # Process results if result["profanity"]: clean_text = rephrase_profanity(input_text) profane_words_str = ", ".join(result["profane_words"]) toxicity_score = result["score"] classification = ( "Severe Toxicity" if toxicity_score >= 0.7 else "Moderate Toxicity" if toxicity_score >= 0.5 else "Mild Toxicity" if toxicity_score >= 0.35 else "Minimal Toxicity" if toxicity_score >= 0.2 else "No Toxicity" ) # Generate audio for the rephrased text audio_output = text_to_speech(clean_text) return ( f"Profanity Score: {result['score']:.4f}\n\n" f"Profane: {result['profanity']}\n" f"Classification: {classification}\n" f"Detected Profane Words: {profane_words_str}\n\n" f"Reworded: {clean_text}" ), result["highlighted_text"], audio_output else: # If no profanity detected, just convert the original text to speech audio_output = text_to_speech(input_text) return ( f"Profanity Score: {result['score']:.4f}\n" f"Profane: {result['profanity']}\n" f"Classification: No Toxicity" ), None, audio_output except Exception as e: error_msg = f"Error in text analysis: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) return error_msg, None, None def analyze_audio(audio_path, threshold=0.5): """ Analyze audio for profanity with adjustable threshold """ if not models_loaded: return "Models not loaded yet. Please wait for initialization to complete.", None, None if not audio_path: return "No audio provided.", None, None try: # Transcribe audio result = whisper_model.transcribe(audio_path, fp16=torch.cuda.is_available()) text = result["text"] # Detect profanity with user-defined threshold analysis = detect_profanity(text, threshold=threshold) # Handle error case if "error" in analysis: return f"Error during analysis: {analysis['error']}\nTranscription: {text}", None, None if analysis["profanity"]: clean_text = rephrase_profanity(text) else: clean_text = text # Generate audio for the rephrased text audio_output = text_to_speech(clean_text) return ( f"Transcription: {text}\n\n" f"Profanity Score: {analysis['score']:.4f}\n" f"Profane: {'Yes' if analysis['profanity'] else 'No'}\n" f"Classification: {'Severe Toxicity' if analysis['score'] >= 0.7 else 'Moderate Toxicity' if analysis['score'] >= 0.5 else 'Mild Toxicity' if analysis['score'] >= 0.35 else 'Minimal Toxicity' if analysis['score'] >= 0.2 else 'No Toxicity'}\n" f"Profane Words: {', '.join(analysis['profane_words']) if analysis['profanity'] else 'None'}\n\n" f"Reworded: {clean_text}" ), analysis["highlighted_text"] if analysis["profanity"] else None, audio_output except Exception as e: error_msg = f"Error in audio analysis: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) return error_msg, None, None # Global variables to store streaming results stream_results = { "transcript": "", "profanity_info": "", "clean_text": "", "audio_output": None } def process_stream_chunk(audio_chunk): """Process an audio chunk from the streaming interface""" global stream_results, processing_active if not processing_active or not models_loaded: return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"] try: # The format of audio_chunk from Gradio streaming can vary # It can be: (numpy_array, sample_rate), (filepath, sample_rate, numpy_array) or just numpy_array # Let's handle all possible cases if audio_chunk is None: # No audio received return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"] # Different Gradio versions return different formats temp_file = None if isinstance(audio_chunk, tuple): if len(audio_chunk) == 2: # Format: (numpy_array, sample_rate) samples, sample_rate = audio_chunk temp_file = f"temp_stream_{int(time.time())}.wav" write_wav(temp_file, sample_rate, samples) elif len(audio_chunk) == 3: # Format: (filepath, sample_rate, numpy_array) filepath, sample_rate, samples = audio_chunk # Use the provided filepath if it exists if os.path.exists(filepath): temp_file = filepath else: # Create our own file temp_file = f"temp_stream_{int(time.time())}.wav" write_wav(temp_file, sample_rate, samples) elif isinstance(audio_chunk, np.ndarray): # Just a numpy array, assume sample rate of 16000 for Whisper samples = audio_chunk sample_rate = 16000 temp_file = f"temp_stream_{int(time.time())}.wav" write_wav(temp_file, sample_rate, samples) elif isinstance(audio_chunk, str) and os.path.exists(audio_chunk): # It's a filepath temp_file = audio_chunk else: # Unknown format stream_results["profanity_info"] = f"Error: Unknown audio format: {type(audio_chunk)}" return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"] # Make sure we have a valid file to process if not temp_file or not os.path.exists(temp_file): stream_results["profanity_info"] = "Error: Failed to create audio file for processing" return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"] # Process with Whisper result = whisper_model.transcribe(temp_file, fp16=torch.cuda.is_available()) transcript = result["text"].strip() # Skip processing if transcript is empty if not transcript: # Clean up temp file if we created it if temp_file and temp_file.startswith("temp_stream_") and os.path.exists(temp_file): try: os.remove(temp_file) except: pass # Return current state, but update profanity info stream_results["profanity_info"] = "No speech detected. Keep talking..." return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"] # Update transcript stream_results["transcript"] = transcript # Analyze for profanity analysis = detect_profanity(transcript, threshold=0.5) # Check if profanity was detected if analysis.get("profanity", False): profane_words = ", ".join(analysis.get("profane_words", [])) stream_results["profanity_info"] = f"Profanity Detected (Score: {analysis['score']:.2f})\nProfane Words: {profane_words}" # Rephrase to clean text clean_text = rephrase_profanity(transcript) stream_results["clean_text"] = clean_text # Create audio from cleaned text audio_file = text_to_speech(clean_text) if audio_file: stream_results["audio_output"] = audio_file else: stream_results["profanity_info"] = f"No Profanity Detected (Score: {analysis['score']:.2f})" stream_results["clean_text"] = transcript # Use original text for audio if no profanity audio_file = text_to_speech(transcript) if audio_file: stream_results["audio_output"] = audio_file # Clean up temporary file if we created it if temp_file and temp_file.startswith("temp_stream_") and os.path.exists(temp_file): try: os.remove(temp_file) except: pass return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"] except Exception as e: error_msg = f"Error processing streaming audio: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) # Update profanity info with error message stream_results["profanity_info"] = f"Error: {str(e)}" return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"] def start_streaming(): """Start the real-time audio processing""" global processing_active, stream_results if not models_loaded: return "Models not loaded yet. Please wait for initialization to complete." if processing_active: return "Streaming is already active." # Reset results stream_results = { "transcript": "", "profanity_info": "Waiting for audio input...", "clean_text": "", "audio_output": None } processing_active = True logger.info("Started real-time audio processing") return "Started real-time audio processing. Speak into your microphone." def stop_streaming(): """Stop the real-time audio processing""" global processing_active if not processing_active: return "Streaming is not active." processing_active = False return "Stopped real-time audio processing." def create_ui(): """Create the Gradio UI""" # Simple CSS for styling css = """ /* Fix for dark mode text visibility */ .dark .gr-input, .dark textarea, .dark .gr-textbox, .dark [data-testid="textbox"] { color: white !important; background-color: #2c303b !important; } .dark .gr-box, .dark .gr-form, .dark .gr-panel, .dark .gr-block { color: white !important; } /* Highlighted text container - with dark mode fixes */ .highlighted-text { border: 1px solid #ddd; border-radius: 5px; padding: 10px; margin: 10px 0; background-color: #f9f9f9; font-family: sans-serif; max-height: 300px; overflow-y: auto; color: #333 !important; /* Ensure text is dark for light mode */ } /* Dark mode specific styling for highlighted text */ .dark .highlighted-text { background-color: #2c303b !important; color: #ffffff !important; border-color: #4a4f5a !important; } /* Make sure text in the highlighted container remains visible in both themes */ .highlighted-text, .dark .highlighted-text { color-scheme: light dark; } /* Loading animation */ .loading { display: inline-block; width: 20px; height: 20px; border: 3px solid rgba(0,0,0,.3); border-radius: 50%; border-top-color: #3498db; animation: spin 1s ease-in-out infinite; } @keyframes spin { to { transform: rotate(360deg); } } """ # Create a custom theme based on Soft but explicitly set to light mode light_theme = gr.themes.Soft( primary_hue="blue", secondary_hue="blue", neutral_hue="gray" ) # Set theme to light mode and disable theme switching with gr.Blocks(css=css, theme=light_theme, analytics_enabled=False) as ui: # Model initialization init_status = gr.State("") gr.Markdown( """ # Profanity Detection & Replacement System Detect, rephrase, and listen to cleaned content from text or audio! """, elem_classes="header" ) # The rest of your UI code remains unchanged... # Initialize models button with status indicators with gr.Row(): with gr.Column(scale=3): init_button = gr.Button("Initialize Models", variant="primary") init_output = gr.Textbox(label="Initialization Status", interactive=False) with gr.Column(scale=1): model_status = gr.HTML( """
Model Status: Not Loaded
Model Status: Loaded ✓
Model Status: Error ✗
Model Status: Loading...