kusumakar commited on
Commit
81f8436
·
1 Parent(s): 0f16eb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -11,6 +11,7 @@ from PIL import Image
11
  from typing import List, Optional, Union
12
  import inspect
13
  import warnings
 
14
 
15
  # Load the Stable Diffusion model
16
  modelid = "CompVis/stable-diffusion-v1-4"
@@ -19,6 +20,8 @@ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(modelid, revision="fp16",
19
  pipe.to(device)
20
 
21
  url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
 
 
22
 
23
 
24
  def generate_image(prompt):
@@ -28,8 +31,9 @@ def generate_image(prompt):
28
 
29
  generator = torch.Generator(device=device).manual_seed(1024)
30
  with autocast("cuda"):
31
- prompt = torch.tensor(prompt, device=device).half()
32
- image = pipe(prompt=prompt, init_image=init_img, strength=0.75, guidance_scale=7.5, generator=generator).images[0]
 
33
 
34
  return image
35
 
 
11
  from typing import List, Optional, Union
12
  import inspect
13
  import warnings
14
+ import sentence_transformers
15
 
16
  # Load the Stable Diffusion model
17
  modelid = "CompVis/stable-diffusion-v1-4"
 
20
  pipe.to(device)
21
 
22
  url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
23
+ # Load the Sentence-BERT model for text embeddings
24
+ text_embedding_model = sentence_transformers.SentenceTransformer("paraphrase-MiniLM-L6-v2")
25
 
26
 
27
  def generate_image(prompt):
 
31
 
32
  generator = torch.Generator(device=device).manual_seed(1024)
33
  with autocast("cuda"):
34
+ prompt_embedding = text_embedding_model.encode([prompt])[0]
35
+ prompt_tensor = torch.tensor(prompt_embedding, device=device).half()
36
+ image = pipe(prompt=prompt_tensor, init_image=init_img, strength=0.75, guidance_scale=7.5, generator=generator).images[0]
37
 
38
  return image
39