sounar's picture
Update app.py
c600b9f verified
raw
history blame
2.87 kB
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from PIL import Image
import gradio as gr
import base64
import io
# Get API token from environment variable
api_token = os.getenv("HF_TOKEN").strip()
# Quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16
)
# Load model with revision pinning
model = AutoModelForCausalLM.from_pretrained(
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
token=api_token,
revision="main"
)
tokenizer = AutoTokenizer.from_pretrained(
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
trust_remote_code=True,
token=api_token,
revision="main"
)
def analyze_input(image_data, question):
try:
# Prepare the prompt
if image_data is not None:
prompt = f"Given the medical image and the question: {question}\nPlease provide a detailed analysis."
else:
prompt = f"Medical question: {question}\nAnswer: "
# Tokenize input
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
# Prepare model inputs
model_inputs = {
"input_ids": input_ids,
"pixel_values": None # Set to None for text-only queries
}
# Generate response
generation_config = {
"max_new_tokens": 256,
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
}
outputs = model.generate(
model_inputs=model_inputs,
**generation_config
)
# Decode and clean up response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the prompt from the response
if prompt in response:
response = response[len(prompt):].strip()
return {
"status": "success",
"response": response
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
# Create Gradio interface
demo = gr.Interface(
fn=analyze_input,
inputs=[
gr.Image(type="numpy", label="Medical Image (Optional)"),
gr.Textbox(label="Question", placeholder="Enter your medical query...")
],
outputs=gr.JSON(label="Analysis"),
title="Medical Query Analysis",
description="Ask medical questions. For now, please focus on text-based queries without images.",
flagging_mode="never"
)
# Launch the interface
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860
)