aagoluoglu commited on
Commit
efa687f
·
verified ·
1 Parent(s): 15b4e31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -15,6 +15,7 @@ import torch
15
 
16
  sns.set_theme()
17
 
 
18
  www_dir = Path(__file__).parent.resolve() / "www"
19
 
20
  df = pd.read_csv(Path(__file__).parent / "penguins.csv", na_values="NA")
@@ -64,8 +65,10 @@ def server(input: Inputs, output: Outputs, session: Session):
64
  model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
65
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
66
 
67
- # Create an instance of the model from my fine-tuned model with the loaded configuration
68
- model = SamModel.from_pretrained("aagoluoglu/SAM_Sidewalks", config=model_config)
 
 
69
 
70
  # set the device to cuda if available, otherwise use cpu
71
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -84,7 +87,6 @@ def server(input: Inputs, output: Outputs, session: Session):
84
  """Displays the uploaded image"""
85
  img_src = uploaded_image_path()
86
  if img_src:
87
- dir = Path(__file__).resolve().parent
88
  img: ImgData = {"src": str(dir / uploaded_image_path()), "width": "100px"}
89
  return img
90
  else:
 
15
 
16
  sns.set_theme()
17
 
18
+ dir = Path(__file__).resolve().parent
19
  www_dir = Path(__file__).parent.resolve() / "www"
20
 
21
  df = pd.read_csv(Path(__file__).parent / "penguins.csv", na_values="NA")
 
65
  model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
66
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
67
 
68
+ # Create an instance of the model architecture with the loaded configuration
69
+ model = SamModel(config=model_config)
70
+ # Update the model by loading the weights from saved file.
71
+ model.load_state_dict(torch.load(str(dir / "checkpoint.pth")))
72
 
73
  # set the device to cuda if available, otherwise use cpu
74
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
87
  """Displays the uploaded image"""
88
  img_src = uploaded_image_path()
89
  if img_src:
 
90
  img: ImgData = {"src": str(dir / uploaded_image_path()), "width": "100px"}
91
  return img
92
  else: