Tanusree88 commited on
Commit
065f8a8
·
verified ·
1 Parent(s): c21163c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -2,7 +2,8 @@ import os
2
  import zipfile
3
  import numpy as np
4
  import torch
5
- from transformers import ViTForImageClassification, AdamW
 
6
  from PIL import Image
7
  from torch.utils.data import Dataset, DataLoader
8
  import streamlit as st
@@ -80,11 +81,12 @@ class CustomImageDataset(Dataset):
80
  label = self.labels[idx]
81
  return image, label
82
 
 
83
  # Training function
84
  def fine_tune_model(train_loader):
85
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=3)
86
  model.train()
87
- optimizer = AdamW(model.parameters(), lr=1e-4)
88
  criterion = torch.nn.CrossEntropyLoss()
89
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
90
  model.to(device)
 
2
  import zipfile
3
  import numpy as np
4
  import torch
5
+ from transformers import ViTForImageClassification
6
+ from torch.optim import AdamW
7
  from PIL import Image
8
  from torch.utils.data import Dataset, DataLoader
9
  import streamlit as st
 
81
  label = self.labels[idx]
82
  return image, label
83
 
84
+ # Training function
85
  # Training function
86
  def fine_tune_model(train_loader):
87
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=3)
88
  model.train()
89
+ optimizer = AdamW(model.parameters(), lr=1e-4) # Use PyTorch's AdamW
90
  criterion = torch.nn.CrossEntropyLoss()
91
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92
  model.to(device)