Last commit not found
import streamlit as st | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MarianMTModel, MarianTokenizer | |
# Load models and tokenizers | |
def load_healthscribe_model(): | |
model_name = "har1/HealthScribe-Clinical_Note_Generator" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
return model, tokenizer | |
def load_translation_model(model_name): | |
model = MarianMTModel.from_pretrained(model_name) | |
tokenizer = MarianTokenizer.from_pretrained(model_name) | |
return model, tokenizer | |
# Initialize models | |
healthscribe_model, healthscribe_tokenizer = load_healthscribe_model() | |
# Language selection options | |
language_options = { | |
"English to French": ("en", "fr"), | |
"French to English": ("fr", "en"), | |
"English to Spanish": ("en", "es"), | |
"Spanish to English": ("es", "en"), | |
"English to German": ("en", "de"), | |
"German to English": ("de", "en"), | |
"English to Italian": ("en", "it"), | |
"Italian to English": ("it", "en"), | |
} | |
# Streamlit UI setup | |
st.title("Multifunctional Text Processing App") | |
st.write("This app can generate clinical notes or translate text between languages.") | |
# Choose task | |
task = st.selectbox("Select a task:", ["Generate Clinical Note", "Translate Text"]) | |
if task == "Generate Clinical Note": | |
st.subheader("Clinical Note Generator") | |
input_text = st.text_area("Enter patient information or medical notes:", height=200) | |
if st.button("Generate Clinical Note"): | |
if input_text.strip(): | |
# Tokenize and generate | |
inputs = healthscribe_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
outputs = healthscribe_model.generate(inputs["input_ids"], max_length=512, num_beams=5, early_stopping=True) | |
generated_note = healthscribe_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Display the result | |
st.subheader("Generated Clinical Note") | |
st.write(generated_note) | |
else: | |
st.warning("Please enter some text to generate a clinical note.") | |
elif task == "Translate Text": | |
st.subheader("Translation Tool") | |
language_pair = st.selectbox("Select language pair", list(language_options.keys())) | |
src_lang, tgt_lang = language_options[language_pair] | |
# Load the corresponding translation model and tokenizer | |
model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}" | |
translation_model, translation_tokenizer = load_translation_model(model_name) | |
# Input text to translate | |
text = st.text_area("Enter text to translate:") | |
if st.button("Translate"): | |
if text.strip(): | |
# Prepare the input for the model | |
inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
# Generate translation | |
translation = translation_model.generate(**inputs) | |
# Decode the output | |
translated_text = translation_tokenizer.decode(translation[0], skip_special_tokens=True) | |
# Display translation | |
st.write("**Original Text**:", text) | |
st.write("**Translated Text**:", translated_text) | |
else: | |
st.warning("Please enter some text to translate.") | |