MultiMed / app.py
Tonic's picture
Update app.py
13def01
raw
history blame
22.4 kB
# Welcome to Team Tonic's MultiMed
from gradio_client import Client
import os
import numpy as np
import base64
import gradio as gr
import tempfile
import requests
import json
import dotenv
from scipy.io.wavfile import write
import PIL
from openai import OpenAI
import time
from PIL import Image
import io
import hashlib
import datetime
from utils import build_logger
from transformers import AutoTokenizer, MistralForCausalLM
import torch
import random
from textwrap import wrap
import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
from peft import PeftModel, PeftConfig
import torch
import os
# Global variables to hold component references
components = {}
dotenv.load_dotenv()
seamless_client = Client("facebook/seamless_m4t")
HuggingFace_Token = os.getenv("HuggingFace_Token")
hf_token = os.getenv("HuggingFace_Token")
def check_hallucination(assertion,citation):
API_URL = "https://api-inference.huggingface.co/models/vectara/hallucination_evaluation_model"
headers = {"Authorization": f"Bearer {HuggingFace_Token}"}
payload = {"inputs" : f"{assertion} [SEP] {citation}"}
response = requests.post(API_URL, headers=headers, json=payload,timeout=120)
output = response.json()
output = output[0][0]["score"]
return f"**hullicination score:** {output}"
# Define the API parameters
VAPI_URL = "https://api-inference.huggingface.co/models/vectara/hallucination_evaluation_model"
headers = {"Authorization": f"Bearer {HuggingFace_Token}"}
# Function to query the API
def query(payload):
response = requests.post(VAPI_URL, headers=headers, json=payload)
return response.json()
# Function to evaluate hallucination
def evaluate_hallucination(input1, input2):
# Combine the inputs
combined_input = f"{input1}. {input2}"
# Make the API call
output = query({"inputs": combined_input})
# Extract the score from the output
score = output[0][0]['score']
# Generate a label based on the score
if score < 0.5:
label = f"🔴 High risk. Score: {score:.2f}"
else:
label = f"🟢 Low risk. Score: {score:.2f}"
return label
def process_speech(input_language, audio_input):
"""
processing sound using seamless_m4t
"""
if audio_input is None :
return "no audio or audio did not save yet \nplease try again ! "
print(f"audio : {audio_input}")
print(f"audio type : {type(audio_input)}")
out = seamless_client.predict(
"S2TT",
"file",
None,
audio_input, #audio_name
"",
input_language,# source language
"English",# target language
api_name="/run",
)
out = out[1] # get the text
try :
return f"{out}"
except Exception as e :
return f"{e}"
def decode_image(encoded_image: str) -> Image:
decoded_bytes = base64.b64decode(encoded_image.encode("utf-8"))
buffer = io.BytesIO(decoded_bytes)
image = Image.open(buffer)
return image
def encode_image(image: Image.Image, format: str = "PNG") -> str:
with io.BytesIO() as buffer:
image.save(buffer, format=format)
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
return encoded_image
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
def get_conv_image_dir():
name = os.path.join(LOGDIR, "images")
os.makedirs(name, exist_ok=True)
return name
def get_image_name(image, image_dir=None):
buffer = io.BytesIO()
image.save(buffer, format="PNG")
image_bytes = buffer.getvalue()
md5 = hashlib.md5(image_bytes).hexdigest()
if image_dir is not None:
image_name = os.path.join(image_dir, md5 + ".png")
else:
image_name = md5 + ".png"
return image_name
def resize_image(image, max_size):
width, height = image.size
aspect_ratio = float(width) / float(height)
if width > height:
new_width = max_size
new_height = int(new_width / aspect_ratio)
else:
new_height = max_size
new_width = int(new_height * aspect_ratio)
resized_image = image.resize((new_width, new_height))
return resized_image
def process_image(image_input):
# Initialize the Gradio client with the URL of the Gradio server
client = Client("https://adept-fuyu-8b-demo.hf.space/--replicas/pqjvl/")
# Check if the image input is a NumPy array
if isinstance(image_input, np.ndarray):
# Convert the NumPy array to a PIL Image
image = Image.fromarray(image_input)
# Save the PIL Image to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
image.save(tmp_file.name)
image_path = tmp_file.name
elif isinstance(image_input, str):
try:
# Try to decode if it's a base64 string
image = decode_image(image_input)
except Exception:
# If decoding fails, assume it's a file path or a URL
image_path = image_input
else:
# If decoding succeeds, save the decoded image to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
image.save(tmp_file.name)
image_path = tmp_file.name
else:
# Assuming it's a PIL Image, save it to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
image_input.save(tmp_file.name)
image_path = tmp_file.name
# Call the predict method of the client
result = client.predict(
image_path, # File path or URL of the image
True, # Additional parameter for the server (e.g., enable detailed captioning)
fn_index=2 # Function index if the server has multiple functions
)
# Clean up the temporary file if created
if not isinstance(image_input, str) or isinstance(image_input, str) and 'tmp' in image_path:
os.remove(image_path)
return result
def query_vectara(text):
user_message = text
# Read authentication parameters from the .env file
CUSTOMER_ID = os.getenv('CUSTOMER_ID')
CORPUS_ID = os.getenv('CORPUS_ID')
API_KEY = os.getenv('API_KEY')
# Define the headers
api_key_header = {
"customer-id": CUSTOMER_ID,
"x-api-key": API_KEY
}
# Define the request body in the structure provided in the example
request_body = {
"query": [
{
"query": user_message,
"queryContext": "",
"start": 1,
"numResults": 25,
"contextConfig": {
"charsBefore": 0,
"charsAfter": 0,
"sentencesBefore": 2,
"sentencesAfter": 2,
"startTag": "%START_SNIPPET%",
"endTag": "%END_SNIPPET%",
},
"rerankingConfig": {
"rerankerId": 272725718,
"mmrConfig": {
"diversityBias": 0.35
}
},
"corpusKey": [
{
"customerId": CUSTOMER_ID,
"corpusId": CORPUS_ID,
"semantics": 0,
"metadataFilter": "",
"lexicalInterpolationConfig": {
"lambda": 0
},
"dim": []
}
],
"summary": [
{
"maxSummarizedResults": 5,
"responseLang": "auto",
"summarizerPromptName": "vectara-summary-ext-v1.2.0"
}
]
}
]
}
# Make the API request using Gradio
response = requests.post(
"https://api.vectara.io/v1/query",
json=request_body, # Use json to automatically serialize the request body
verify=True,
headers=api_key_header
)
if response.status_code == 200:
query_data = response.json()
if query_data:
sources_info = []
# Extract the summary.
summary = query_data['responseSet'][0]['summary'][0]['text']
# Iterate over all response sets
for response_set in query_data.get('responseSet', []):
# Extract sources
# Limit to top 5 sources.
for source in response_set.get('response', [])[:5]:
source_metadata = source.get('metadata', [])
source_info = {}
for metadata in source_metadata:
metadata_name = metadata.get('name', '')
metadata_value = metadata.get('value', '')
if metadata_name == 'title':
source_info['title'] = metadata_value
elif metadata_name == 'author':
source_info['author'] = metadata_value
elif metadata_name == 'pageNumber':
source_info['page number'] = metadata_value
if source_info:
sources_info.append(source_info)
result = {"summary": summary, "sources": sources_info}
return f"{json.dumps(result, indent=2)}"
else:
return "No data found in the response."
else:
return f"Error: {response.status_code}"
def convert_to_markdown(vectara_response_json):
vectara_response = json.loads(vectara_response_json)
if vectara_response:
summary = vectara_response.get('summary', 'No summary available')
sources_info = vectara_response.get('sources', [])
# Format the summary as Markdown
markdown_summary = f' {summary}\n\n'
# Format the sources as a numbered list
markdown_sources = ""
for i, source_info in enumerate(sources_info):
author = source_info.get('author', 'Unknown author')
title = source_info.get('title', 'Unknown title')
page_number = source_info.get('page number', 'Unknown page number')
markdown_sources += f"{i+1}. {title} by {author}, Page {page_number}\n"
return f"{markdown_summary}**Sources:**\n{markdown_sources}"
else:
return "No data found in the response."
# Functions to Wrap the Prompt Correctly
def wrap_text(text, width=90):
lines = text.split('\n')
wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
wrapped_text = '\n'.join(wrapped_lines)
return wrapped_text
def multimodal_prompt(user_input, system_prompt="You are an expert medical analyst:"):
# Combine user input and system prompt
formatted_input = f"{user_input}{system_prompt}"
# Encode the input text
encodeds = tokenizer(formatted_input, return_tensors="pt", add_special_tokens=False)
model_inputs = encodeds.to(device)
# Generate a response using the model
output = model.generate(
**model_inputs,
max_length=max_length,
use_cache=True,
early_stopping=True,
bos_token_id=model.config.bos_token_id,
eos_token_id=model.config.eos_token_id,
pad_token_id=model.config.eos_token_id,
temperature=0.1,
do_sample=True
)
# Decode the response
response_text = tokenizer.decode(output[0], skip_special_tokens=True)
return response_text
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Use the base model's ID
base_model_id = "stabilityai/stablelm-3b-4e1t"
model_directory = "Tonic/stablemed"
# Instantiate the Tokenizer
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-3b-4e1t", token=hf_token, trust_remote_code=True, padding_side="left")
# tokenizer = AutoTokenizer.from_pretrained("Tonic/stablemed", trust_remote_code=True, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
# Load the PEFT model
peft_config = PeftConfig.from_pretrained("Tonic/stablemed", token=hf_token)
peft_model = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-3b-4e1t", token=hf_token, trust_remote_code=True)
peft_model = PeftModel.from_pretrained(peft_model, "Tonic/stablemed", token=hf_token)
class ChatBot:
def __init__(self):
self.history = []
def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
formatted_input = f"<s>[INST]{system_prompt} {user_input}[/INST]"
user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")
response = peft_model.generate(input_ids=user_input_ids, max_length=512, pad_token_id=tokenizer.eos_token_id)
response_text = tokenizer.decode(response[0], skip_special_tokens=True)
return response_text
bot = ChatBot()
def process_summary_with_stablemed(summary):
system_prompt = "You are a medical instructor . Assess and describe the proper options to your students in minute detail. Propose a course of action for them to base their recommendations on based on your description."
response_text = bot.predict(summary, system_prompt)
return response_text
# Main function to handle the Gradio interface logic
def process_and_query(input_language=None, audio_input=None, image_input=None, text_input=None):
components['speech_to_text'].hide()
components['image_identification'].hide()
components['text_summarization'].hide()
components['results'].show()
try:
# Initialize the conditional variables
combined_text = ""
image_description = ""
# Process text input
if text_input is not None:
combined_text = "The user asks the following to his health adviser: " + text_input
# Process audio input
if audio_input is not None:
audio_text = process_speech(input_language, audio_input)
print("Audio Text:", audio_text) # Debug print
combined_text += "\n" + audio_text
# Process image input
if image_input is not None:
image_text = process_image(image_input) # Call process_image with only the image input
print("Image Text:", image_text) # Debug print
combined_text += "\n" + image_text
# Check if combined text is empty
if not combined_text.strip():
return "Error: Please provide some input (text, audio, or image).", "No hallucination evaluation"
# Use the text to query Vectara
vectara_response_json = query_vectara(combined_text)
print("Vectara Response:", vectara_response_json) # Debug print
# Convert the Vectara response to Markdown
markdown_output = convert_to_markdown(vectara_response_json)
# Append the original image description to the markdown output
if image_description:
markdown_output += "\n\n**Original Image Description:**\n" + image_description
# Process the summary with OpenAI
final_response = process_summary_with_stablemed(markdown_output)
print("Final Response:", final_response) # Debug print
# Evaluate hallucination
hallucination_label = evaluate_hallucination(final_response, markdown_output)
print("Hallucination Label:", hallucination_label) # Debug print
return final_response, hallucination_label
except Exception as e:
# Handle exceptions
print(f"An error occurred: {e}")
return "Error occurred during processing.", "No hallucination evaluation"
welcome_message = """
# 👋🏻Welcome to ⚕🗣️😷MultiMed - Access Chat ⚕🗣️😷
🗣️📝 This is an educational and accessible conversational tool.
### How To Use ⚕🗣️😷MultiMed⚕:
🗣️📝Interact with ⚕🗣️😷MultiMed⚕ in any language using image, audio or text!
📚🌟💼 that uses [Tonic/stablemed](https://huggingface.co/Tonic/stablemed) and [adept/fuyu-8B](https://huggingface.co/adept/fuyu-8b) with [Vectara](https://huggingface.co/vectara) embeddings + retrieval.
do [get in touch](https://discord.gg/GWpVpekp). You can also use 😷MultiMed⚕️ on your own data & in your own way by cloning this space. 🧬🔬🔍 Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/TeamTonic/MultiMed?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></h3>
### Join us :
🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community on 👻Discord: [Discord](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Polytonic](https://github.com/tonic-ai) & contribute to 🌟 [PolyGPT](https://github.com/tonic-ai/polygpt-alpha)"
"""
languages = [
"Afrikaans",
"Amharic",
"Modern Standard Arabic",
"Moroccan Arabic",
"Egyptian Arabic",
"Assamese",
"Asturian",
"North Azerbaijani",
"Belarusian",
"Bengali",
"Bosnian",
"Bulgarian",
"Catalan",
"Cebuano",
"Czech",
"Central Kurdish",
"Mandarin Chinese",
"Welsh",
"Danish",
"German",
"Greek",
"English",
"Estonian",
"Basque",
"Finnish",
"French",
"West Central Oromo",
"Irish",
"Galician",
"Gujarati",
"Hebrew",
"Hindi",
"Croatian",
"Hungarian",
"Armenian",
"Igbo",
"Indonesian",
"Icelandic",
"Italian",
"Javanese",
"Japanese",
"Kamba",
"Kannada",
"Georgian",
"Kazakh",
"Kabuverdianu",
"Halh Mongolian",
"Khmer",
"Kyrgyz",
"Korean",
"Lao",
"Lithuanian",
"Luxembourgish",
"Ganda",
"Luo",
"Standard Latvian",
"Maithili",
"Malayalam",
"Marathi",
"Macedonian",
"Maltese",
"Meitei",
"Burmese",
"Dutch",
"Norwegian Nynorsk",
"Norwegian Bokmål",
"Nepali",
"Nyanja",
"Occitan",
"Odia",
"Punjabi",
"Southern Pashto",
"Western Persian",
"Polish",
"Portuguese",
"Romanian",
"Russian",
"Slovak",
"Slovenian",
"Shona",
"Sindhi",
"Somali",
"Spanish",
"Serbian",
"Swedish",
"Swahili",
"Tamil",
"Telugu",
"Tajik",
"Tagalog",
"Thai",
"Turkish",
"Ukrainian",
"Urdu",
"Northern Uzbek",
"Vietnamese",
"Xhosa",
"Yoruba",
"Cantonese",
"Colloquial Malay",
"Standard Malay",
"Zulu"
]
def process_and_query(input_language, audio_input, image_input, text_input):
# Your processing logic here
# Hide input components and show result components after processing
components['speech_to_text'].hide()
components['image_identification'].hide()
components['text_summarization'].hide()
components['results'].show()
# Return the processed text and hallucination evaluation
return "Processed Text in " + input_language, "Hallucination Evaluation"
def clear():
components['language_selection'].reset()
components['speech_to_text'].hide()
components['image_identification'].hide()
components['text_summarization'].hide()
components['results'].hide()
def on_language_change(language):
if language:
components['speech_to_text'].show()
components['image_identification'].show()
components['text_summarization'].show()
else:
components['speech_to_text'].hide()
components['image_identification'].hide()
components['text_summarization'].hide()
with gr.Blocks(theme='ParityError/Anime') as iface:
with gr.Row() as language_selection:
input_language = gr.Dropdown(languages, label="Select the language", value="English", interactive=True)
input_language.change(on_language_change)
components['language_selection'] = language_selection
with gr.Accordion("Speech to Text", open=False) as speech_to_text:
audio_input = gr.Audio(label="Speak", type="filepath", sources="microphone")
audio_output = gr.Markdown(label="Output text")
components['speech_to_text'] = speech_to_text
with gr.Accordion("Image Identification", open=False) as image_identification:
image_input = gr.Image(label="Upload image")
image_output = gr.Markdown(label="Output text")
components['image_identification'] = image_identification
with gr.Accordion("Text Summarization", open=False) as text_summarization:
text_input = gr.Textbox(label="Input text", lines=5)
text_output = gr.Markdown(label="Output text")
text_button = gr.Button("Process text")
hallucination_output = gr.Label(label="Hallucination Evaluation")
components['text_summarization'] = text_summarization
with gr.Row() as results:
text_output = gr.Markdown()
hallucination_output = gr.Label()
components['results'] = results
clear_button = gr.Button("Clear")
clear_button.click(clear, inputs=[], outputs=[])
text_button.click(process_and_query, inputs=[input_language, audio_input, image_input, text_input], outputs=[text_output, hallucination_output])
# Initially hide all blocks except language selection
components['speech_to_text'].hide()
components['image_identification'].hide()
components['text_summarization'].hide()
components['results'].hide()
iface.launch(show_error=True, debug=True)