emeersman commited on
Commit
e757853
·
1 Parent(s): 9c175fd

Update handler for img2img

Browse files
Files changed (1) hide show
  1. handler.py +40 -15
handler.py CHANGED
@@ -1,6 +1,9 @@
1
  from typing import Dict, List, Any
2
  import torch
3
- from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
 
 
 
4
 
5
  # set device
6
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -11,11 +14,16 @@ if device.type != 'cuda':
11
  model_id = "stabilityai/stable-diffusion-2-1-base"
12
 
13
  class EndpointHandler():
14
- def __init__(self):
15
  # load the optimized model
16
- self.pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
17
- self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
18
- self.pipe = self.pipe.to(device)
 
 
 
 
 
19
 
20
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
21
  """
@@ -26,31 +34,48 @@ class EndpointHandler():
26
  A :obj:`dict`:. base64 encoded image
27
  """
28
  prompt = data.pop("inputs", data)
 
 
 
 
 
29
  params = data.pop("parameters", data)
30
 
31
  # hyperparamters
32
- num_inference_steps = params.pop("num_inference_steps", 20)
33
  guidance_scale = params.pop("guidance_scale", 7.5)
34
  negative_prompt = params.pop("negative_prompt", None)
35
  height = params.pop("height", None)
36
  width = params.pop("width", None)
37
  manual_seed = params.pop("manual_seed", -1)
38
 
39
- generator = torch.Generator(device)
40
 
41
- if (manual_seed != -1)
 
42
  generator.manual_seed(manual_seed)
43
-
44
- # run inference pipeline
45
- out = self.pipe(prompt,
46
- generator=generator,
47
  num_inference_steps=num_inference_steps,
48
  guidance_scale=guidance_scale,
49
  num_images_per_prompt=1,
50
  negative_prompt=negative_prompt,
51
  height=height,
52
  width=width
53
- )
54
-
55
- # return first generate PIL image
 
 
 
 
 
 
 
 
 
 
 
 
56
  return out.images[0]
 
1
  from typing import Dict, List, Any
2
  import torch
3
+ import requests
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, DDIMScheduler
7
 
8
  # set device
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
14
  model_id = "stabilityai/stable-diffusion-2-1-base"
15
 
16
  class EndpointHandler():
17
+ def __init__(self, path=""):
18
  # load the optimized model
19
+ self.textPipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
20
+ self.textPipe.scheduler = DDIMScheduler.from_config(self.textPipe.scheduler.config)
21
+ self.textPipe = self.textPipe.to(device)
22
+
23
+ # create an img2img model
24
+ self.imgPipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
25
+ self.imgPipe.scheduler = DDIMScheduler.from_config(self.imgPipe.scheduler.config)
26
+ self.imgPipe = self.imgPipe.to(device)
27
 
28
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
29
  """
 
34
  A :obj:`dict`:. base64 encoded image
35
  """
36
  prompt = data.pop("inputs", data)
37
+ url = data.pop("url", data)
38
+ response = requests.get(url)
39
+ init_image = Image.open(BytesIO(response.content)).convert("RGB")
40
+ init_image.thumbnail((512, 512))
41
+
42
  params = data.pop("parameters", data)
43
 
44
  # hyperparamters
45
+ num_inference_steps = params.pop("num_inference_steps", 25)
46
  guidance_scale = params.pop("guidance_scale", 7.5)
47
  negative_prompt = params.pop("negative_prompt", None)
48
  height = params.pop("height", None)
49
  width = params.pop("width", None)
50
  manual_seed = params.pop("manual_seed", -1)
51
 
52
+ out = None
53
 
54
+ if data.get("url"):
55
+ generator = torch.Generator(device='cuda')
56
  generator.manual_seed(manual_seed)
57
+ # run img2img pipeline
58
+ out = self.imgPipe(prompt,
59
+ image=init_image,
 
60
  num_inference_steps=num_inference_steps,
61
  guidance_scale=guidance_scale,
62
  num_images_per_prompt=1,
63
  negative_prompt=negative_prompt,
64
  height=height,
65
  width=width
66
+ )
67
+ else:
68
+ # run text pipeline
69
+ out = self.textPipe(prompt,
70
+ image=init_image,
71
+ num_inference_steps=num_inference_steps,
72
+ guidance_scale=guidance_scale,
73
+ num_images_per_prompt=1,
74
+ negative_prompt=negative_prompt,
75
+ height=height,
76
+ width=width
77
+ )
78
+
79
+
80
+ # return first generated PIL image
81
  return out.images[0]