English
Thomas Male commited on
Commit
f22523d
·
1 Parent(s): 58558bd

Upload handler.py

Browse files

Created custom Handler

Files changed (1) hide show
  1. handler.py +80 -0
handler.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from torch import autocast
4
+ from tqdm.auto import tqdm
5
+ from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
6
+ from point_e.diffusion.sampler import PointCloudSampler
7
+ from point_e.models.download import load_checkpoint
8
+ from point_e.models.configs import MODEL_CONFIGS, model_from_config
9
+ from point_e.util.plotting import plot_point_cloud
10
+ import base64
11
+ from io import BytesIO
12
+
13
+
14
+ # set device
15
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+
17
+ if device.type != 'cuda':
18
+ raise ValueError("need to run on GPU")
19
+
20
+ class EndpointHandler():
21
+ def __init__(self, path=""):
22
+ # load the optimized model
23
+ print('creating base model...')
24
+ self.base_name = 'base40M-textvec'
25
+ self.base_model = model_from_config(MODEL_CONFIGS[self.base_name], device)
26
+ self.base_model.eval()
27
+ self.base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[self.base_name])
28
+
29
+ print('creating upsample model...')
30
+ self.upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
31
+ self.upsampler_model.eval()
32
+ self.upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])
33
+
34
+ print('downloading base checkpoint...')
35
+ self.base_model.load_state_dict(load_checkpoint(self.base_name, device))
36
+
37
+ print('downloading upsampler checkpoint...')
38
+ self.upsampler_model.load_state_dict(load_checkpoint('upsample', device))
39
+
40
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
41
+ """
42
+ Args:
43
+ data (:obj:):
44
+ includes the input data and the parameters for the inference.
45
+ Return:
46
+ A :obj:`dict`:. plotly json Data
47
+ """
48
+ inputs = data.pop("inputs", data)
49
+
50
+ sampler = PointCloudSampler(
51
+ device=device,
52
+ models=[self.base_model,self.upsampler_model],
53
+ diffusions=[self.base_diffusion, self.upsampler_diffusion],
54
+ num_points=[1024, 4096 - 1024],
55
+ aux_channels=['R', 'G', 'B'],
56
+ guidance_scale=[3.0, 0.0],
57
+ model_kwargs_key_filter=('texts', ''), # Do not condition the upsampler at all
58
+ )
59
+
60
+ # Set a test prompt to condition on.
61
+ # prompt = 'A bluebird mid-flight'
62
+
63
+ # run inference pipeline
64
+ with autocast(device.type):
65
+ samples = None
66
+ for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[inputs]))):
67
+ samples = x
68
+ #image = self.pipe(inputs, guidance_scale=7.5)["sample"][0]
69
+
70
+ pc = sampler.output_to_point_clouds(samples)[0]
71
+ return {"data": pc}
72
+ #print(pc)
73
+
74
+ # encode image as base 64
75
+ #buffered = BytesIO()
76
+ #image.save(buffered, format="JPEG")
77
+ #img_str = base64.b64encode(buffered.getvalue())
78
+
79
+ # postprocess the prediction
80
+ #return {"image": img_str.decode()}