joyson commited on
Commit
d5ca99f
Β·
verified Β·
1 Parent(s): b9af579

Update image_to_image.py

Browse files
Files changed (1) hide show
  1. image_to_image.py +41 -60
image_to_image.py CHANGED
@@ -1,60 +1,41 @@
1
- import torch
2
- from diffusers import StableDiffusionXLImg2ImgPipeline, EulerDiscreteScheduler
3
- from PIL import Image
4
- from io import BytesIO
5
- from utils import load_unet_model
6
-
7
- class ImageToImage:
8
- """
9
- Class to handle Image-to-Image transformations using Stable Diffusion XL.
10
- """
11
- def __init__(self, device="cpu"):
12
- # Model and repository details
13
- self.base = "stabilityai/stable-diffusion-xl-base-1.0"
14
- self.repo = "ByteDance/SDXL-Lightning"
15
- self.ckpt = "sdxl_lightning_4step_unet.safetensors"
16
- self.device = device
17
-
18
- # Load the UNet model
19
- print("Loading Image-to-Image model...")
20
- self.unet = load_unet_model(self.base, self.repo, self.ckpt, device=self.device)
21
-
22
- # Initialize the pipeline
23
- self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
24
- self.base,
25
- unet=self.unet,
26
- torch_dtype=torch.float32
27
- ).to(self.device)
28
-
29
- # Set the scheduler
30
- self.pipe.scheduler = EulerDiscreteScheduler.from_config(
31
- self.pipe.scheduler.config,
32
- timestep_spacing="trailing"
33
- )
34
- print("Image-to-Image model loaded successfully.")
35
-
36
-
37
- async def transform_image(self, image, prompt):
38
- """
39
- Transform an uploaded image based on a text prompt.
40
-
41
- Args:
42
- image (PIL.Image): The input image to transform.
43
- prompt (str): The text prompt to guide the transformation.
44
-
45
- Returns:
46
- PIL.Image: The transformed image.
47
- """
48
- if not prompt:
49
- raise ValueError("Prompt cannot be empty.")
50
-
51
- # Resize the image as required by the model
52
- init_image = image.resize((768, 512))
53
- with torch.no_grad():
54
- transformed_image = self.pipe(
55
- prompt=prompt,
56
- image=init_image,
57
- strength=0.75,
58
- guidance_scale=7.5
59
- ).images[0]
60
- return transformed_image
 
1
+ import torch
2
+ from diffusers import StableDiffusionImg2ImgPipeline
3
+ from PIL import Image
4
+ from io import BytesIO
5
+
6
+ class ImageToImage:
7
+ """
8
+ Class to handle Image-to-Image transformations using Stable Diffusion.
9
+ """
10
+ def __init__(self, device="cpu"):
11
+ # Model and repository details
12
+ model_id = "OFA-Sys/small-stable-diffusion-v0"
13
+ self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.float32).pipe.to(device)
14
+ self.device = device
15
+ print("Image-to-Image model loaded successfully.")
16
+
17
+
18
+ async def transform_image(self, image, prompt):
19
+ """
20
+ Transform an uploaded image based on a text prompt.
21
+
22
+ Args:
23
+ image (PIL.Image): The input image to transform.
24
+ prompt (str): The text prompt to guide the transformation.
25
+
26
+ Returns:
27
+ PIL.Image: The transformed image.
28
+ """
29
+ if not prompt:
30
+ raise ValueError("Prompt cannot be empty.")
31
+
32
+ # Resize the image as required by the model
33
+ init_image = image.resize((512, 512))
34
+ with torch.no_grad():
35
+ transformed_image = self.pipe(
36
+ prompt=prompt,
37
+ image=init_image,
38
+ strength=0.75,
39
+ guidance_scale=7.5
40
+ ).images[0]
41
+ return transformed_image