mterris commited on
Commit
539d600
·
1 Parent(s): 35c18b7
Files changed (1) hide show
  1. factories.py +2 -2
factories.py CHANGED
@@ -472,7 +472,7 @@ class Metric():
472
  elif self.name == "LPIPS":
473
  self.metric = dinv.loss.metric.LPIPS(device=device_str)
474
 
475
- def __call__(self, x_net: torch.Tensor, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
476
  # it may happen that x_net and x do not have the same size, in which case we take the minimum size of both
477
  if x_net.shape[-1] != x.shape[-1]:
478
  min_size = min(x_net.shape[-1], x.shape[-1])
@@ -483,7 +483,7 @@ class Metric():
483
  else:
484
  x_net_crop = x_net
485
  x_crop = x
486
- return self.metric(x_net_crop, x_crop)
487
 
488
  @classmethod
489
  def get_list_metrics(cls, metric_names: List[str], device_str: str = "cpu") -> List["Metric"]:
 
472
  elif self.name == "LPIPS":
473
  self.metric = dinv.loss.metric.LPIPS(device=device_str)
474
 
475
+ def __call__(self, x_net: torch.Tensor, x: torch.Tensor, *args, crop=31, **kwargs) -> torch.Tensor:
476
  # it may happen that x_net and x do not have the same size, in which case we take the minimum size of both
477
  if x_net.shape[-1] != x.shape[-1]:
478
  min_size = min(x_net.shape[-1], x.shape[-1])
 
483
  else:
484
  x_net_crop = x_net
485
  x_crop = x
486
+ return self.metric(x_net_crop[..., crop:-crop, crop:-crop], x_crop[..., crop:-crop, crop:-crop])
487
 
488
  @classmethod
489
  def get_list_metrics(cls, metric_names: List[str], device_str: str = "cpu") -> List["Metric"]: