Spaces:
Build error
Build error
Update model.py
Browse files
model.py
CHANGED
@@ -167,7 +167,7 @@ class MultiHeadClassification(nn.Module):
|
|
167 |
Returns:
|
168 |
None
|
169 |
"""
|
170 |
-
model = torch.load(path)
|
171 |
if head_name in self.heads:
|
172 |
num_classes = model['weight'].shape[0]
|
173 |
self.heads[head_name].load_state_dict(model)
|
@@ -209,7 +209,7 @@ class MultiHeadClassification(nn.Module):
|
|
209 |
Args:
|
210 |
path (str): Path to the file
|
211 |
"""
|
212 |
-
self.load_state_dict(torch.load(path))
|
213 |
self.to(self.torch_dtype).to(self.device)
|
214 |
|
215 |
def save_backbone(self, path):
|
|
|
167 |
Returns:
|
168 |
None
|
169 |
"""
|
170 |
+
model = torch.load(path, map_location=self.device)
|
171 |
if head_name in self.heads:
|
172 |
num_classes = model['weight'].shape[0]
|
173 |
self.heads[head_name].load_state_dict(model)
|
|
|
209 |
Args:
|
210 |
path (str): Path to the file
|
211 |
"""
|
212 |
+
self.load_state_dict(torch.load(path, map_location=self.device))
|
213 |
self.to(self.torch_dtype).to(self.device)
|
214 |
|
215 |
def save_backbone(self, path):
|