this sucks
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ import sys
|
|
10 |
import timm
|
11 |
from types import SimpleNamespace
|
12 |
# from transformers import AutoModel, pipeline
|
|
|
13 |
import torch
|
14 |
|
15 |
sys.path.insert(1, "../")
|
@@ -176,7 +177,6 @@ def plot_activations(activation_1: list, activation_2: list, origin='lower'):
|
|
176 |
|
177 |
return fig
|
178 |
|
179 |
-
|
180 |
def predict_and_analyze(model_name, num_channels, dim, input_channel, image):
|
181 |
|
182 |
'''
|
@@ -203,7 +203,7 @@ def predict_and_analyze(model_name, num_channels, dim, input_channel, image):
|
|
203 |
print("Data loaded")
|
204 |
|
205 |
print("Loading model")
|
206 |
-
model_loading_name =
|
207 |
|
208 |
if 'eff' in model_name:
|
209 |
hparams = effnet_hparams[num_channels]
|
@@ -220,13 +220,15 @@ def predict_and_analyze(model_name, num_channels, dim, input_channel, image):
|
|
220 |
depth_mult=hparams.depth_mult,
|
221 |
)
|
222 |
|
223 |
-
config.save_pretrained(
|
224 |
-
|
225 |
-
config = EfficientNetConfig.from_pretrained("%s_planet_detection" % (model_name))
|
226 |
|
227 |
model = EfficientNetPreTrained(config)
|
|
|
|
|
|
|
228 |
|
229 |
-
pretrained_model = timm.create_model(
|
230 |
model.model.load_state_dict(pretrained_model.state_dict())
|
231 |
|
232 |
# pipeline = pipeline(task="image-classification", model=model_loading_name)
|
|
|
10 |
import timm
|
11 |
from types import SimpleNamespace
|
12 |
# from transformers import AutoModel, pipeline
|
13 |
+
from transformers import AutoModelForImageClassification
|
14 |
import torch
|
15 |
|
16 |
sys.path.insert(1, "../")
|
|
|
177 |
|
178 |
return fig
|
179 |
|
|
|
180 |
def predict_and_analyze(model_name, num_channels, dim, input_channel, image):
|
181 |
|
182 |
'''
|
|
|
203 |
print("Data loaded")
|
204 |
|
205 |
print("Loading model")
|
206 |
+
model_loading_name = "%s_%i_planet_detection" % (model_name, num_channels)
|
207 |
|
208 |
if 'eff' in model_name:
|
209 |
hparams = effnet_hparams[num_channels]
|
|
|
220 |
depth_mult=hparams.depth_mult,
|
221 |
)
|
222 |
|
223 |
+
config.save_pretrained(model_loading_name)
|
224 |
+
config = EfficientNetConfig.from_pretrained(model_loading_name)
|
|
|
225 |
|
226 |
model = EfficientNetPreTrained(config)
|
227 |
+
|
228 |
+
# config.register_for_auto_class()
|
229 |
+
# model.register_for_auto_class("AutoModelForImageClassification")
|
230 |
|
231 |
+
pretrained_model = timm.create_model(model_loading_name, pretrained=True)
|
232 |
model.model.load_state_dict(pretrained_model.state_dict())
|
233 |
|
234 |
# pipeline = pipeline(task="image-classification", model=model_loading_name)
|