shunk031 commited on
Commit
0a9e7b7
·
1 Parent(s): 6bfc688

deploy: eb4fa6c501d824a2e1c134c4e3832ffcb7656abf

Browse files
Files changed (1) hide show
  1. layout_overlap.py +14 -16
layout_overlap.py CHANGED
@@ -66,8 +66,8 @@ class LayoutOverlap(evaluate.Metric):
66
  citation=_CITATION,
67
  features=ds.Features(
68
  {
69
- "batch_bbox": ds.Sequence(ds.Sequence(ds.Value("float64"))),
70
- "batch_mask": ds.Sequence(ds.Value("bool")),
71
  }
72
  ),
73
  codebase_urls=[
@@ -146,35 +146,33 @@ class LayoutOverlap(evaluate.Metric):
146
  def _compute(
147
  self,
148
  *,
149
- batch_bbox: Union[npt.NDArray[np.float64], List[List[int]]],
150
- batch_mask: Union[npt.NDArray[np.bool_], List[List[bool]]],
151
  ) -> Dict[str, npt.NDArray[np.float64]]:
152
 
153
  # shape: (B, model_max_length, C)
154
- batch_bbox = np.array(batch_bbox)
155
  # shape: (B, model_max_length)
156
- batch_mask = np.array(batch_mask)
157
 
158
- assert batch_bbox.ndim == 3
159
- assert batch_mask.ndim == 2
160
 
161
  # S: model_max_length
162
- B, S, C = batch_bbox.shape
163
 
164
  # shape: batch_bbox (B, S, C), batch_mask (B, S) -> (B, S, 1) -> (B, S, C)
165
- batch_bbox[np.repeat(~batch_mask[:, :, None], axis=2, repeats=C)] = 0.0
166
  # shape: (C, B, S)
167
- batch_bbox = batch_bbox.transpose(2, 0, 1)
168
 
169
- A = self.__calculate_a1_ai(batch_bbox)
170
 
171
  # shape: (B,)
172
- score_ac_layout_gan = self._compute_ac_layout_gan(
173
- S=S, batch_mask=batch_mask, **A
174
- )
175
  # shape: (B,)
176
  score_layout_gan_pp = self._compute_layout_gan_pp(
177
- score_ac_layout_gan=score_ac_layout_gan, batch_mask=batch_mask
178
  )
179
  # shape: (B,)
180
  score_layout_gan = self._compute_layout_gan(B=B, S=S, ai=A["ai"])
 
66
  citation=_CITATION,
67
  features=ds.Features(
68
  {
69
+ "bbox": ds.Sequence(ds.Sequence(ds.Value("float64"))),
70
+ "mask": ds.Sequence(ds.Value("bool")),
71
  }
72
  ),
73
  codebase_urls=[
 
146
  def _compute(
147
  self,
148
  *,
149
+ bbox: Union[npt.NDArray[np.float64], List[List[int]]],
150
+ mask: Union[npt.NDArray[np.bool_], List[List[bool]]],
151
  ) -> Dict[str, npt.NDArray[np.float64]]:
152
 
153
  # shape: (B, model_max_length, C)
154
+ bbox = np.array(bbox)
155
  # shape: (B, model_max_length)
156
+ mask = np.array(mask)
157
 
158
+ assert bbox.ndim == 3
159
+ assert mask.ndim == 2
160
 
161
  # S: model_max_length
162
+ B, S, C = bbox.shape
163
 
164
  # shape: batch_bbox (B, S, C), batch_mask (B, S) -> (B, S, 1) -> (B, S, C)
165
+ bbox[np.repeat(~mask[:, :, None], axis=2, repeats=C)] = 0.0
166
  # shape: (C, B, S)
167
+ bbox = bbox.transpose(2, 0, 1)
168
 
169
+ A = self.__calculate_a1_ai(bbox)
170
 
171
  # shape: (B,)
172
+ score_ac_layout_gan = self._compute_ac_layout_gan(S=S, batch_mask=mask, **A)
 
 
173
  # shape: (B,)
174
  score_layout_gan_pp = self._compute_layout_gan_pp(
175
+ score_ac_layout_gan=score_ac_layout_gan, batch_mask=mask
176
  )
177
  # shape: (B,)
178
  score_layout_gan = self._compute_layout_gan(B=B, S=S, ai=A["ai"])