Spaces:
Running
Running
Commit
·
665358e
1
Parent(s):
65ca521
Made changes to mediseg.py
Browse files- app/mediseg.py +3 -3
app/mediseg.py
CHANGED
@@ -117,7 +117,7 @@ class UNetMulti(nn.Module):
|
|
117 |
|
118 |
def process_brain_tumor(image: Image.Image, model_path=os.path.join("models", "brain_tumor_unet_multiclass.pth")) -> str:
|
119 |
model = UNetMulti(in_channels=3, out_channels=4).to(device)
|
120 |
-
model.load_state_dict(torch.load(model_path, map_location=device))
|
121 |
model.eval()
|
122 |
|
123 |
transform_img = transforms.Compose([
|
@@ -204,7 +204,7 @@ class UNetBinary(nn.Module):
|
|
204 |
|
205 |
def process_endoscopy(image: Image.Image, model_path=os.path.join("models", "endoscopy_unet.pth")) -> str:
|
206 |
model = UNetBinary(in_channels=3, out_channels=1).to(device)
|
207 |
-
model.load_state_dict(torch.load(model_path, map_location=device))
|
208 |
model.eval()
|
209 |
|
210 |
transform_img = transforms.Compose([
|
@@ -284,7 +284,7 @@ def process_pneumonia(image: Image.Image, model_path=os.path.join("models", "pne
|
|
284 |
model = models.resnet18(pretrained=False)
|
285 |
num_ftrs = model.fc.in_features
|
286 |
model.fc = nn.Linear(num_ftrs, 2) # 2 classes: normal and pneumonia
|
287 |
-
model.load_state_dict(torch.load(model_path, map_location=device))
|
288 |
model.to(device)
|
289 |
model.eval()
|
290 |
|
|
|
117 |
|
118 |
def process_brain_tumor(image: Image.Image, model_path=os.path.join("models", "brain_tumor_unet_multiclass.pth")) -> str:
|
119 |
model = UNetMulti(in_channels=3, out_channels=4).to(device)
|
120 |
+
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
|
121 |
model.eval()
|
122 |
|
123 |
transform_img = transforms.Compose([
|
|
|
204 |
|
205 |
def process_endoscopy(image: Image.Image, model_path=os.path.join("models", "endoscopy_unet.pth")) -> str:
|
206 |
model = UNetBinary(in_channels=3, out_channels=1).to(device)
|
207 |
+
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
|
208 |
model.eval()
|
209 |
|
210 |
transform_img = transforms.Compose([
|
|
|
284 |
model = models.resnet18(pretrained=False)
|
285 |
num_ftrs = model.fc.in_features
|
286 |
model.fc = nn.Linear(num_ftrs, 2) # 2 classes: normal and pneumonia
|
287 |
+
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
|
288 |
model.to(device)
|
289 |
model.eval()
|
290 |
|