LocalScribe1 / app.py
KG0101's picture
Update app.py
b39f388 verified
raw
history blame
5.23 kB
import spaces
import torch
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from typing import Iterator
import os
MODEL_NAME = "openai/whisper-large-v3-turbo"
BATCH_SIZE = 8
FILE_LIMIT_MB = 5000
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = 0 if torch.cuda.is_available() else "cpu"
# Initialize the LLM
if torch.cuda.is_available():
llm_model_id = "chuanli11/Llama-3.2-3B-Instruct-uncensored"
llm = AutoModelForCausalLM.from_pretrained(llm_model_id, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
tokenizer.use_default_system_prompt = False
# Initialize the transcription pipeline
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
chunk_length_s=30,
device=device,
)
# Function to transcribe audio inputs
@spaces.GPU
def transcribe(inputs, task):
if inputs is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
return text
# Function to generate SOAP notes using LLM
@spaces.GPU
def generate_soap(
transcribed_text: str,
system_prompt: str = "You are a world class clinical assistant.",
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
task_prompt = """
Convert the following transcribed conversation into a clinical SOAP note.
The text includes dialogue between a physician and a patient. Please clearly distinguish between the physician's and the patient's statements.
Extract and organize the information into the relevant sections of a SOAP note:
- Subjective (symptoms and patient statements),
- Objective (clinical findings and observations),
- Assessment (diagnosis or potential diagnoses),
- Plan (treatment and follow-up).
Ensure the note is concise, clear, and accurately reflects the conversation.
"""
conversation = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"{task_prompt}\n\nTranscribed conversation:\n{transcribed_text}"}
]
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(llm.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=llm.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Combine transcription and SOAP generation
def transcribe_and_generate_soap(audio, task, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
# First transcribe the audio
transcribed_text = transcribe(audio, task)
# Then generate SOAP notes based on the transcription and collect all generated text into a single string
generated_text = "".join(generate_soap(
transcribed_text,
system_prompt,
max_new_tokens,
temperature,
top_p,
top_k,
repetition_penalty
))
return generated_text
# Gradio Interface combining transcription and SOAP note generation
demo = gr.Blocks(theme=gr.themes.Ocean())
with demo:
with gr.Tab("Clinical SOAP Note from Audio"):
audio_transcribe_and_soap = gr.Interface(
fn=transcribe_and_generate_soap,
inputs=[
gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio Input"),
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
gr.Textbox(label="System Prompt", lines=2, value="You are a world class clinical assistant."),
gr.Slider(label="Max new tokens", minimum=1, maximum=2048, value=1024, step=1),
gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, value=0.6, step=0.1),
gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, value=0.9, step=0.05),
gr.Slider(label="Top-k", minimum=1, maximum=1000, value=50, step=1),
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.05)
],
outputs="text",
title="Generate Clinical SOAP Note from Audio",
description="Transcribe audio input and convert it into a structured clinical SOAP note."
)
demo.queue().launch(ssr_mode=False)