Update script.py
Browse files
script.py
CHANGED
@@ -161,7 +161,7 @@ def generate_embeddings(metadata_file_path, root_dir):
|
|
161 |
|
162 |
loader = DataLoader(test_dataset, batch_size=3, shuffle=False)
|
163 |
|
164 |
-
device = torch.device(
|
165 |
model = timm.create_model(
|
166 |
"timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=False
|
167 |
)
|
@@ -225,7 +225,7 @@ class FungiMEEModel(nn.Module):
|
|
225 |
super().__init__()
|
226 |
|
227 |
print("Setting up Pytorch Model")
|
228 |
-
self.device = torch.device(
|
229 |
print(f"Using devide: {self.device}")
|
230 |
|
231 |
self.date_embedding = MlpHead(
|
@@ -279,7 +279,7 @@ class FungiEnsembleModel(nn.Module):
|
|
279 |
super().__init__()
|
280 |
|
281 |
self.models = nn.ModuleList()
|
282 |
-
self.device = torch.device(
|
283 |
|
284 |
for model in models:
|
285 |
model = model.to(self.device)
|
|
|
161 |
|
162 |
loader = DataLoader(test_dataset, batch_size=3, shuffle=False)
|
163 |
|
164 |
+
device = torch.device('cpu')
|
165 |
model = timm.create_model(
|
166 |
"timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=False
|
167 |
)
|
|
|
225 |
super().__init__()
|
226 |
|
227 |
print("Setting up Pytorch Model")
|
228 |
+
self.device = torch.device('cpu')
|
229 |
print(f"Using devide: {self.device}")
|
230 |
|
231 |
self.date_embedding = MlpHead(
|
|
|
279 |
super().__init__()
|
280 |
|
281 |
self.models = nn.ModuleList()
|
282 |
+
self.device = torch.device('cpu')
|
283 |
|
284 |
for model in models:
|
285 |
model = model.to(self.device)
|