HaoFeng2019 commited on
Commit
eb95227
·
1 Parent(s): 5c5a030

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -47
app.py CHANGED
@@ -1,5 +1,3 @@
1
- #origin
2
-
3
  from seg import U2NETP
4
  from GeoTr import GeoTr
5
  from IllTr import IllTr
@@ -18,10 +16,6 @@ import argparse
18
  import warnings
19
  warnings.filterwarnings('ignore')
20
 
21
-
22
-
23
-
24
-
25
  import gradio as gr
26
 
27
  example_img_list = ['51_1 copy.png','48_2 copy.png','25.jpg']
@@ -41,48 +35,21 @@ class GeoTr_Seg(nn.Module):
41
  bm = (2 * (bm / 286.8) - 1) * 0.99
42
 
43
  return bm
44
-
45
-
46
- def reload_model(model, path=""):
47
- if not bool(path):
48
- return model
49
- else:
50
- model_dict = model.state_dict()
51
- pretrained_dict = torch.load(path, map_location='cpu')
52
- #print(len(pretrained_dict.keys()))
53
- pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
54
- #print(len(pretrained_dict.keys()))
55
- model_dict.update(pretrained_dict)
56
- model.load_state_dict(model_dict)
57
 
58
- return model
59
-
60
-
61
- def reload_segmodel(model, path=""):
62
- if not bool(path):
63
- return model
64
- else:
65
- model_dict = model.state_dict()
66
- pretrained_dict = torch.load(path, map_location='cpu')
67
- #print(len(pretrained_dict.keys()))
68
- pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
69
- #print(len(pretrained_dict.keys()))
70
- model_dict.update(pretrained_dict)
71
- model.load_state_dict(model_dict)
72
-
73
- return model
74
-
75
 
 
 
 
 
76
 
 
 
 
77
 
78
  def process_image(input_image):
79
- GeoTr_Seg_model = GeoTr_Seg()#.cuda()
80
- reload_segmodel(GeoTr_Seg_model.msk, './model_pretrained/seg.pth')
81
- reload_model(GeoTr_Seg_model.GeoTr, './model_pretrained/geotr.pth')
82
-
83
- IllTr_model = IllTr()#.cuda()
84
- reload_model(IllTr_model, './model_pretrained/illtr.pth')
85
-
86
  GeoTr_Seg_model.eval()
87
  IllTr_model.eval()
88
 
@@ -112,13 +79,9 @@ def process_image(input_image):
112
  else:
113
  return Image.fromarray(img_geo)
114
 
115
-
116
-
117
  # Define Gradio interface
118
  input_image = gr.inputs.Image()
119
  output_image = gr.outputs.Image(type='pil')
120
 
121
-
122
  iface = gr.Interface(fn=process_image, inputs=input_image, outputs=output_image, title="DocTr", examples=example_img_list)
123
  iface.launch()
124
-
 
 
 
1
  from seg import U2NETP
2
  from GeoTr import GeoTr
3
  from IllTr import IllTr
 
16
  import warnings
17
  warnings.filterwarnings('ignore')
18
 
 
 
 
 
19
  import gradio as gr
20
 
21
  example_img_list = ['51_1 copy.png','48_2 copy.png','25.jpg']
 
35
  bm = (2 * (bm / 286.8) - 1) * 0.99
36
 
37
  return bm
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # Initialize models
40
+ GeoTr_Seg_model = GeoTr_Seg()
41
+ IllTr_model = IllTr()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # Load models only once
44
+ reload_segmodel(GeoTr_Seg_model.msk, './model_pretrained/seg.pth')
45
+ reload_model(GeoTr_Seg_model.GeoTr, './model_pretrained/geotr.pth')
46
+ reload_model(IllTr_model, './model_pretrained/illtr.pth')
47
 
48
+ # Compile models (assuming PyTorch 2.0)
49
+ GeoTr_Seg_model = torch.compile(GeoTr_Seg_model)
50
+ IllTr_model = torch.compile(IllTr_model)
51
 
52
  def process_image(input_image):
 
 
 
 
 
 
 
53
  GeoTr_Seg_model.eval()
54
  IllTr_model.eval()
55
 
 
79
  else:
80
  return Image.fromarray(img_geo)
81
 
 
 
82
  # Define Gradio interface
83
  input_image = gr.inputs.Image()
84
  output_image = gr.outputs.Image(type='pil')
85
 
 
86
  iface = gr.Interface(fn=process_image, inputs=input_image, outputs=output_image, title="DocTr", examples=example_img_list)
87
  iface.launch()