Upload model
Browse files- modeling_basnet.py +18 -19
modeling_basnet.py
CHANGED
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
|
14 |
|
15 |
|
16 |
@dataclass
|
17 |
-
class
|
18 |
dout: torch.Tensor
|
19 |
d1: Optional[torch.Tensor] = None
|
20 |
d2: Optional[torch.Tensor] = None
|
@@ -25,6 +25,11 @@ class BASNetModelOutput(ModelOutput):
|
|
25 |
db: Optional[torch.Tensor] = None
|
26 |
|
27 |
|
|
|
|
|
|
|
|
|
|
|
28 |
class RefUnet(nn.Module):
|
29 |
def __init__(self, in_ch: int, inc_ch: int) -> None:
|
30 |
super().__init__()
|
@@ -466,27 +471,21 @@ class BASNetModel(PreTrainedModel):
|
|
466 |
d6_act = torch.sigmoid(d6)
|
467 |
db_act = torch.sigmoid(db)
|
468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
if not return_dict:
|
470 |
-
return (
|
471 |
-
dout_act,
|
472 |
-
d1_act,
|
473 |
-
d2_act,
|
474 |
-
d3_act,
|
475 |
-
d4_act,
|
476 |
-
d5_act,
|
477 |
-
d6_act,
|
478 |
-
db_act,
|
479 |
-
)
|
480 |
|
481 |
return BASNetModelOutput(
|
482 |
-
|
483 |
-
d1=d1_act,
|
484 |
-
d2=d2_act,
|
485 |
-
d3=d3_act,
|
486 |
-
d4=d4_act,
|
487 |
-
d5=d5_act,
|
488 |
-
d6=d6_act,
|
489 |
-
db=db_act,
|
490 |
)
|
491 |
|
492 |
|
|
|
14 |
|
15 |
|
16 |
@dataclass
|
17 |
+
class BasNetSideOutput(ModelOutput):
|
18 |
dout: torch.Tensor
|
19 |
d1: Optional[torch.Tensor] = None
|
20 |
d2: Optional[torch.Tensor] = None
|
|
|
25 |
db: Optional[torch.Tensor] = None
|
26 |
|
27 |
|
28 |
+
@dataclass
|
29 |
+
class BASNetModelOutput(ModelOutput):
|
30 |
+
activated: BasNetSideOutput
|
31 |
+
|
32 |
+
|
33 |
class RefUnet(nn.Module):
|
34 |
def __init__(self, in_ch: int, inc_ch: int) -> None:
|
35 |
super().__init__()
|
|
|
471 |
d6_act = torch.sigmoid(d6)
|
472 |
db_act = torch.sigmoid(db)
|
473 |
|
474 |
+
side_outputs = (
|
475 |
+
dout_act,
|
476 |
+
d1_act,
|
477 |
+
d2_act,
|
478 |
+
d3_act,
|
479 |
+
d4_act,
|
480 |
+
d5_act,
|
481 |
+
d6_act,
|
482 |
+
db_act,
|
483 |
+
)
|
484 |
if not return_dict:
|
485 |
+
return (side_outputs,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
|
487 |
return BASNetModelOutput(
|
488 |
+
activated=BasNetSideOutput(*side_outputs),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
)
|
490 |
|
491 |
|