Update handler.py
Browse files- handler.py +6 -3
handler.py
CHANGED
@@ -27,6 +27,9 @@ class EndpointHandler:
|
|
27 |
tensor = cast(torch.Tensor, data["inputs"])
|
28 |
parameters = cast(dict, data.get("parameters", {}))
|
29 |
do_scaling = cast(bool, parameters.get("do_scaling", True))
|
|
|
|
|
|
|
30 |
output_type = cast(str, parameters.get("output_type", "pil"))
|
31 |
partial_postprocess = cast(bool, parameters.get("partial_postprocess", False))
|
32 |
if partial_postprocess and output_type != "pt":
|
@@ -34,8 +37,8 @@ class EndpointHandler:
|
|
34 |
|
35 |
tensor = tensor.to(self.device, self.dtype)
|
36 |
|
37 |
-
if
|
38 |
-
tensor = tensor /
|
39 |
|
40 |
with torch.no_grad():
|
41 |
frames = cast(torch.Tensor, self.vae.decode(tensor, return_dict=False)[0])
|
@@ -55,4 +58,4 @@ class EndpointHandler:
|
|
55 |
elif output_type == "pt":
|
56 |
frames = frames
|
57 |
|
58 |
-
return frames
|
|
|
27 |
tensor = cast(torch.Tensor, data["inputs"])
|
28 |
parameters = cast(dict, data.get("parameters", {}))
|
29 |
do_scaling = cast(bool, parameters.get("do_scaling", True))
|
30 |
+
scaling_factor = cast(float, parameters.get("scaling_factor", None))
|
31 |
+
if do_scaling and scaling_factor is None:
|
32 |
+
scaling_factor = self.vae.config.scaling_factor
|
33 |
output_type = cast(str, parameters.get("output_type", "pil"))
|
34 |
partial_postprocess = cast(bool, parameters.get("partial_postprocess", False))
|
35 |
if partial_postprocess and output_type != "pt":
|
|
|
37 |
|
38 |
tensor = tensor.to(self.device, self.dtype)
|
39 |
|
40 |
+
if scaling_factor is not None:
|
41 |
+
tensor = tensor / scaling_factor
|
42 |
|
43 |
with torch.no_grad():
|
44 |
frames = cast(torch.Tensor, self.vae.decode(tensor, return_dict=False)[0])
|
|
|
58 |
elif output_type == "pt":
|
59 |
frames = frames
|
60 |
|
61 |
+
return frames
|