Spaces:
Paused
Paused
import os | |
import torch | |
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig | |
from PIL import Image | |
import gradio as gr | |
import base64 | |
import io | |
# Get API token from environment variable | |
api_token = os.getenv("HF_TOKEN").strip() | |
# Quantization configuration | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.float16 | |
) | |
# Load model | |
model = AutoModel.from_pretrained( | |
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1", | |
quantization_config=bnb_config, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
token=api_token | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1", | |
trust_remote_code=True, | |
token=api_token | |
) | |
def analyze_input(image_data, question): | |
try: | |
# Handle base64 image if provided | |
if isinstance(image_data, str) and image_data.startswith('data:image'): | |
# Extract base64 data after the comma | |
base64_data = image_data.split(',')[1] | |
image_bytes = base64.b64decode(base64_data) | |
image = Image.open(io.BytesIO(image_bytes)).convert('RGB') | |
# Handle direct image input | |
elif image_data is not None: | |
image = Image.fromarray(image_data).convert('RGB') | |
else: | |
image = None | |
# Process with or without image | |
if image is not None: | |
inputs = model.prepare_inputs_for_generation( | |
input_ids=tokenizer(question, return_tensors="pt").input_ids, | |
images=[image] | |
) | |
else: | |
inputs = tokenizer(question, return_tensors="pt") | |
outputs = model.generate(**inputs, max_new_tokens=256) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return { | |
"status": "success", | |
"response": response | |
} | |
except Exception as e: | |
return { | |
"status": "error", | |
"message": str(e) | |
} | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=analyze_input, | |
inputs=[ | |
gr.Image(type="numpy", label="Medical Image"), # Removed optional parameter | |
gr.Textbox(label="Question", placeholder="Enter your medical query...") | |
], | |
outputs=gr.JSON(label="Analysis"), | |
title="Bio-Medical MultiModal Analysis", | |
description="Ask questions with or without an image", | |
allow_flagging="never", | |
) | |
# Launch with API access enabled | |
demo.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860, | |
enable_queue=True | |
) |