LocalScribe1 / app.py
KG0101's picture
Update to separate outputs
5320fa6 verified
import spaces
import torch
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from typing import Iterator
import os
MODEL_NAME = "openai/whisper-large-v3-turbo"
BATCH_SIZE = 8
FILE_LIMIT_MB = 5000
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = 0 if torch.cuda.is_available() else "cpu"
# Initialize the LLM
if torch.cuda.is_available():
llm_model_id = "chuanli11/Llama-3.2-3B-Instruct-uncensored"
llm = AutoModelForCausalLM.from_pretrained(llm_model_id, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
tokenizer.use_default_system_prompt = False
# Initialize the transcription pipeline
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
chunk_length_s=30,
device=device,
)
# Function to transcribe audio inputs
@spaces.GPU
def transcribe(inputs, task):
if inputs is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
return text
# Function to generate SOAP notes using LLM
@spaces.GPU
def generate_soap(
transcribed_text: str,
system_prompt: str = "You are a world class clinical assistant.",
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
task_prompt = """
Convert the following transcribed conversation into a clinical SOAP note.
The text includes dialogue between a physician and a patient. Please clearly distinguish between the physician's and the patient's statements.
Extract and organize the information into the relevant sections of a SOAP note:
- Subjective (symptoms and patient statements),
- Objective (clinical findings and observations),
- Assessment (diagnosis or potential diagnoses),
- Plan (treatment and follow-up).
Ensure the note is concise, clear, and accurately reflects the conversation.
"""
conversation = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"{task_prompt}\n\nTranscribed conversation:\n{transcribed_text}"}
]
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(llm.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=llm.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Gradio Interface combining transcription and SOAP note generation
demo = gr.Blocks(theme=gr.themes.Ocean())
with demo:
with gr.Tab("Clinical SOAP Note from Audio"):
# Transcription Interface
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio Input")
task_input = gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")
transcription_output = gr.Textbox(label="Transcription Output")
# Transcription button
transcribe_button = gr.Button("Transcribe")
transcribe_button.click(fn=transcribe, inputs=[audio_input, task_input], outputs=transcription_output)
# SOAP Generation Interface
transcribed_text_input = gr.Textbox(label="Edit Transcription before SOAP Generation", lines=5)
system_prompt_input = gr.Textbox(label="System Prompt", lines=2, value="You are a world class clinical assistant.")
max_new_tokens_input = gr.Slider(label="Max new tokens", minimum=1, maximum=2048, value=1024, step=1)
temperature_input = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, value=0.6, step=0.1)
top_p_input = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, value=0.9, step=0.05)
top_k_input = gr.Slider(label="Top-k", minimum=1, maximum=1000, value=50, step=1)
repetition_penalty_input = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.05)
soap_output = gr.Textbox(label="Generated SOAP Note Output")
# SOAP generation button
generate_soap_button = gr.Button("Generate SOAP Note")
generate_soap_button.click(
fn=generate_soap,
inputs=[
transcribed_text_input,
system_prompt_input,
max_new_tokens_input,
temperature_input,
top_p_input,
top_k_input,
repetition_penalty_input
],
outputs=soap_output
)
demo.queue().launch(ssr_mode=False)