PranayChamala commited on
Commit
665358e
·
1 Parent(s): 65ca521

Made changes to mediseg.py

Browse files
Files changed (1) hide show
  1. app/mediseg.py +3 -3
app/mediseg.py CHANGED
@@ -117,7 +117,7 @@ class UNetMulti(nn.Module):
117
 
118
  def process_brain_tumor(image: Image.Image, model_path=os.path.join("models", "brain_tumor_unet_multiclass.pth")) -> str:
119
  model = UNetMulti(in_channels=3, out_channels=4).to(device)
120
- model.load_state_dict(torch.load(model_path, map_location=device))
121
  model.eval()
122
 
123
  transform_img = transforms.Compose([
@@ -204,7 +204,7 @@ class UNetBinary(nn.Module):
204
 
205
  def process_endoscopy(image: Image.Image, model_path=os.path.join("models", "endoscopy_unet.pth")) -> str:
206
  model = UNetBinary(in_channels=3, out_channels=1).to(device)
207
- model.load_state_dict(torch.load(model_path, map_location=device))
208
  model.eval()
209
 
210
  transform_img = transforms.Compose([
@@ -284,7 +284,7 @@ def process_pneumonia(image: Image.Image, model_path=os.path.join("models", "pne
284
  model = models.resnet18(pretrained=False)
285
  num_ftrs = model.fc.in_features
286
  model.fc = nn.Linear(num_ftrs, 2) # 2 classes: normal and pneumonia
287
- model.load_state_dict(torch.load(model_path, map_location=device))
288
  model.to(device)
289
  model.eval()
290
 
 
117
 
118
  def process_brain_tumor(image: Image.Image, model_path=os.path.join("models", "brain_tumor_unet_multiclass.pth")) -> str:
119
  model = UNetMulti(in_channels=3, out_channels=4).to(device)
120
+ model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
121
  model.eval()
122
 
123
  transform_img = transforms.Compose([
 
204
 
205
  def process_endoscopy(image: Image.Image, model_path=os.path.join("models", "endoscopy_unet.pth")) -> str:
206
  model = UNetBinary(in_channels=3, out_channels=1).to(device)
207
+ model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
208
  model.eval()
209
 
210
  transform_img = transforms.Compose([
 
284
  model = models.resnet18(pretrained=False)
285
  num_ftrs = model.fc.in_features
286
  model.fc = nn.Linear(num_ftrs, 2) # 2 classes: normal and pneumonia
287
+ model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
288
  model.to(device)
289
  model.eval()
290