hlky HF Staff commited on
Commit
fbdce04
·
verified ·
1 Parent(s): 9b00d20

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +6 -3
handler.py CHANGED
@@ -27,6 +27,9 @@ class EndpointHandler:
27
  tensor = cast(torch.Tensor, data["inputs"])
28
  parameters = cast(dict, data.get("parameters", {}))
29
  do_scaling = cast(bool, parameters.get("do_scaling", True))
 
 
 
30
  output_type = cast(str, parameters.get("output_type", "pil"))
31
  partial_postprocess = cast(bool, parameters.get("partial_postprocess", False))
32
  if partial_postprocess and output_type != "pt":
@@ -34,8 +37,8 @@ class EndpointHandler:
34
 
35
  tensor = tensor.to(self.device, self.dtype)
36
 
37
- if do_scaling:
38
- tensor = tensor / self.vae.config.scaling_factor
39
 
40
  with torch.no_grad():
41
  frames = cast(torch.Tensor, self.vae.decode(tensor, return_dict=False)[0])
@@ -55,4 +58,4 @@ class EndpointHandler:
55
  elif output_type == "pt":
56
  frames = frames
57
 
58
- return frames
 
27
  tensor = cast(torch.Tensor, data["inputs"])
28
  parameters = cast(dict, data.get("parameters", {}))
29
  do_scaling = cast(bool, parameters.get("do_scaling", True))
30
+ scaling_factor = cast(float, parameters.get("scaling_factor", None))
31
+ if do_scaling and scaling_factor is None:
32
+ scaling_factor = self.vae.config.scaling_factor
33
  output_type = cast(str, parameters.get("output_type", "pil"))
34
  partial_postprocess = cast(bool, parameters.get("partial_postprocess", False))
35
  if partial_postprocess and output_type != "pt":
 
37
 
38
  tensor = tensor.to(self.device, self.dtype)
39
 
40
+ if scaling_factor is not None:
41
+ tensor = tensor / scaling_factor
42
 
43
  with torch.no_grad():
44
  frames = cast(torch.Tensor, self.vae.decode(tensor, return_dict=False)[0])
 
58
  elif output_type == "pt":
59
  frames = frames
60
 
61
+ return frames