import spaces import torch import gradio as gr from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread from typing import Iterator import os MODEL_NAME = "openai/whisper-large-v3-turbo" BATCH_SIZE = 8 FILE_LIMIT_MB = 1000 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) device = 0 if torch.cuda.is_available() else "cpu" # Initialize the LLM if torch.cuda.is_available(): llm_model_id = "NousResearch/Meta-Llama-3.1-8B-Instruct" llm = AutoModelForCausalLM.from_pretrained(llm_model_id, torch_dtype=torch.float16, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(llm_model_id) tokenizer.use_default_system_prompt = False # Initialize the transcription pipeline pipe = pipeline( task="automatic-speech-recognition", model=MODEL_NAME, chunk_length_s=30, device=device, ) # Function to transcribe audio inputs @spaces.GPU def transcribe(inputs, task): if inputs is None: raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"] return text # Function to generate SOAP notes using LLM @spaces.GPU def generate_soap( transcribed_text: str, system_prompt: str = "You are a world class clinical assistant.", max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ) -> Iterator[str]: task_prompt = """ Convert the following transcribed conversation into a clinical SOAP note. The text includes dialogue between a physician and a patient. Please clearly distinguish between the physician's and the patient's statements. Extract and organize the information into the relevant sections of a SOAP note: - Subjective (symptoms and patient statements), - Objective (clinical findings and observations), - Assessment (diagnosis or potential diagnoses), - Plan (treatment and follow-up). Ensure the note is concise, clear, and accurately reflects the conversation. """ conversation = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": f"{task_prompt}\n\nTranscribed conversation:\n{transcribed_text}"} ] input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(llm.device) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids}, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, repetition_penalty=repetition_penalty, ) t = Thread(target=llm.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) # Combine transcription and SOAP generation def transcribe_and_generate_soap(audio, task, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty): # First transcribe the audio transcribed_text = transcribe(audio, task) # Then generate SOAP notes based on the transcription return generate_soap( transcribed_text, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty ) # Gradio Interface combining transcription and SOAP note generation demo = gr.Blocks(theme=gr.themes.Ocean()) with demo: with gr.Tab("Clinical SOAP Note from Audio"): audio_transcribe_and_soap = gr.Interface( fn=transcribe_and_generate_soap, inputs=[ gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio Input"), gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"), gr.Textbox(label="System Prompt", lines=2, value="You are a world class clinical assistant."), gr.Slider(label="Max new tokens", minimum=1, maximum=2048, value=1024, step=1), gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, value=0.6, step=0.1), gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, value=0.9, step=0.05), gr.Slider(label="Top-k", minimum=1, maximum=1000, value=50, step=1), gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.05) ], outputs="text", title="Generate Clinical SOAP Note from Audio", description="Transcribe audio input and convert it into a structured clinical SOAP note." ) demo.queue().launch(ssr_mode=False)