jpterry commited on
Commit
b971cba
·
1 Parent(s): 7f0b5ec

this sucks

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -10,6 +10,7 @@ import sys
10
  import timm
11
  from types import SimpleNamespace
12
  # from transformers import AutoModel, pipeline
 
13
  import torch
14
 
15
  sys.path.insert(1, "../")
@@ -176,7 +177,6 @@ def plot_activations(activation_1: list, activation_2: list, origin='lower'):
176
 
177
  return fig
178
 
179
-
180
  def predict_and_analyze(model_name, num_channels, dim, input_channel, image):
181
 
182
  '''
@@ -203,7 +203,7 @@ def predict_and_analyze(model_name, num_channels, dim, input_channel, image):
203
  print("Data loaded")
204
 
205
  print("Loading model")
206
- model_loading_name = model_path + "%s_%i_planet_detection" % (model_name, num_channels)
207
 
208
  if 'eff' in model_name:
209
  hparams = effnet_hparams[num_channels]
@@ -220,13 +220,15 @@ def predict_and_analyze(model_name, num_channels, dim, input_channel, image):
220
  depth_mult=hparams.depth_mult,
221
  )
222
 
223
- config.save_pretrained("%s_planet_detection" % (model_name))
224
-
225
- config = EfficientNetConfig.from_pretrained("%s_planet_detection" % (model_name))
226
 
227
  model = EfficientNetPreTrained(config)
 
 
 
228
 
229
- pretrained_model = timm.create_model("%s_planet_detection" % (model_name), pretrained=True)
230
  model.model.load_state_dict(pretrained_model.state_dict())
231
 
232
  # pipeline = pipeline(task="image-classification", model=model_loading_name)
 
10
  import timm
11
  from types import SimpleNamespace
12
  # from transformers import AutoModel, pipeline
13
+ from transformers import AutoModelForImageClassification
14
  import torch
15
 
16
  sys.path.insert(1, "../")
 
177
 
178
  return fig
179
 
 
180
  def predict_and_analyze(model_name, num_channels, dim, input_channel, image):
181
 
182
  '''
 
203
  print("Data loaded")
204
 
205
  print("Loading model")
206
+ model_loading_name = "%s_%i_planet_detection" % (model_name, num_channels)
207
 
208
  if 'eff' in model_name:
209
  hparams = effnet_hparams[num_channels]
 
220
  depth_mult=hparams.depth_mult,
221
  )
222
 
223
+ config.save_pretrained(model_loading_name)
224
+ config = EfficientNetConfig.from_pretrained(model_loading_name)
 
225
 
226
  model = EfficientNetPreTrained(config)
227
+
228
+ # config.register_for_auto_class()
229
+ # model.register_for_auto_class("AutoModelForImageClassification")
230
 
231
+ pretrained_model = timm.create_model(model_loading_name, pretrained=True)
232
  model.model.load_state_dict(pretrained_model.state_dict())
233
 
234
  # pipeline = pipeline(task="image-classification", model=model_loading_name)