ppierzc commited on
Commit
0449ff8
·
1 Parent(s): 238552d

add(handler) add custom input handler

Browse files
Files changed (1) hide show
  1. input_handler.py +45 -0
input_handler.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from torch import autocast
4
+ from diffusers import AutoPipelineForText2Image
5
+ import base64
6
+ from io import BytesIO
7
+
8
+ # set device
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ if device.type != 'cuda':
12
+ raise ValueError("need to run on GPU")
13
+
14
+
15
+ class EndpointHandler():
16
+ def __init__(self, path=""):
17
+ # load the optimized model
18
+ import torch
19
+
20
+ self.pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
21
+ self.pipe.load_lora_weights(path, weight_name="pytorch_lora_weights.safetensors")
22
+
23
+ self.pipe = self.pipe.to(device)
24
+
25
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
26
+ """
27
+ Args:
28
+ data (:obj:):
29
+ includes the input data and the parameters for the inference.
30
+ Return:
31
+ A :obj:`dict`:. base64 encoded image
32
+ """
33
+ inputs = data.pop("inputs", data)
34
+
35
+ # run inference pipeline
36
+ with autocast(device.type):
37
+ image = self.pipe(inputs, num_inference_steps=1, guidance_scale=0.0).images[0]
38
+
39
+ # encode image as base 64
40
+ buffered = BytesIO()
41
+ image.save(buffered, format="JPEG")
42
+ img_str = base64.b64encode(buffered.getvalue())
43
+
44
+ # postprocess the prediction
45
+ return {"image": img_str.decode()}