import gradio as gr from transformers import WhisperProcessor, WhisperForConditionalGeneration import torch import torchaudio import spaces # Initialize devices device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model and processor processor = WhisperProcessor.from_pretrained("aiola/whisper-ner-v1") model = WhisperForConditionalGeneration.from_pretrained("aiola/whisper-ner-v1") model = model.to(device) def unify_ner_text(text, symbols_to_replace=("/", " ", ":", "_")): """Process and standardize entity text by replacing certain symbols and normalizing spaces.""" text = " ".join(text.split()) for symbol in symbols_to_replace: text = text.replace(symbol, "-") return text.lower() @spaces.GPU # This decorator ensures your function can use GPU on Hugging Face Spaces def transcribe_and_recognize_entities(audio_file, prompt): target_sample_rate = 16000 signal, sampling_rate = torchaudio.load(audio_file) resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=target_sample_rate) signal = resampler(signal) if signal.ndim == 2: signal = torch.mean(signal, dim=0) input_features = processor(signal, sampling_rate=target_sample_rate, return_tensors="pt").input_features input_features = input_features.to(device) ner_types = prompt.split(',') processed_ner_types = [unify_ner_text(ner_type.strip()) for ner_type in ner_types] prompt = ", ".join(processed_ner_types) print(f"Prompt after unify_ner_text: {prompt}") prompt_ids = processor.get_prompt_ids(prompt, return_tensors="pt") prompt_ids = prompt_ids.to(device) predicted_ids = model.generate( input_features, max_new_tokens=256, prompt_ids=prompt_ids, language='en', generation_config=model.generation_config, ) transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] prompt_length_in_transcription = len(prompt) transcription = transcription[prompt_length_in_transcription + 1:] # Remove the prompt return transcription iface = gr.Interface( fn=transcribe_and_recognize_entities, inputs=[gr.Audio(label="Upload Audio", type="filepath"), gr.Textbox(label="Entity Recognition Prompt")], outputs=gr.Textbox(label="Transcription and Entities"), title="Whisper-NER Demo", description="Upload an audio file and enter entities to identify. The model will transcribe the audio and recognize entities." ) iface.launch(share=True)