import gradio as gr from transformers import WhisperProcessor, WhisperForConditionalGeneration import torch import torchaudio # Load model and processor processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2") model = WhisperForConditionalGeneration.from_pretrained("aiola/whisper-ner-v1") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) def unify_ner_text(text, symbols_to_replace=("/", " ", ":", "_")): """Process and standardize entity text by replacing certain symbols and normalizing spaces.""" text = " ".join(text.split()) for symbol in symbols_to_replace: text = text.replace(symbol, "-") return text.lower() def transcribe_and_recognize_entities(audio_file, prompt): target_sample_rate = 16000 signal, sampling_rate = torchaudio.load(audio_file) resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=target_sample_rate) signal = resampler(signal) if signal.ndim == 2: signal = torch.mean(signal, dim=0) signal = signal.cpu() # Ensure signal is on CPU for processing input_features = processor(signal, sampling_rate=target_sample_rate, return_tensors="pt").input_features input_features = input_features.to(device) # Split the prompt into individual NER types and process each one ner_types = prompt.split(',') processed_ner_types = [unify_ner_text(ner_type.strip()) for ner_type in ner_types] prompt = ", ".join(processed_ner_types) print(f"Prompt after unify_ner_text: {prompt}") prompt_ids = processor.get_prompt_ids(prompt, return_tensors="pt") prompt_ids = prompt_ids.to(device) predicted_ids = model.generate( input_features, max_new_tokens=256, prompt_ids=prompt_ids, language='en', # Ensure transcription is translated to English generation_config=model.generation_config, ) # slice only the output without the prompt itself at the start. transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] # Determine the length of the prompt in the transcription prompt_length_in_transcription = len(prompt) # Slice the transcription to remove the prompt itself from the output transcription = transcription[prompt_length_in_transcription + 1:] return transcription # Define Gradio interface iface = gr.Interface( fn=transcribe_and_recognize_entities, inputs=[ gr.Audio(label="Upload Audio", type="filepath"), gr.Textbox(label="Entity Recognition Prompt"), ], outputs=gr.Textbox(label="Transcription and Entities"), title="Whisper-NER Demo", description="Upload an audio file and enter entities to identify. The model will transcribe the audio and recognize entities." ) # iface.launch() iface.launch(share=True)