joyson commited on
Commit
b9af579
Β·
verified Β·
1 Parent(s): 56fca0b

Update text_to_image.py

Browse files
Files changed (1) hide show
  1. text_to_image.py +35 -55
text_to_image.py CHANGED
@@ -1,55 +1,35 @@
1
- import torch
2
- import spaces
3
- from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
4
- from PIL import Image
5
- from io import BytesIO
6
- from utils import load_unet_model
7
-
8
- @spaces.GPU
9
- class TextToImage:
10
- """
11
- Class to handle Text-to-Image generation using Stable Diffusion XL.
12
- """
13
- def __init__(self, device="cpu"):
14
- # Model and repository details
15
- self.base = "stabilityai/stable-diffusion-xl-base-1.0"
16
- self.repo = "ByteDance/SDXL-Lightning"
17
- self.ckpt = "sdxl_lightning_4step_unet.safetensors"
18
- self.device = device
19
-
20
- # Load the UNet model
21
- print("Loading Text-to-Image model...")
22
- self.unet = load_unet_model(self.base, self.repo, self.ckpt, device=self.device)
23
-
24
- # Initialize the pipeline
25
- self.pipe = StableDiffusionXLPipeline.from_pretrained(
26
- self.base,
27
- unet=self.unet,
28
- torch_dtype=torch.float32,
29
- ).to(self.device)
30
-
31
- # Set the scheduler
32
- self.pipe.scheduler = EulerDiscreteScheduler.from_config(
33
- self.pipe.scheduler.config,
34
- timestep_spacing="trailing"
35
- )
36
- print("Text-to-Image model loaded successfully.")
37
-
38
-
39
- async def generate_image(self, prompt):
40
- """
41
- Generate an image from a text prompt.
42
-
43
- Args:
44
- prompt (str): The text prompt to generate the image.
45
-
46
- Returns:
47
- PIL.Image: The generated image.
48
- """
49
- with torch.no_grad():
50
- image = self.pipe(
51
- prompt,
52
- num_inference_steps=4,
53
- guidance_scale=0
54
- ).images[0]
55
- return image
 
1
+ import torch
2
+ import spaces
3
+ from diffusers import StableDiffusionPipeline
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ from utils import load_unet_model
7
+
8
+ @spaces.GPU
9
+ class TextToImage:
10
+ """
11
+ Class to handle Text-to-Image generation using Stable Diffusion XL.
12
+ """
13
+ def __init__(self, device="cpu"):
14
+ # Model and repository details
15
+ model_id = "OFA-Sys/small-stable-diffusion-v0"
16
+ self.device = device
17
+ self.pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32).pipe.to(device)
18
+ print("Text-to-Image model loaded successfully.")
19
+
20
+
21
+ async def generate_image(self, prompt):
22
+ """
23
+ Generate an image from a text prompt.
24
+
25
+ Args:
26
+ prompt (str): The text prompt to generate the image.
27
+
28
+ Returns:
29
+ PIL.Image: The generated image.
30
+ """
31
+ with torch.no_grad():
32
+ image = self.pipe(
33
+ prompt
34
+ ).images[0]
35
+ return image