Advait3009 commited on
Commit
27d6624
·
verified ·
1 Parent(s): 9df0d5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -14
app.py CHANGED
@@ -3,26 +3,29 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from diffusers import StableDiffusionPipeline
4
  from sentence_transformers import SentenceTransformer, util
5
  import torch
 
6
 
7
  # --- Load models ---
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
  # Text-to-text model
11
- text_model_name = "google/flan-t5-small"
12
  text_tokenizer = AutoTokenizer.from_pretrained(text_model_name)
13
- text_model = AutoModelForSeq2SeqLM.from_pretrained(text_model_name)
14
 
15
  # Text-to-image model
16
  image_model_id = "runwayml/stable-diffusion-v1-5"
17
  image_pipe = StableDiffusionPipeline.from_pretrained(
18
- image_model_id, torch_dtype=torch.float16 if device=="cuda" else torch.float32
 
 
19
  )
20
  image_pipe = image_pipe.to(device)
21
 
22
  # Sentence similarity model
23
  embedder = SentenceTransformer('all-MiniLM-L6-v2')
24
 
25
- # Example prompts that mean “generate image”
26
  image_triggers = [
27
  "generate an image of",
28
  "draw a",
@@ -33,30 +36,30 @@ image_triggers = [
33
  "sketch",
34
  ]
35
 
36
- # --- Main Logic ---
37
  def multimodal_agent(prompt):
38
- # Step 1: Check similarity to image prompts
39
  prompt_embedding = embedder.encode(prompt, convert_to_tensor=True)
40
  trigger_embeddings = embedder.encode(image_triggers, convert_to_tensor=True)
41
-
42
  cosine_scores = util.pytorch_cos_sim(prompt_embedding, trigger_embeddings)
43
  max_score = torch.max(cosine_scores).item()
44
 
45
- # Step 2: Threshold to decide
46
  if max_score > 0.65:
47
  # Generate image
48
- image = image_pipe(prompt).images[0]
49
- return None, image # Return image only
 
50
  else:
51
  # Generate text
52
- inputs = text_tokenizer(prompt, return_tensors="pt")
53
  outputs = text_model.generate(**inputs, max_new_tokens=100)
54
  text = text_tokenizer.decode(outputs[0], skip_special_tokens=True)
55
- return text, None # Return text only
56
 
57
- # --- Gradio UI ---
58
  with gr.Blocks() as demo:
59
- gr.Markdown("# 🤖 Smart Multimodal AI Agent\nDecides text or image based on your prompt.")
60
 
61
  input_box = gr.Textbox(label="Enter your prompt")
62
  output_text = gr.Textbox(label="Text Output")
 
3
  from diffusers import StableDiffusionPipeline
4
  from sentence_transformers import SentenceTransformer, util
5
  import torch
6
+ import contextlib
7
 
8
  # --- Load models ---
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  # Text-to-text model
12
+ text_model_name = "google/flan-t5-large"
13
  text_tokenizer = AutoTokenizer.from_pretrained(text_model_name)
14
+ text_model = AutoModelForSeq2SeqLM.from_pretrained(text_model_name).to(device)
15
 
16
  # Text-to-image model
17
  image_model_id = "runwayml/stable-diffusion-v1-5"
18
  image_pipe = StableDiffusionPipeline.from_pretrained(
19
+ image_model_id,
20
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
21
+ safety_checker=None # Optional for debugging
22
  )
23
  image_pipe = image_pipe.to(device)
24
 
25
  # Sentence similarity model
26
  embedder = SentenceTransformer('all-MiniLM-L6-v2')
27
 
28
+ # Image-like trigger phrases
29
  image_triggers = [
30
  "generate an image of",
31
  "draw a",
 
36
  "sketch",
37
  ]
38
 
39
+ # --- Core logic ---
40
  def multimodal_agent(prompt):
41
+ # Step 1: Semantic similarity to image triggers
42
  prompt_embedding = embedder.encode(prompt, convert_to_tensor=True)
43
  trigger_embeddings = embedder.encode(image_triggers, convert_to_tensor=True)
 
44
  cosine_scores = util.pytorch_cos_sim(prompt_embedding, trigger_embeddings)
45
  max_score = torch.max(cosine_scores).item()
46
 
47
+ # Step 2: Decision branch
48
  if max_score > 0.65:
49
  # Generate image
50
+ with torch.autocast("cuda") if device == "cuda" else contextlib.nullcontext():
51
+ image = image_pipe(prompt).images[0]
52
+ return None, image
53
  else:
54
  # Generate text
55
+ inputs = text_tokenizer(prompt, return_tensors="pt").to(device)
56
  outputs = text_model.generate(**inputs, max_new_tokens=100)
57
  text = text_tokenizer.decode(outputs[0], skip_special_tokens=True)
58
+ return text, None
59
 
60
+ # --- UI ---
61
  with gr.Blocks() as demo:
62
+ gr.Markdown("# 🤖 Smart Multimodal AI Agent\nGive a prompt It decides text vs image automatically!")
63
 
64
  input_box = gr.Textbox(label="Enter your prompt")
65
  output_text = gr.Textbox(label="Text Output")