AlexandreL2024 commited on
Commit
c115c70
·
verified ·
1 Parent(s): 8d9bc84

Update tasks/image.py

Browse files
Files changed (1) hide show
  1. tasks/image.py +24 -12
tasks/image.py CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
3
  import torch.optim as optim
4
  from torchvision import transforms
5
  from torch.utils.data import DataLoader, Dataset
 
6
 
7
  from fastapi import APIRouter
8
  from datetime import datetime
@@ -180,19 +181,30 @@ async def evaluate_image(request: ImageEvaluationRequest):
180
 
181
 
182
  # Training loop
183
- num_epochs = 10
184
- for epoch in range(num_epochs):
185
- for images, labels in train_loader :
186
- images, labels = images.to(device), labels.to(device)
187
- # Zero the parameter gradients
188
- optimizer.zero_grad()
189
 
190
- # Forward + backward + optimize
191
- outputs = model(images)
192
- loss = criterion(outputs, labels)
193
- loss.backward()
194
- optimizer.step()
195
- print(f'Epoch [{epoch + 1}/10], Loss: {loss.item():.4f}')
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  # Evaluation loop
198
  model.eval() # Set the model to evaluation mode
 
3
  import torch.optim as optim
4
  from torchvision import transforms
5
  from torch.utils.data import DataLoader, Dataset
6
+ from huggingface_hub import hf_hub_download
7
 
8
  from fastapi import APIRouter
9
  from datetime import datetime
 
181
 
182
 
183
  # Training loop
184
+ # num_epochs = 10
185
+ # for epoch in range(num_epochs):
186
+ # for images, labels in train_loader :
187
+ # images, labels = images.to(device), labels.to(device)
188
+ # # Zero the parameter gradients
189
+ # optimizer.zero_grad()
190
 
191
+ # # Forward + backward + optimize
192
+ # outputs = model(images)
193
+ # loss = criterion(outputs, labels)
194
+ # loss.backward()
195
+ # optimizer.step()
196
+ # print(f'Epoch [{epoch + 1}/10], Loss: {loss.item():.4f}')
197
+
198
+ # Charging pre-trained model
199
+ repo_id = "AlexandreL2024/CNN-Image-Classification"
200
+ filename = "model_CNN_2Layers.pth"
201
+
202
+ # Upload file .pth from Hugging Face
203
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
204
+
205
+ # Charger le modèle avec torch.load()
206
+ model = ImageClassifier()
207
+ model = model.load_state_dict(torch.load(model_path))
208
 
209
  # Evaluation loop
210
  model.eval() # Set the model to evaluation mode