Upload 7 files
Browse files
text_embedding_module/OCR/ocr_recog/RecModel.py
CHANGED
@@ -14,17 +14,17 @@ class RecModel(nn.Module):
|
|
14 |
def __init__(self, config):
|
15 |
super().__init__()
|
16 |
assert "in_channels" in config, "in_channels must in model config"
|
17 |
-
backbone_type = config
|
18 |
assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}"
|
19 |
-
self.backbone = backbone_dict[backbone_type](config
|
20 |
|
21 |
-
neck_type = config
|
22 |
assert neck_type in neck_dict, f"neck.type must in {neck_dict}"
|
23 |
-
self.neck = neck_dict[neck_type](self.backbone.out_channels, **config
|
24 |
|
25 |
-
head_type = config
|
26 |
assert head_type in head_dict, f"head.type must in {head_dict}"
|
27 |
-
self.head = head_dict[head_type](self.neck.out_channels, **config
|
28 |
|
29 |
self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}"
|
30 |
|
|
|
14 |
def __init__(self, config):
|
15 |
super().__init__()
|
16 |
assert "in_channels" in config, "in_channels must in model config"
|
17 |
+
backbone_type = config["backbone"].pop("type")
|
18 |
assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}"
|
19 |
+
self.backbone = backbone_dict[backbone_type](config['in_channels'], **config['backbone'])
|
20 |
|
21 |
+
neck_type = config['neck']("type")
|
22 |
assert neck_type in neck_dict, f"neck.type must in {neck_dict}"
|
23 |
+
self.neck = neck_dict[neck_type](self.backbone.out_channels, **config['neck'])
|
24 |
|
25 |
+
head_type = config['head']("type")
|
26 |
assert head_type in head_dict, f"head.type must in {head_dict}"
|
27 |
+
self.head = head_dict[head_type](self.neck.out_channels, **config['head'])
|
28 |
|
29 |
self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}"
|
30 |
|