File size: 2,969 Bytes
c9e151d
10dacc2
c9e151d
 
10dacc2
63b02f7
10dacc2
 
c9e151d
10dacc2
 
 
 
 
 
 
 
 
c9e151d
10dacc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9e151d
10dacc2
 
 
 
 
 
 
c9e151d
10dacc2
 
c9e151d
10dacc2
 
 
380d0cf
10dacc2
 
 
 
63b02f7
10dacc2
 
03cdb75
10dacc2
 
 
03cdb75
10dacc2
 
03cdb75
10dacc2
 
 
 
 
 
 
c9e151d
10dacc2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, pipeline
from PIL import Image
import torch
import warnings

# Suppress warnings
warnings.filterwarnings("ignore")

# Load Phi-3.5-vision model
phi_model_id = "microsoft/Phi-3.5-vision-instruct"
try:
    phi_model = AutoModelForCausalLM.from_pretrained(
        phi_model_id, 
        device_map="auto", 
        trust_remote_code=True,
        torch_dtype=torch.float16,  # Use float16 to reduce memory usage
        _attn_implementation="eager"  # Fall back to eager implementation if flash attention is not available
    )
except ImportError:
    print("FlashAttention not available, falling back to eager implementation.")
    phi_model = AutoModelForCausalLM.from_pretrained(
        phi_model_id, 
        device_map="auto", 
        trust_remote_code=True,
        torch_dtype=torch.float16,
        _attn_implementation="eager"
    )

phi_processor = AutoProcessor.from_pretrained(phi_model_id, trust_remote_code=True)

# Load Llama 3.1 model
llama_model_id = "meta-llama/Llama-3.1-8B"
try:
    llama_pipeline = pipeline("text-generation", model=llama_model_id, device_map="auto", torch_dtype=torch.float16)
except Exception as e:
    print(f"Error loading Llama 3.1 model: {e}")
    print("Falling back to a smaller, open-source model.")
    llama_model_id = "gpt2"  # Fallback to a smaller, open-source model
    llama_pipeline = pipeline("text-generation", model=llama_model_id, device_map="auto")

def analyze_image(image, query):
    prompt = f"<|user|>\n<|image_1|>\n{query}<|end|>\n<|assistant|>\n"
    inputs = phi_processor(prompt, images=image, return_tensors="pt").to(phi_model.device)
    
    with torch.no_grad():
        output = phi_model.generate(**inputs, max_new_tokens=100)
    return phi_processor.decode(output[0], skip_special_tokens=True)

def generate_text(query, history):
    context = "\n".join([f"{h[0]}\n{h[1]}" for h in history])
    prompt = f"{context}\nHuman: {query}\nAI:"
    
    response = llama_pipeline(prompt, max_new_tokens=100, do_sample=True, temperature=0.7)[0]['generated_text']
    return response.split("AI:")[-1].strip()

def chatbot(image, query, history):
    if image is not None:
        response = analyze_image(Image.fromarray(image), query)
    else:
        response = generate_text(query, history)
    
    history.append((query, response))
    return "", history, history

with gr.Blocks() as demo:
    gr.Markdown("# Multi-Modal AI Assistant")
    
    with gr.Row():
        image_input = gr.Image(type="numpy", label="Upload an image (optional)")
        chat_history = gr.Chatbot(label="Chat History")
    
    query_input = gr.Textbox(label="Ask a question or enter a prompt")
    submit_button = gr.Button("Submit")
    
    state = gr.State([])
    
    submit_button.click(
        chatbot,
        inputs=[image_input, query_input, state],
        outputs=[query_input, chat_history, state]
    )

demo.launch()