Update app.py
Browse files
app.py
CHANGED
@@ -95,8 +95,11 @@ class CustomImageDataset(Dataset):
|
|
95 |
|
96 |
# Training function for classification
|
97 |
def fine_tune_classification_model(train_loader):
|
98 |
-
model
|
|
|
|
|
99 |
model.train()
|
|
|
100 |
optimizer = AdamW(model.parameters(), lr=1e-4)
|
101 |
criterion = torch.nn.CrossEntropyLoss()
|
102 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
@@ -112,6 +115,7 @@ def fine_tune_classification_model(train_loader):
|
|
112 |
loss.backward()
|
113 |
optimizer.step()
|
114 |
running_loss += loss.item()
|
|
|
115 |
return running_loss / len(train_loader)
|
116 |
|
117 |
# Streamlit UI for Fine-tuning
|
@@ -140,8 +144,10 @@ if st.button("Start Training"):
|
|
140 |
|
141 |
# Segmentation function (using SegFormer)
|
142 |
def fine_tune_segmentation_model(train_loader):
|
143 |
-
model
|
|
|
144 |
model.train()
|
|
|
145 |
optimizer = AdamW(model.parameters(), lr=1e-4)
|
146 |
criterion = torch.nn.CrossEntropyLoss()
|
147 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
@@ -157,6 +163,7 @@ def fine_tune_segmentation_model(train_loader):
|
|
157 |
loss.backward()
|
158 |
optimizer.step()
|
159 |
running_loss += loss.item()
|
|
|
160 |
return running_loss / len(train_loader)
|
161 |
|
162 |
# Add a button for segmentation training
|
|
|
95 |
|
96 |
# Training function for classification
|
97 |
def fine_tune_classification_model(train_loader):
|
98 |
+
# Load the ResNet model with ignore_mismatched_sizes
|
99 |
+
model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50', num_labels=3, ignore_mismatched_sizes=True)
|
100 |
+
model.classifier = torch.nn.Linear(model.config.hidden_size, 3) # Update classifier for 3 labels
|
101 |
model.train()
|
102 |
+
|
103 |
optimizer = AdamW(model.parameters(), lr=1e-4)
|
104 |
criterion = torch.nn.CrossEntropyLoss()
|
105 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
115 |
loss.backward()
|
116 |
optimizer.step()
|
117 |
running_loss += loss.item()
|
118 |
+
|
119 |
return running_loss / len(train_loader)
|
120 |
|
121 |
# Streamlit UI for Fine-tuning
|
|
|
144 |
|
145 |
# Segmentation function (using SegFormer)
|
146 |
def fine_tune_segmentation_model(train_loader):
|
147 |
+
# Load the Segformer model with ignore_mismatched_sizes
|
148 |
+
model = SegformerForSemanticSegmentation.from_pretrained('nvidia/segformer-b0', num_labels=3, ignore_mismatched_sizes=True)
|
149 |
model.train()
|
150 |
+
|
151 |
optimizer = AdamW(model.parameters(), lr=1e-4)
|
152 |
criterion = torch.nn.CrossEntropyLoss()
|
153 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
163 |
loss.backward()
|
164 |
optimizer.step()
|
165 |
running_loss += loss.item()
|
166 |
+
|
167 |
return running_loss / len(train_loader)
|
168 |
|
169 |
# Add a button for segmentation training
|