sounar's picture
Update app.py
8e90fc6 verified
raw
history blame
2.82 kB
import os
import torch
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
import gradio as gr
from PIL import Image
from torchvision.transforms import ToTensor
# Get API token from environment variable
api_token = os.getenv("HF_TOKEN").strip()
# Quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16
)
# Initialize model and tokenizer
model = AutoModel.from_pretrained(
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
attn_implementation="flash_attention_2",
token=api_token
)
tokenizer = AutoTokenizer.from_pretrained(
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
trust_remote_code=True,
token=api_token
)
def analyze_input(image, question):
try:
if image is not None:
# Convert to RGB if image is provided
image = image.convert('RGB')
# Prepare messages in the format expected by the model
msgs = [{'role': 'user', 'content': [image, question]}]
# Generate response using the chat method
response_stream = model.chat(
image=image,
msgs=msgs,
tokenizer=tokenizer,
sampling=True,
temperature=0.95,
stream=True
)
# Collect the streamed response
generated_text = ""
for new_text in response_stream:
generated_text += new_text
print(new_text, flush=True, end='')
return {"status": "success", "response": generated_text}
except Exception as e:
import traceback
error_trace = traceback.format_exc()
print(f"Error occurred: {error_trace}")
return {"status": "error", "message": str(e)}
# Create Gradio interface
demo = gr.Interface(
fn=analyze_input,
inputs=[
gr.Image(type="pil", label="Upload Medical Image"),
gr.Textbox(
label="Medical Question",
placeholder="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?",
value="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?"
)
],
outputs=gr.JSON(label="Analysis"),
title="Medical Image Analysis Assistant",
description="Upload a medical image and ask questions about it. The AI will analyze the image and provide detailed responses."
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860
)