Spaces:
Paused
Paused
import os | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
from PIL import Image | |
import gradio as gr | |
import base64 | |
import io | |
# Get API token from environment variable | |
api_token = os.getenv("HF_TOKEN").strip() | |
# Quantization configuration | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.float16 | |
) | |
# Load model with revision pinning | |
model = AutoModelForCausalLM.from_pretrained( | |
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1", | |
quantization_config=bnb_config, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
token=api_token, | |
revision="main" | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1", | |
trust_remote_code=True, | |
token=api_token, | |
revision="main" | |
) | |
def analyze_input(image_data, question): | |
try: | |
# Prepare the prompt | |
if image_data is not None: | |
prompt = f"Given the medical image and the question: {question}\nPlease provide a detailed analysis." | |
else: | |
prompt = f"Medical question: {question}\nAnswer: " | |
# Tokenize input | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) | |
# Prepare model inputs | |
model_inputs = { | |
"input_ids": input_ids, | |
"pixel_values": None # Set to None for text-only queries | |
} | |
# Generate response | |
generation_config = { | |
"max_new_tokens": 256, | |
"do_sample": True, | |
"temperature": 0.7, | |
"top_p": 0.9, | |
} | |
outputs = model.generate( | |
model_inputs=model_inputs, | |
**generation_config | |
) | |
# Decode and clean up response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Remove the prompt from the response | |
if prompt in response: | |
response = response[len(prompt):].strip() | |
return { | |
"status": "success", | |
"response": response | |
} | |
except Exception as e: | |
return { | |
"status": "error", | |
"message": str(e) | |
} | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=analyze_input, | |
inputs=[ | |
gr.Image(type="numpy", label="Medical Image (Optional)"), | |
gr.Textbox(label="Question", placeholder="Enter your medical query...") | |
], | |
outputs=gr.JSON(label="Analysis"), | |
title="Medical Query Analysis", | |
description="Ask medical questions. For now, please focus on text-based queries without images.", | |
flagging_mode="never" | |
) | |
# Launch the interface | |
demo.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |