z-uo commited on
Commit
b4d55e3
·
1 Parent(s): 2eef120

add raw text output

Browse files
Files changed (2) hide show
  1. app.py +7 -4
  2. test.py +3 -2
app.py CHANGED
@@ -8,7 +8,7 @@ import gradio as gr
8
 
9
  # import sys
10
  # sys.path.insert(0, './')
11
- from test import create_letr, draw_fig
12
  from models.preprocessing import *
13
  from models.misc import nested_tensor_from_tensor_list
14
 
@@ -57,9 +57,9 @@ def predict(inp, size, model_name):
57
  else:
58
  outputs = model(inputs)[0]
59
 
60
- draw_fig(image, outputs, orig_size)
61
 
62
- return image
63
 
64
 
65
  inputs = [
@@ -67,7 +67,10 @@ inputs = [
67
  gr.inputs.Radio(["256", "512", "1100"]),
68
  gr.inputs.Radio(["resnet50", "resnet101"]),
69
  ]
70
- outputs = gr.outputs.Image()
 
 
 
71
  gr.Interface(
72
  fn=predict,
73
  inputs=inputs,
 
8
 
9
  # import sys
10
  # sys.path.insert(0, './')
11
+ from test import create_letr, get_lines_and_draw
12
  from models.preprocessing import *
13
  from models.misc import nested_tensor_from_tensor_list
14
 
 
57
  else:
58
  outputs = model(inputs)[0]
59
 
60
+ lines = get_lines_and_draw(image, outputs, orig_size)
61
 
62
+ return image, str(lines)
63
 
64
 
65
  inputs = [
 
67
  gr.inputs.Radio(["256", "512", "1100"]),
68
  gr.inputs.Radio(["resnet50", "resnet101"]),
69
  ]
70
+ outputs = [
71
+ gr.outputs.Image(),
72
+ gr.outputs.Textbox()
73
+ ]
74
  gr.Interface(
75
  fn=predict,
76
  inputs=inputs,
test.py CHANGED
@@ -19,7 +19,7 @@ def create_letr(path):
19
  model.eval()
20
  return model
21
 
22
- def draw_fig(image, outputs, orig_size):
23
  # find lines
24
  out_logits, out_line = outputs['pred_logits'], outputs['pred_lines']
25
  prob = F.softmax(out_logits, -1)
@@ -42,6 +42,7 @@ def draw_fig(image, outputs, orig_size):
42
  for tp_id, line in enumerate(lines):
43
  y1, x1, y2, x2 = line
44
  draw.line((x1, y1, x2, y2), fill=500)
 
45
 
46
  if __name__ == '__main__':
47
  model = create_letr('resnet50/checkpoint0024.pth')
@@ -62,6 +63,6 @@ if __name__ == '__main__':
62
 
63
  with torch.no_grad():
64
  outputs = model(inputs)[0]
65
- draw_fig(image, outputs, orig_size)
66
 
67
  image.save('output.png')
 
19
  model.eval()
20
  return model
21
 
22
+ def get_lines_and_draw(image, outputs, orig_size):
23
  # find lines
24
  out_logits, out_line = outputs['pred_logits'], outputs['pred_lines']
25
  prob = F.softmax(out_logits, -1)
 
42
  for tp_id, line in enumerate(lines):
43
  y1, x1, y2, x2 = line
44
  draw.line((x1, y1, x2, y2), fill=500)
45
+ return lines
46
 
47
  if __name__ == '__main__':
48
  model = create_letr('resnet50/checkpoint0024.pth')
 
63
 
64
  with torch.no_grad():
65
  outputs = model(inputs)[0]
66
+ lines = get_lines_and_draw(image, outputs, orig_size)
67
 
68
  image.save('output.png')