sounar's picture
Update app.py
3d49b4f verified
raw
history blame
2.66 kB
import os
import torch
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
from PIL import Image
import gradio as gr
import base64
import io
# Get API token from environment variable
api_token = os.getenv("HF_TOKEN").strip()
# Quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16
)
# Load model
model = AutoModel.from_pretrained(
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
token=api_token
)
tokenizer = AutoTokenizer.from_pretrained(
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
trust_remote_code=True,
token=api_token
)
def analyze_input(image_data, question):
try:
# Handle base64 image if provided
if isinstance(image_data, str) and image_data.startswith('data:image'):
# Extract base64 data after the comma
base64_data = image_data.split(',')[1]
image_bytes = base64.b64decode(base64_data)
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
# Handle direct image input
elif image_data is not None:
image = Image.fromarray(image_data).convert('RGB')
else:
image = None
# Process with or without image
if image is not None:
inputs = model.prepare_inputs_for_generation(
input_ids=tokenizer(question, return_tensors="pt").input_ids,
images=[image]
)
else:
inputs = tokenizer(question, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=256)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {
"status": "success",
"response": response
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
# Create Gradio interface
demo = gr.Interface(
fn=analyze_input,
inputs=[
gr.Image(type="numpy", label="Medical Image"), # Removed optional parameter
gr.Textbox(label="Question", placeholder="Enter your medical query...")
],
outputs=gr.JSON(label="Analysis"),
title="Bio-Medical MultiModal Analysis",
description="Ask questions with or without an image",
allow_flagging="never",
)
# Launch with API access enabled
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
enable_queue=True
)