Spaces:
Paused
Paused
File size: 2,660 Bytes
2974476 72ff248 a0ba541 2974476 832ce7b 2974476 2bf9d03 a0ba541 72ff248 3d49b4f a0ba541 fa9231d 0b2a88c 2974476 72ff248 3d49b4f 2bf9d03 fa9231d 882bd69 72ff248 3d49b4f 2bf9d03 72ff248 9998c92 3d49b4f a0ba541 2974476 a0ba541 2974476 a0ba541 2974476 a0ba541 2974476 4f5fa66 2974476 4f5fa66 3d49b4f 2974476 4f5fa66 2974476 4f5fa66 2629ae5 2974476 3d49b4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import os
import torch
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
from PIL import Image
import gradio as gr
import base64
import io
# Get API token from environment variable
api_token = os.getenv("HF_TOKEN").strip()
# Quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16
)
# Load model
model = AutoModel.from_pretrained(
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
token=api_token
)
tokenizer = AutoTokenizer.from_pretrained(
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
trust_remote_code=True,
token=api_token
)
def analyze_input(image_data, question):
try:
# Handle base64 image if provided
if isinstance(image_data, str) and image_data.startswith('data:image'):
# Extract base64 data after the comma
base64_data = image_data.split(',')[1]
image_bytes = base64.b64decode(base64_data)
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
# Handle direct image input
elif image_data is not None:
image = Image.fromarray(image_data).convert('RGB')
else:
image = None
# Process with or without image
if image is not None:
inputs = model.prepare_inputs_for_generation(
input_ids=tokenizer(question, return_tensors="pt").input_ids,
images=[image]
)
else:
inputs = tokenizer(question, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=256)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {
"status": "success",
"response": response
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
# Create Gradio interface
demo = gr.Interface(
fn=analyze_input,
inputs=[
gr.Image(type="numpy", label="Medical Image"), # Removed optional parameter
gr.Textbox(label="Question", placeholder="Enter your medical query...")
],
outputs=gr.JSON(label="Analysis"),
title="Bio-Medical MultiModal Analysis",
description="Ask questions with or without an image",
allow_flagging="never",
)
# Launch with API access enabled
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
enable_queue=True
) |