File size: 2,969 Bytes
c9e151d 10dacc2 c9e151d 10dacc2 63b02f7 10dacc2 c9e151d 10dacc2 c9e151d 10dacc2 c9e151d 10dacc2 c9e151d 10dacc2 c9e151d 10dacc2 380d0cf 10dacc2 63b02f7 10dacc2 03cdb75 10dacc2 03cdb75 10dacc2 03cdb75 10dacc2 c9e151d 10dacc2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, pipeline
from PIL import Image
import torch
import warnings
# Suppress warnings
warnings.filterwarnings("ignore")
# Load Phi-3.5-vision model
phi_model_id = "microsoft/Phi-3.5-vision-instruct"
try:
phi_model = AutoModelForCausalLM.from_pretrained(
phi_model_id,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16, # Use float16 to reduce memory usage
_attn_implementation="eager" # Fall back to eager implementation if flash attention is not available
)
except ImportError:
print("FlashAttention not available, falling back to eager implementation.")
phi_model = AutoModelForCausalLM.from_pretrained(
phi_model_id,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16,
_attn_implementation="eager"
)
phi_processor = AutoProcessor.from_pretrained(phi_model_id, trust_remote_code=True)
# Load Llama 3.1 model
llama_model_id = "meta-llama/Llama-3.1-8B"
try:
llama_pipeline = pipeline("text-generation", model=llama_model_id, device_map="auto", torch_dtype=torch.float16)
except Exception as e:
print(f"Error loading Llama 3.1 model: {e}")
print("Falling back to a smaller, open-source model.")
llama_model_id = "gpt2" # Fallback to a smaller, open-source model
llama_pipeline = pipeline("text-generation", model=llama_model_id, device_map="auto")
def analyze_image(image, query):
prompt = f"<|user|>\n<|image_1|>\n{query}<|end|>\n<|assistant|>\n"
inputs = phi_processor(prompt, images=image, return_tensors="pt").to(phi_model.device)
with torch.no_grad():
output = phi_model.generate(**inputs, max_new_tokens=100)
return phi_processor.decode(output[0], skip_special_tokens=True)
def generate_text(query, history):
context = "\n".join([f"{h[0]}\n{h[1]}" for h in history])
prompt = f"{context}\nHuman: {query}\nAI:"
response = llama_pipeline(prompt, max_new_tokens=100, do_sample=True, temperature=0.7)[0]['generated_text']
return response.split("AI:")[-1].strip()
def chatbot(image, query, history):
if image is not None:
response = analyze_image(Image.fromarray(image), query)
else:
response = generate_text(query, history)
history.append((query, response))
return "", history, history
with gr.Blocks() as demo:
gr.Markdown("# Multi-Modal AI Assistant")
with gr.Row():
image_input = gr.Image(type="numpy", label="Upload an image (optional)")
chat_history = gr.Chatbot(label="Chat History")
query_input = gr.Textbox(label="Ask a question or enter a prompt")
submit_button = gr.Button("Submit")
state = gr.State([])
submit_button.click(
chatbot,
inputs=[image_input, query_input, state],
outputs=[query_input, chat_history, state]
)
demo.launch() |