|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoProcessor, pipeline |
|
from PIL import Image |
|
import torch |
|
import warnings |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
phi_model_id = "microsoft/Phi-3.5-vision-instruct" |
|
try: |
|
phi_model = AutoModelForCausalLM.from_pretrained( |
|
phi_model_id, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16, |
|
_attn_implementation="eager" |
|
) |
|
except ImportError: |
|
print("FlashAttention not available, falling back to eager implementation.") |
|
phi_model = AutoModelForCausalLM.from_pretrained( |
|
phi_model_id, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16, |
|
_attn_implementation="eager" |
|
) |
|
|
|
phi_processor = AutoProcessor.from_pretrained(phi_model_id, trust_remote_code=True) |
|
|
|
|
|
llama_model_id = "meta-llama/Llama-3.1-8B" |
|
try: |
|
llama_pipeline = pipeline("text-generation", model=llama_model_id, device_map="auto", torch_dtype=torch.float16) |
|
except Exception as e: |
|
print(f"Error loading Llama 3.1 model: {e}") |
|
print("Falling back to a smaller, open-source model.") |
|
llama_model_id = "gpt2" |
|
llama_pipeline = pipeline("text-generation", model=llama_model_id, device_map="auto") |
|
|
|
def analyze_image(image, query): |
|
prompt = f"<|user|>\n<|image_1|>\n{query}<|end|>\n<|assistant|>\n" |
|
inputs = phi_processor(prompt, images=image, return_tensors="pt").to(phi_model.device) |
|
|
|
with torch.no_grad(): |
|
output = phi_model.generate(**inputs, max_new_tokens=100) |
|
return phi_processor.decode(output[0], skip_special_tokens=True) |
|
|
|
def generate_text(query, history): |
|
context = "\n".join([f"{h[0]}\n{h[1]}" for h in history]) |
|
prompt = f"{context}\nHuman: {query}\nAI:" |
|
|
|
response = llama_pipeline(prompt, max_new_tokens=100, do_sample=True, temperature=0.7)[0]['generated_text'] |
|
return response.split("AI:")[-1].strip() |
|
|
|
def chatbot(image, query, history): |
|
if image is not None: |
|
response = analyze_image(Image.fromarray(image), query) |
|
else: |
|
response = generate_text(query, history) |
|
|
|
history.append((query, response)) |
|
return "", history, history |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Multi-Modal AI Assistant") |
|
|
|
with gr.Row(): |
|
image_input = gr.Image(type="numpy", label="Upload an image (optional)") |
|
chat_history = gr.Chatbot(label="Chat History") |
|
|
|
query_input = gr.Textbox(label="Ask a question or enter a prompt") |
|
submit_button = gr.Button("Submit") |
|
|
|
state = gr.State([]) |
|
|
|
submit_button.click( |
|
chatbot, |
|
inputs=[image_input, query_input, state], |
|
outputs=[query_input, chat_history, state] |
|
) |
|
|
|
demo.launch() |