Advait3009 commited on
Commit
12d2cf1
·
verified ·
1 Parent(s): 0b9ca50

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ from diffusers import StableDiffusionPipeline
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import torch
6
+
7
+ # --- Load models ---
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # Text-to-text model
11
+ text_model_name = "google/flan-t5-small"
12
+ text_tokenizer = AutoTokenizer.from_pretrained(text_model_name)
13
+ text_model = AutoModelForSeq2SeqLM.from_pretrained(text_model_name)
14
+
15
+ # Text-to-image model
16
+ image_model_id = "runwayml/stable-diffusion-v1-5"
17
+ image_pipe = StableDiffusionPipeline.from_pretrained(
18
+ image_model_id, torch_dtype=torch.float16 if device=="cuda" else torch.float32
19
+ )
20
+ image_pipe = image_pipe.to(device)
21
+
22
+ # Sentence similarity model
23
+ embedder = SentenceTransformer('all-MiniLM-L6-v2')
24
+
25
+ # Example prompts that mean “generate image”
26
+ image_triggers = [
27
+ "generate an image of",
28
+ "draw a",
29
+ "create a picture of",
30
+ "show me a",
31
+ "visualize",
32
+ "render",
33
+ "sketch",
34
+ ]
35
+
36
+ # --- Main Logic ---
37
+ def multimodal_agent(prompt):
38
+ # Step 1: Check similarity to image prompts
39
+ prompt_embedding = embedder.encode(prompt, convert_to_tensor=True)
40
+ trigger_embeddings = embedder.encode(image_triggers, convert_to_tensor=True)
41
+
42
+ cosine_scores = util.pytorch_cos_sim(prompt_embedding, trigger_embeddings)
43
+ max_score = torch.max(cosine_scores).item()
44
+
45
+ # Step 2: Threshold to decide
46
+ if max_score > 0.65:
47
+ # Generate image
48
+ image = image_pipe(prompt).images[0]
49
+ return None, image # Return image only
50
+ else:
51
+ # Generate text
52
+ inputs = text_tokenizer(prompt, return_tensors="pt")
53
+ outputs = text_model.generate(**inputs, max_new_tokens=100)
54
+ text = text_tokenizer.decode(outputs[0], skip_special_tokens=True)
55
+ return text, None # Return text only
56
+
57
+ # --- Gradio UI ---
58
+ with gr.Blocks() as demo:
59
+ gr.Markdown("# 🤖 Smart Multimodal AI Agent\nDecides text or image based on your prompt.")
60
+
61
+ input_box = gr.Textbox(label="Enter your prompt")
62
+ output_text = gr.Textbox(label="Text Output")
63
+ output_image = gr.Image(label="Image Output")
64
+
65
+ btn = gr.Button("Generate")
66
+ btn.click(multimodal_agent, inputs=input_box, outputs=[output_text, output_image])
67
+
68
+ demo.launch()