Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,8 @@ import os
|
|
2 |
import zipfile
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
-
from transformers import ViTForImageClassification
|
|
|
6 |
from PIL import Image
|
7 |
from torch.utils.data import Dataset, DataLoader
|
8 |
import streamlit as st
|
@@ -80,11 +81,12 @@ class CustomImageDataset(Dataset):
|
|
80 |
label = self.labels[idx]
|
81 |
return image, label
|
82 |
|
|
|
83 |
# Training function
|
84 |
def fine_tune_model(train_loader):
|
85 |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=3)
|
86 |
model.train()
|
87 |
-
optimizer = AdamW(model.parameters(), lr=1e-4)
|
88 |
criterion = torch.nn.CrossEntropyLoss()
|
89 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
90 |
model.to(device)
|
|
|
2 |
import zipfile
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
+
from transformers import ViTForImageClassification
|
6 |
+
from torch.optim import AdamW
|
7 |
from PIL import Image
|
8 |
from torch.utils.data import Dataset, DataLoader
|
9 |
import streamlit as st
|
|
|
81 |
label = self.labels[idx]
|
82 |
return image, label
|
83 |
|
84 |
+
# Training function
|
85 |
# Training function
|
86 |
def fine_tune_model(train_loader):
|
87 |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=3)
|
88 |
model.train()
|
89 |
+
optimizer = AdamW(model.parameters(), lr=1e-4) # Use PyTorch's AdamW
|
90 |
criterion = torch.nn.CrossEntropyLoss()
|
91 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
92 |
model.to(device)
|