wifix199 commited on
Commit
9614ae1
1 Parent(s): 989332d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -8,15 +8,28 @@ pipe.to("cpu") # Use "cuda" if GPU is available
8
 
9
  unet = pipe.unet
10
 
11
- def generate_image(prompt, unet):
12
- added_cond_kwargs = {"text_embeds": pipe.get_text_embedding(prompt)}
13
- image = unet(prompt, **added_cond_kwargs).images[0]
 
 
 
14
  return image
15
 
16
  def chatbot(prompt):
17
  # Generate the image based on the user's input
18
- image = generate_image(prompt, unet)
19
  return image
 
 
 
 
 
 
 
 
 
 
20
 
21
  # Create the Gradio interface
22
  interface = gr.Interface(
 
8
 
9
  unet = pipe.unet
10
 
11
+ def generate_image(prompt, unet, pipe):
12
+ # Encode the prompt
13
+ text_encoding = pipe.text_encoder(prompt, return_tensors="pt").to(unet.device)
14
+
15
+ # Generate the image
16
+ image = unet(text_embeddings=text_encoding.last_hidden_state).images[0]
17
  return image
18
 
19
  def chatbot(prompt):
20
  # Generate the image based on the user's input
21
+ image = generate_image(prompt, unet, pipe)
22
  return image
23
+
24
+ def get_aug_embed(self, text_embeds, image):
25
+ if text_embeds is None:
26
+ text_embeds = self.text_encoder(
27
+ text_embeds=text_embeds,
28
+ image=image,
29
+ height=self.unet.config.sample_size,
30
+ width=self.unet.config.sample_size,
31
+ )
32
+ return text_embeds
33
 
34
  # Create the Gradio interface
35
  interface = gr.Interface(