Sreekanth Tangirala commited on
Commit
de2aabe
·
0 Parent(s):

first commit

Browse files
Files changed (5) hide show
  1. .gitignore +85 -0
  2. app.py +40 -0
  3. model.py +29 -0
  4. requirements.txt +5 -0
  5. train.py +128 -0
.gitignore ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual Environment
24
+ venv/
25
+ env/
26
+ ENV/
27
+ .env
28
+ .venv
29
+ env.bak/
30
+ venv.bak/
31
+
32
+ # PyTorch specific
33
+ *.pth
34
+ *.pt
35
+ *.pkl
36
+ *.onnx
37
+ data/
38
+ runs/
39
+ checkpoints/
40
+
41
+ # IDE specific
42
+ .idea/
43
+ .vscode/
44
+ *.swp
45
+ *.swo
46
+ .DS_Store
47
+
48
+ # Jupyter Notebook
49
+ .ipynb_checkpoints
50
+ *.ipynb
51
+
52
+ # Logs and databases
53
+ *.log
54
+ *.sqlite
55
+ logs/
56
+ wandb/
57
+
58
+ # Distribution / packaging
59
+ .Python
60
+ build/
61
+ develop-eggs/
62
+ dist/
63
+ downloads/
64
+ eggs/
65
+ .eggs/
66
+ lib/
67
+ lib64/
68
+ parts/
69
+ sdist/
70
+ var/
71
+ wheels/
72
+ *.egg-info/
73
+ .installed.cfg
74
+ *.egg
75
+
76
+ # Unit test / coverage reports
77
+ htmlcov/
78
+ .tox/
79
+ .coverage
80
+ .coverage.*
81
+ .cache
82
+ nosetests.xml
83
+ coverage.xml
84
+ *.cover
85
+ .hypothesis/
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ from torchvision.models import resnet50
6
+
7
+ # Load model
8
+ model = resnet50(pretrained=False)
9
+ model.fc = nn.Linear(model.fc.in_features, 10)
10
+ model.load_state_dict(torch.load('best_model.pth'))
11
+ model.eval()
12
+
13
+ # Define classes (for CIFAR-10)
14
+ classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
15
+ 'dog', 'frog', 'horse', 'ship', 'truck']
16
+
17
+ def predict(image):
18
+ transform = transforms.Compose([
19
+ transforms.Resize(224),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
22
+ ])
23
+
24
+ img_tensor = transform(image).unsqueeze(0)
25
+
26
+ with torch.no_grad():
27
+ outputs = model(img_tensor)
28
+ _, predicted = outputs.max(1)
29
+
30
+ return classes[predicted.item()]
31
+
32
+ # Create Gradio interface
33
+ iface = gr.Interface(
34
+ fn=predict,
35
+ inputs=gr.Image(type="pil"),
36
+ outputs=gr.Label(num_top_classes=1),
37
+ examples=[["example1.jpg"], ["example2.jpg"]]
38
+ )
39
+
40
+ iface.launch()
model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.models import resnet50
4
+
5
+ def get_model(num_classes):
6
+ """
7
+ Initialize a ResNet50 model from scratch
8
+ Args:
9
+ num_classes (int): Number of output classes
10
+ Returns:
11
+ model: ResNet50 model with custom final layer
12
+ """
13
+ model = resnet50(pretrained=False)
14
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
15
+ return model
16
+
17
+ def save_model(model, path):
18
+ """
19
+ Save model state dict
20
+ """
21
+ torch.save(model.state_dict(), path)
22
+
23
+ def load_model(num_classes, path):
24
+ """
25
+ Load a saved model
26
+ """
27
+ model = get_model(num_classes)
28
+ model.load_state_dict(torch.load(path))
29
+ return model
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/torch_stable.html
2
+ torch>=2.1.0
3
+ torchvision>=0.16.0
4
+ gradio==4.19.2
5
+ numpy==1.24.3
train.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torchvision
5
+ import torchvision.transforms as transforms
6
+ from torch.utils.data import DataLoader, Subset
7
+ from model import get_model, save_model
8
+ from tqdm import tqdm
9
+
10
+ def get_transforms():
11
+ """
12
+ Define the image transformations
13
+ """
14
+ return transforms.Compose([
15
+ transforms.Resize(224),
16
+ transforms.ToTensor(),
17
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
18
+ std=[0.229, 0.224, 0.225])
19
+ ])
20
+
21
+ def get_data(subset_size=None):
22
+ """
23
+ Load and prepare the dataset
24
+ Args:
25
+ subset_size (int): If provided, return only a subset of data
26
+ """
27
+ transform = get_transforms()
28
+ trainset = torchvision.datasets.CIFAR10(
29
+ root='./data',
30
+ train=True,
31
+ download=True,
32
+ transform=transform
33
+ )
34
+
35
+ if subset_size:
36
+ indices = torch.randperm(len(trainset))[:subset_size]
37
+ trainset = Subset(trainset, indices)
38
+
39
+ trainloader = DataLoader(
40
+ trainset,
41
+ batch_size=32,
42
+ shuffle=True,
43
+ num_workers=2
44
+ )
45
+
46
+ return trainloader
47
+
48
+ def train_model(model, trainloader, epochs=100, device='cuda'):
49
+ """
50
+ Train the model
51
+ Args:
52
+ model: The ResNet50 model
53
+ trainloader: DataLoader for training data
54
+ epochs (int): Number of epochs to train
55
+ device (str): Device to train on ('cuda' or 'cpu')
56
+ """
57
+ model = model.to(device)
58
+ criterion = nn.CrossEntropyLoss()
59
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
60
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
61
+ optimizer,
62
+ 'max',
63
+ patience=5
64
+ )
65
+
66
+ best_acc = 0.0
67
+
68
+ # Create epoch progress bar
69
+ epoch_pbar = tqdm(range(epochs), desc='Training')
70
+
71
+ for epoch in epoch_pbar:
72
+ model.train()
73
+ running_loss = 0.0
74
+ correct = 0
75
+ total = 0
76
+
77
+ # Create batch progress bar
78
+ batch_pbar = tqdm(trainloader, leave=False, desc=f'Epoch {epoch+1}')
79
+
80
+ for inputs, labels in batch_pbar:
81
+ inputs, labels = inputs.to(device), labels.to(device)
82
+
83
+ optimizer.zero_grad()
84
+ outputs = model(inputs)
85
+ loss = criterion(outputs, labels)
86
+ loss.backward()
87
+ optimizer.step()
88
+
89
+ running_loss += loss.item()
90
+ _, predicted = outputs.max(1)
91
+ total += labels.size(0)
92
+ correct += predicted.eq(labels).sum().item()
93
+
94
+ # Update batch progress bar
95
+ batch_pbar.set_postfix({'loss': f'{loss.item():.3f}'})
96
+
97
+ epoch_acc = 100. * correct / total
98
+ avg_loss = running_loss/len(trainloader)
99
+
100
+ # Update epoch progress bar
101
+ epoch_pbar.set_postfix({
102
+ 'loss': f'{avg_loss:.3f}',
103
+ 'accuracy': f'{epoch_acc:.2f}%'
104
+ })
105
+
106
+ scheduler.step(epoch_acc)
107
+
108
+ if epoch_acc > best_acc:
109
+ best_acc = epoch_acc
110
+ save_model(model, 'best_model.pth')
111
+
112
+ if epoch_acc > 70:
113
+ print(f"\nReached target accuracy of 70%!")
114
+ break
115
+
116
+ if __name__ == "__main__":
117
+ # Set device
118
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119
+ print(f"Using device: {device}")
120
+
121
+ # Get data
122
+ trainloader = get_data(subset_size=5000) # Using subset for initial testing
123
+
124
+ # Initialize model
125
+ model = get_model(num_classes=10)
126
+
127
+ # Train model
128
+ train_model(model, trainloader, epochs=10, device=device)