emeersman commited on
Commit
fb68c01
·
1 Parent(s): fe706b1

Switch to PIL image processing

Browse files
Files changed (1) hide show
  1. handler.py +25 -18
handler.py CHANGED
@@ -1,7 +1,7 @@
1
  from typing import Dict, List, Any
2
  import torch
3
- from torch import autocast
4
  from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
 
5
  import base64
6
  from io import BytesIO
7
 
@@ -28,32 +28,39 @@ class EndpointHandler():
28
  A :obj:`dict`:. base64 encoded image
29
  """
30
  inputs = data.pop("inputs", data)
 
31
 
32
  # hyperparamters
33
- num_inference_steps = data.pop("num_inference_steps", 25)
34
- guidance_scale = data.pop("guidance_scale", 7.5)
35
- negative_prompt = data.pop("negative_prompt", None)
36
- height = data.pop("height", None)
37
- width = data.pop("width", None)
38
- manual_seed = data.pop("manual_seed", -1)
39
 
40
  generator = torch.Generator(device).manual_seed(manual_seed)
41
 
 
 
 
42
  # run inference pipeline
43
- with autocast(device.type):
44
- image = self.pipe(inputs,
45
- generator=generator,
46
  num_inference_steps=num_inference_steps,
47
  guidance_scale=guidance_scale,
48
  num_images_per_prompt=1,
49
  negative_prompt=negative_prompt,
50
  height=height,
51
- width=width).images[0]
 
52
 
53
- # encode image as base 64
54
- buffered = BytesIO()
55
- image.save(buffered, format="JPEG")
56
- img_str = base64.b64encode(buffered.getvalue())
57
-
58
- # postprocess the prediction
59
- return {"image": img_str.decode()}
 
 
 
1
  from typing import Dict, List, Any
2
  import torch
 
3
  from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
4
+ from PIL import Image
5
  import base64
6
  from io import BytesIO
7
 
 
28
  A :obj:`dict`:. base64 encoded image
29
  """
30
  inputs = data.pop("inputs", data)
31
+ params = data.pop("parameters", data)
32
 
33
  # hyperparamters
34
+ num_inference_steps = params.pop("num_inference_steps", 20)
35
+ guidance_scale = params.pop("guidance_scale", 7.5)
36
+ negative_prompt = params.pop("negative_prompt", None)
37
+ height = params.pop("height", None)
38
+ width = params.pop("width", None)
39
+ manual_seed = params.pop("manual_seed", -1)
40
 
41
  generator = torch.Generator(device).manual_seed(manual_seed)
42
 
43
+ if encoded_image is not None:
44
+ image = self.decode_base64_image(encoded_image)
45
+
46
  # run inference pipeline
47
+ out = self.pipe(inputs,
48
+ image=image,
49
+ generator=generator,
50
  num_inference_steps=num_inference_steps,
51
  guidance_scale=guidance_scale,
52
  num_images_per_prompt=1,
53
  negative_prompt=negative_prompt,
54
  height=height,
55
+ width=width
56
+ )
57
 
58
+ # return first generate PIL image
59
+ return out.images[0]
60
+
61
+ # helper to decode input image
62
+ def decode_base64_image(self, image_string):
63
+ base64_image = base64.b64decode(image_string)
64
+ buffer = BytesIO(base64_image)
65
+ image = Image.open(buffer)
66
+ return image