Jinglong Xiong commited on
Commit
bb5a422
·
1 Parent(s): 33dda45

can generate multiple variations

Browse files
Files changed (1) hide show
  1. gen_image.py +40 -7
gen_image.py CHANGED
@@ -3,6 +3,15 @@ import torch
3
 
4
  class ImageGenerator:
5
  def __init__(self, model_id="stabilityai/stable-diffusion-2-1-base", device="cuda"):
 
 
 
 
 
 
 
 
 
6
  scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
7
  self.pipe = StableDiffusionPipeline.from_pretrained(
8
  model_id,
@@ -10,25 +19,49 @@ class ImageGenerator:
10
  torch_dtype=torch.float16
11
  )
12
  self.pipe = self.pipe.to(device)
 
 
13
 
14
- def generate(self, prompt, negative_prompt=None, output_path=None):
15
- image = self.pipe(prompt, negative_prompt=negative_prompt).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  if output_path:
18
- image.save(output_path)
 
19
 
20
  return image
21
 
22
-
23
  # Example usage
24
  if __name__ == "__main__":
25
  generator = ImageGenerator()
26
  import time
27
  start_time = time.time()
28
  image = generator.generate(
29
- prompt="magenta trapezoids layered on a transluscent silver sheet, simple, icon",
30
- negative_prompt="3d, blurry, complex geometry, realistic",
31
- output_path="sheet.png"
32
  )
33
  end_time = time.time()
34
  print(f"Time taken: {end_time - start_time} seconds")
 
3
 
4
  class ImageGenerator:
5
  def __init__(self, model_id="stabilityai/stable-diffusion-2-1-base", device="cuda"):
6
+ """
7
+ Initialize the image generator with a specific model.
8
+
9
+ Args:
10
+ model_id (str): The model identifier for the stable diffusion model.
11
+ Default is "stabilityai/stable-diffusion-2-1-base".
12
+ device (str): The device to run the model on, either "cuda" or "cpu".
13
+ Default is "cuda".
14
+ """
15
  scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
16
  self.pipe = StableDiffusionPipeline.from_pretrained(
17
  model_id,
 
19
  torch_dtype=torch.float16
20
  )
21
  self.pipe = self.pipe.to(device)
22
+ self.positive_prompt = "simple, icon"
23
+ self.negative_prompt = "3d, blurry, complex geometry, realistic"
24
 
25
+ def generate(self, prompt, negative_prompt=None, output_path=None, num_images=1, num_inference_steps=50):
26
+ """
27
+ Generate an image based on the provided prompt.
28
+
29
+ Args:
30
+ prompt (str): The text description to generate an image from.
31
+ negative_prompt (str, optional): Elements to avoid in the generated image.
32
+ If None, uses the default negative prompt.
33
+ output_path (str, optional): Path to save the generated image.
34
+ If None, the image is not saved to disk.
35
+ num_images (int, optional): Number of images to generate.
36
+
37
+ Returns:
38
+ PIL.Image.Image: The generated image.
39
+ """
40
+ prompt = f"{prompt}, {self.positive_prompt}"
41
+ if negative_prompt is None:
42
+ negative_prompt = self.negative_prompt
43
+ images = self.pipe(
44
+ prompt,
45
+ negative_prompt=negative_prompt,
46
+ num_inference_steps=50,
47
+ num_images_per_prompt=num_images
48
+ ).images
49
 
50
  if output_path:
51
+ for i, image in enumerate(images):
52
+ image.save(f".cache/{output_path.replace('.png', f'_{i}.png')}")
53
 
54
  return image
55
 
 
56
  # Example usage
57
  if __name__ == "__main__":
58
  generator = ImageGenerator()
59
  import time
60
  start_time = time.time()
61
  image = generator.generate(
62
+ prompt="magenta trapezoids layered on a transluscent silver sheet",
63
+ output_path="sheet.png",
64
+ num_images=4
65
  )
66
  end_time = time.time()
67
  print(f"Time taken: {end_time - start_time} seconds")