ahatamiz commited on
Commit
8a2c177
1 Parent(s): 2f4b387

Delete hf_model.py

Browse files
Files changed (1) hide show
  1. hf_model.py +0 -52
hf_model.py DELETED
@@ -1,52 +0,0 @@
1
- # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from collections import namedtuple
16
- from typing import Optional, List, Union
17
-
18
- import torch
19
- from transformers import PretrainedConfig, PreTrainedModel
20
- from .mamba_vision import *
21
- from timm.models import create_model, load_checkpoint
22
-
23
-
24
- class MambaVisionConfig(PretrainedConfig):
25
-
26
- def __init__(
27
- self,
28
- args: Optional[dict] = None,
29
- **kwargs,
30
- ):
31
- self.args = args
32
- super().__init__(**kwargs)
33
-
34
-
35
- class MambaVisionModel(PreTrainedModel):
36
- """Pretrained Hugging Face model for MambaVision.
37
-
38
- This class inherits from PreTrainedModel, which provides
39
- HuggingFace's functionality for loading and saving models.
40
- """
41
-
42
- config_class = MambaVisionConfig
43
-
44
- def __init__(self, config):
45
- super().__init__(config)
46
- MambaVisionArgs = namedtuple("MambaVisionArgs", config.args.keys())
47
- args = MambaVisionArgs(**config.args)
48
- self.config = config
49
- self.model = create_model(args.model)
50
-
51
- def forward(self, x: torch.Tensor):
52
- return self.model.forward(x)