daranaka commited on
Commit
580bfba
·
verified ·
1 Parent(s): f31366c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -67
app.py CHANGED
@@ -4,9 +4,7 @@ from PIL import Image
4
  import torch
5
  import numpy as np
6
  import urllib.request
7
- import subprocess
8
 
9
- # Load model
10
  @st.cache_resource
11
  def load_model():
12
  model = AutoModel.from_pretrained("ragavsachdeva/magi", trust_remote_code=True)
@@ -14,7 +12,6 @@ def load_model():
14
  model.to(device)
15
  return model
16
 
17
- # Read image as numpy array
18
  @st.cache_data
19
  def read_image_as_np_array(image_path):
20
  if "http" in image_path:
@@ -24,7 +21,6 @@ def read_image_as_np_array(image_path):
24
  image = np.array(image)
25
  return image
26
 
27
- # Predict detections and associations
28
  @st.cache_data
29
  def predict_detections_and_associations(
30
  image_path,
@@ -46,7 +42,6 @@ def predict_detections_and_associations(
46
  )[0]
47
  return result
48
 
49
- # OCR prediction for transcript
50
  @st.cache_data
51
  def predict_ocr(
52
  image_path,
@@ -56,9 +51,11 @@ def predict_ocr(
56
  character_character_matching_threshold,
57
  text_character_matching_threshold,
58
  ):
 
 
59
  image = read_image_as_np_array(image_path)
60
  result = predict_detections_and_associations(
61
- image_path,
62
  character_detection_threshold,
63
  panel_detection_threshold,
64
  text_detection_threshold,
@@ -70,76 +67,59 @@ def predict_ocr(
70
  ocr_results = model.predict_ocr([image], text_bboxes_for_all_images)
71
  return ocr_results
72
 
73
- # Terminal command function
74
- def run_command(command):
75
- try:
76
- result = subprocess.run(command, shell=True, text=True, capture_output=True)
77
- output = result.stdout + result.stderr
78
- return output
79
- except Exception as e:
80
- return str(e)
81
-
82
- # Load the model
83
  model = load_model()
84
 
85
- # UI Design
86
- st.markdown("""<style>
87
- .title-container { background-color: #0d1117; padding: 20px; border-radius: 10px; margin: 20px; }
88
- .title { font-size: 2em; text-align: center; color: #fff; font-family: 'Comic Sans MS', cursive; text-transform: uppercase; letter-spacing: 0.1em; padding: 0.5em 0 0.2em; background: 0 0; }
89
- .title span { background: -webkit-linear-gradient(45deg, #6495ed, #4169e1); -webkit-background-clip: text; -webkit-text-fill-color: transparent; }
90
- .subheading { font-size: 1.5em; text-align: center; color: #ddd; font-family: 'Comic Sans MS', cursive; }
91
- </style>""", unsafe_allow_html=True)
92
-
93
- st.title("Manga Narrator and Terminal App")
94
-
95
- # File uploader for image
96
  path_to_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
97
 
98
- # Sidebar with hyperparameters
 
 
99
  st.sidebar.markdown("**Hyperparameters**")
100
- character_detection_threshold = st.sidebar.slider('Character detection threshold', 0.0, 1.0, 0.30, step=0.01)
101
- panel_detection_threshold = st.sidebar.slider('Panel detection threshold', 0.0, 1.0, 0.2, step=0.01)
102
- text_detection_threshold = st.sidebar.slider('Text detection threshold', 0.0, 1.0, 0.25, step=0.01)
103
- character_character_matching_threshold = st.sidebar.slider('Character-character matching threshold', 0.0, 1.0, 0.7, step=0.01)
104
- text_character_matching_threshold = st.sidebar.slider('Text-character matching threshold', 0.0, 1.0, 0.4, step=0.01)
 
105
 
106
- # Generate Narration button
107
  if path_to_image is not None:
108
- st.markdown("**Prediction**")
109
 
110
- # Button to generate narration
111
- if st.button("Generate Narration"):
112
- # Generate detections and associations
113
  result = predict_detections_and_associations(
114
- path_to_image,
115
- character_detection_threshold,
116
- panel_detection_threshold,
117
- text_detection_threshold,
118
- character_character_matching_threshold,
119
- text_character_matching_threshold,
120
- )
121
-
122
- # OCR result
123
  ocr_results = predict_ocr(
124
  path_to_image,
125
- character_detection_threshold,
126
- panel_detection_threshold,
127
- text_detection_threshold,
128
- character_character_matching_threshold,
129
- text_character_matching_threshold,
130
  )
131
-
132
- # Display results
133
- st.image(result['image'], caption="Detected Panels and Characters")
134
- st.text_area("Narration", result.get("narration", "Narration not available."))
135
-
136
- # Terminal command input
137
- st.markdown("**Terminal**")
138
- command_input = st.text_input("Enter a command", key='input')
139
- if st.button("Run Command"):
140
- if command_input:
141
- # Execute command
142
- output = run_command(command_input)
143
- # Display output
144
- st.text_area("Terminal Output", value=output, height=300)
145
-
 
 
 
4
  import torch
5
  import numpy as np
6
  import urllib.request
 
7
 
 
8
  @st.cache_resource
9
  def load_model():
10
  model = AutoModel.from_pretrained("ragavsachdeva/magi", trust_remote_code=True)
 
12
  model.to(device)
13
  return model
14
 
 
15
  @st.cache_data
16
  def read_image_as_np_array(image_path):
17
  if "http" in image_path:
 
21
  image = np.array(image)
22
  return image
23
 
 
24
  @st.cache_data
25
  def predict_detections_and_associations(
26
  image_path,
 
42
  )[0]
43
  return result
44
 
 
45
  @st.cache_data
46
  def predict_ocr(
47
  image_path,
 
51
  character_character_matching_threshold,
52
  text_character_matching_threshold,
53
  ):
54
+ if not generate_transcript:
55
+ return
56
  image = read_image_as_np_array(image_path)
57
  result = predict_detections_and_associations(
58
+ path_to_image,
59
  character_detection_threshold,
60
  panel_detection_threshold,
61
  text_detection_threshold,
 
67
  ocr_results = model.predict_ocr([image], text_bboxes_for_all_images)
68
  return ocr_results
69
 
 
 
 
 
 
 
 
 
 
 
70
  model = load_model()
71
 
72
+ st.markdown(""" <style> .title-container { background-color: #0d1117; padding: 20px; border-radius: 10px; margin: 20px; } .title { font-size: 2em; text-align: center; color: #fff; font-family: 'Comic Sans MS', cursive; text-transform: uppercase; letter-spacing: 0.1em; padding: 0.5em 0 0.2em; background: 0 0; } .title span { background: -webkit-linear-gradient(45deg, #6495ed, #4169e1); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .subheading { font-size: 1.5em; text-align: center; color: #ddd; font-family: 'Comic Sans MS', cursive; } .affil, .authors { font-size: 1em; text-align: center; color: #ddd; font-family: 'Comic Sans MS', cursive; } .authors { padding-top: 1em; } </style> <div class='title-container'> <div class='title'> The <span>Ma</span>n<span>g</span>a Wh<span>i</span>sperer </div> <div class='subheading'> Automatically Generating Transcriptions for Comics </div> <div class='authors'> Ragav Sachdeva and Andrew Zisserman </div> <div class='affil'> University of Oxford </div> </div>""", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
73
  path_to_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
74
 
75
+ st.sidebar.markdown("**Mode**")
76
+ generate_detections_and_associations = st.sidebar.toggle("Generate detections and associations", True)
77
+ generate_transcript = st.sidebar.toggle("Generate transcript (slower)", False)
78
  st.sidebar.markdown("**Hyperparameters**")
79
+ input_character_detection_threshold = st.sidebar.slider('Character detection threshold', 0.0, 1.0, 0.30, step=0.01)
80
+ input_panel_detection_threshold = st.sidebar.slider('Panel detection threshold', 0.0, 1.0, 0.2, step=0.01)
81
+ input_text_detection_threshold = st.sidebar.slider('Text detection threshold', 0.0, 1.0, 0.25, step=0.01)
82
+ input_character_character_matching_threshold = st.sidebar.slider('Character-character matching threshold', 0.0, 1.0, 0.7, step=0.01)
83
+ input_text_character_matching_threshold = st.sidebar.slider('Text-character matching threshold', 0.0, 1.0, 0.4, step=0.01)
84
+
85
 
 
86
  if path_to_image is not None:
87
+ image = read_image_as_np_array(path_to_image)
88
 
89
+ st.markdown("**Prediction**")
90
+ if generate_detections_and_associations or generate_transcript:
 
91
  result = predict_detections_and_associations(
92
+ path_to_image,
93
+ input_character_detection_threshold,
94
+ input_panel_detection_threshold,
95
+ input_text_detection_threshold,
96
+ input_character_character_matching_threshold,
97
+ input_text_character_matching_threshold,
98
+ )
99
+
100
+ if generate_transcript:
101
  ocr_results = predict_ocr(
102
  path_to_image,
103
+ input_character_detection_threshold,
104
+ input_panel_detection_threshold,
105
+ input_text_detection_threshold,
106
+ input_character_character_matching_threshold,
107
+ input_text_character_matching_threshold,
108
  )
109
+
110
+ if generate_detections_and_associations and generate_transcript:
111
+ col1, col2 = st.columns(2)
112
+ output = model.visualise_single_image_prediction(image, result)
113
+ col1.image(output)
114
+ text_bboxes_for_all_images = [result["texts"]]
115
+ ocr_results = model.predict_ocr([image], text_bboxes_for_all_images)
116
+ transcript = model.generate_transcript_for_single_image(result, ocr_results[0])
117
+ col2.text(transcript)
118
+
119
+ elif generate_detections_and_associations:
120
+ output = model.visualise_single_image_prediction(image, result)
121
+ st.image(output)
122
+
123
+ elif generate_transcript:
124
+ transcript = model.generate_transcript_for_single_image(result, ocr_results[0])
125
+ st.text(transcript)