HaoFeng2019 commited on
Commit
980de13
·
1 Parent(s): 759cd1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -0
app.py CHANGED
@@ -20,6 +20,35 @@ import gradio as gr
20
 
21
  example_img_list = ['51_1 copy.png','48_2 copy.png','25.jpg']
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class GeoTr_Seg(nn.Module):
24
  def __init__(self):
25
  super(GeoTr_Seg, self).__init__()
@@ -36,6 +65,9 @@ class GeoTr_Seg(nn.Module):
36
 
37
  return bm
38
 
 
 
 
39
  # Initialize models
40
  GeoTr_Seg_model = GeoTr_Seg()
41
  #IllTr_model = IllTr()
 
20
 
21
  example_img_list = ['51_1 copy.png','48_2 copy.png','25.jpg']
22
 
23
+ def reload_model(model, path=""):
24
+ if not bool(path):
25
+ return model
26
+ else:
27
+ model_dict = model.state_dict()
28
+ pretrained_dict = torch.load(path, map_location='cpu')
29
+ # print(len(pretrained_dict.keys()))
30
+ pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
31
+ # print(len(pretrained_dict.keys()))
32
+ model_dict.update(pretrained_dict)
33
+ model.load_state_dict(model_dict)
34
+
35
+ return model
36
+
37
+
38
+ def reload_segmodel(model, path=""):
39
+ if not bool(path):
40
+ return model
41
+ else:
42
+ model_dict = model.state_dict()
43
+ pretrained_dict = torch.load(path, map_location='cpu')
44
+ # print(len(pretrained_dict.keys()))
45
+ pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
46
+ # print(len(pretrained_dict.keys()))
47
+ model_dict.update(pretrained_dict)
48
+ model.load_state_dict(model_dict)
49
+
50
+ return model
51
+
52
  class GeoTr_Seg(nn.Module):
53
  def __init__(self):
54
  super(GeoTr_Seg, self).__init__()
 
65
 
66
  return bm
67
 
68
+
69
+
70
+
71
  # Initialize models
72
  GeoTr_Seg_model = GeoTr_Seg()
73
  #IllTr_model = IllTr()