shunk031 commited on
Commit
c8808cc
·
verified ·
1 Parent(s): c083e1f

Upload model

Browse files
Files changed (1) hide show
  1. modeling_basnet.py +18 -19
modeling_basnet.py CHANGED
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
14
 
15
 
16
  @dataclass
17
- class BASNetModelOutput(ModelOutput):
18
  dout: torch.Tensor
19
  d1: Optional[torch.Tensor] = None
20
  d2: Optional[torch.Tensor] = None
@@ -25,6 +25,11 @@ class BASNetModelOutput(ModelOutput):
25
  db: Optional[torch.Tensor] = None
26
 
27
 
 
 
 
 
 
28
  class RefUnet(nn.Module):
29
  def __init__(self, in_ch: int, inc_ch: int) -> None:
30
  super().__init__()
@@ -466,27 +471,21 @@ class BASNetModel(PreTrainedModel):
466
  d6_act = torch.sigmoid(d6)
467
  db_act = torch.sigmoid(db)
468
 
 
 
 
 
 
 
 
 
 
 
469
  if not return_dict:
470
- return (
471
- dout_act,
472
- d1_act,
473
- d2_act,
474
- d3_act,
475
- d4_act,
476
- d5_act,
477
- d6_act,
478
- db_act,
479
- )
480
 
481
  return BASNetModelOutput(
482
- dout=dout_act,
483
- d1=d1_act,
484
- d2=d2_act,
485
- d3=d3_act,
486
- d4=d4_act,
487
- d5=d5_act,
488
- d6=d6_act,
489
- db=db_act,
490
  )
491
 
492
 
 
14
 
15
 
16
  @dataclass
17
+ class BasNetSideOutput(ModelOutput):
18
  dout: torch.Tensor
19
  d1: Optional[torch.Tensor] = None
20
  d2: Optional[torch.Tensor] = None
 
25
  db: Optional[torch.Tensor] = None
26
 
27
 
28
+ @dataclass
29
+ class BASNetModelOutput(ModelOutput):
30
+ activated: BasNetSideOutput
31
+
32
+
33
  class RefUnet(nn.Module):
34
  def __init__(self, in_ch: int, inc_ch: int) -> None:
35
  super().__init__()
 
471
  d6_act = torch.sigmoid(d6)
472
  db_act = torch.sigmoid(db)
473
 
474
+ side_outputs = (
475
+ dout_act,
476
+ d1_act,
477
+ d2_act,
478
+ d3_act,
479
+ d4_act,
480
+ d5_act,
481
+ d6_act,
482
+ db_act,
483
+ )
484
  if not return_dict:
485
+ return (side_outputs,)
 
 
 
 
 
 
 
 
 
486
 
487
  return BASNetModelOutput(
488
+ activated=BasNetSideOutput(*side_outputs),
 
 
 
 
 
 
 
489
  )
490
 
491