daranaka commited on
Commit
e8897e7
1 Parent(s): 17192ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -41
app.py CHANGED
@@ -5,15 +5,15 @@ import torch
5
  import numpy as np
6
  import urllib.request
7
 
8
- # Load model without caching due to serialization issue with PretrainedConfig
9
  def load_model():
10
  model = AutoModel.from_pretrained("ragavsachdeva/magi", trust_remote_code=True)
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model.to(device)
13
  return model
14
 
15
- # Load the model outside caching
16
- model = load_model() # Place this line here, right after defining load_model
17
 
18
  @st.cache_data
19
  def read_image_as_np_array(image_path):
@@ -45,7 +45,6 @@ def predict_detections_and_associations(
45
  )[0]
46
  return result
47
 
48
-
49
  @st.cache_data
50
  def predict_ocr(
51
  image_path,
@@ -59,7 +58,7 @@ def predict_ocr(
59
  return
60
  image = read_image_as_np_array(image_path)
61
  result = predict_detections_and_associations(
62
- image_path,
63
  character_detection_threshold,
64
  panel_detection_threshold,
65
  text_detection_threshold,
@@ -71,23 +70,25 @@ def predict_ocr(
71
  ocr_results = model.predict_ocr([image], text_bboxes_for_all_images)
72
  return ocr_results
73
 
74
- def clear_memory():
75
- st.session_state.memory = {"characters": {}, "transcript": ""}
76
- st.write("Memory cleared.")
 
 
 
 
 
 
 
 
77
 
78
- model = load_model()
79
-
80
- # Display header and UI components
81
- st.markdown(""" <style> ... styles here ... </style> """, unsafe_allow_html=True)
82
  path_to_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
83
 
84
- # Memory control button
85
- st.button("Clear Memory", on_click=clear_memory)
86
-
87
  st.sidebar.markdown("**Mode**")
88
- generate_detections_and_associations = st.sidebar.toggle("Generate detections and associations", True)
89
- generate_transcript = st.sidebar.toggle("Generate transcript (slower)", False)
90
 
 
91
  st.sidebar.markdown("**Hyperparameters**")
92
  input_character_detection_threshold = st.sidebar.slider('Character detection threshold', 0.0, 1.0, 0.30, step=0.01)
93
  input_panel_detection_threshold = st.sidebar.slider('Panel detection threshold', 0.0, 1.0, 0.2, step=0.01)
@@ -95,11 +96,13 @@ input_text_detection_threshold = st.sidebar.slider('Text detection threshold', 0
95
  input_character_character_matching_threshold = st.sidebar.slider('Character-character matching threshold', 0.0, 1.0, 0.7, step=0.01)
96
  input_text_character_matching_threshold = st.sidebar.slider('Text-character matching threshold', 0.0, 1.0, 0.4, step=0.01)
97
 
 
98
  if path_to_image is not None:
99
  image = read_image_as_np_array(path_to_image)
100
  st.markdown("**Prediction**")
101
-
102
- if generate_detections_and_associations or generate_transcript:
 
103
  result = predict_detections_and_associations(
104
  path_to_image,
105
  input_character_detection_threshold,
@@ -108,6 +111,8 @@ if path_to_image is not None:
108
  input_character_character_matching_threshold,
109
  input_text_character_matching_threshold,
110
  )
 
 
111
 
112
  if generate_transcript:
113
  ocr_results = predict_ocr(
@@ -118,26 +123,5 @@ if path_to_image is not None:
118
  input_character_character_matching_threshold,
119
  input_text_character_matching_threshold,
120
  )
121
-
122
- # Append new characters and transcript to memory
123
- if generate_detections_and_associations:
124
- output = model.visualise_single_image_prediction(image, result)
125
- st.image(output)
126
- # Update character memory based on detected characters
127
- detected_characters = result.get("characters", {})
128
- st.session_state.memory["characters"].update(detected_characters)
129
-
130
- # Append the current transcript to the ongoing transcript in memory
131
  transcript = model.generate_transcript_for_single_image(result, ocr_results[0])
132
- st.session_state.memory["transcript"] += transcript + "\n"
133
-
134
- # Display the cumulative transcript from memory
135
- st.text(st.session_state.memory["transcript"])
136
-
137
- elif generate_detections_and_associations:
138
- output = model.visualise_single_image_prediction(image, result)
139
- st.image(output)
140
-
141
- elif generate_transcript:
142
- # Display the cumulative transcript
143
- st.text(st.session_state.memory["transcript"])
 
5
  import numpy as np
6
  import urllib.request
7
 
8
+ # Load the model without caching to avoid serialization issues
9
  def load_model():
10
  model = AutoModel.from_pretrained("ragavsachdeva/magi", trust_remote_code=True)
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model.to(device)
13
  return model
14
 
15
+ # Initialize the model once at the top level, outside any caching functions
16
+ model = load_model()
17
 
18
  @st.cache_data
19
  def read_image_as_np_array(image_path):
 
45
  )[0]
46
  return result
47
 
 
48
  @st.cache_data
49
  def predict_ocr(
50
  image_path,
 
58
  return
59
  image = read_image_as_np_array(image_path)
60
  result = predict_detections_and_associations(
61
+ path_to_image,
62
  character_detection_threshold,
63
  panel_detection_threshold,
64
  text_detection_threshold,
 
70
  ocr_results = model.predict_ocr([image], text_bboxes_for_all_images)
71
  return ocr_results
72
 
73
+ # Streamlit UI elements
74
+ st.markdown("""
75
+ <style> .title-container { background-color: #0d1117; padding: 20px; border-radius: 10px; margin: 20px; }
76
+ .title { font-size: 2em; text-align: center; color: #fff; font-family: 'Comic Sans MS', cursive; text-transform: uppercase;
77
+ letter-spacing: 0.1em; padding: 0.5em 0 0.2em; background: 0 0; } .title span { background: -webkit-linear-gradient(45deg,
78
+ #6495ed, #4169e1); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .subheading { font-size: 1.5em;
79
+ text-align: center; color: #ddd; font-family: 'Comic Sans MS', cursive; } .affil, .authors { font-size: 1em; text-align: center;
80
+ color: #ddd; font-family: 'Comic Sans MS', cursive; } .authors { padding-top: 1em; } </style>
81
+ <div class='title-container'> <div class='title'> The <span>Ma</span>n<span>g</span>a Wh<span>i</span>sperer </div>
82
+ <div class='subheading'> Automatically Generating Transcriptions for Comics </div> <div class='authors'> Ragav Sachdeva and
83
+ Andrew Zisserman </div> <div class='affil'> University of Oxford </div> </div>""", unsafe_allow_html=True)
84
 
 
 
 
 
85
  path_to_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
86
 
 
 
 
87
  st.sidebar.markdown("**Mode**")
88
+ generate_detections_and_associations = st.sidebar.checkbox("Generate detections and associations", True)
89
+ generate_transcript = st.sidebar.checkbox("Generate transcript (slower)", False)
90
 
91
+ # Hyperparameter Sliders
92
  st.sidebar.markdown("**Hyperparameters**")
93
  input_character_detection_threshold = st.sidebar.slider('Character detection threshold', 0.0, 1.0, 0.30, step=0.01)
94
  input_panel_detection_threshold = st.sidebar.slider('Panel detection threshold', 0.0, 1.0, 0.2, step=0.01)
 
96
  input_character_character_matching_threshold = st.sidebar.slider('Character-character matching threshold', 0.0, 1.0, 0.7, step=0.01)
97
  input_text_character_matching_threshold = st.sidebar.slider('Text-character matching threshold', 0.0, 1.0, 0.4, step=0.01)
98
 
99
+ # Main processing based on image input
100
  if path_to_image is not None:
101
  image = read_image_as_np_array(path_to_image)
102
  st.markdown("**Prediction**")
103
+
104
+ # Run predictions based on checkbox selections
105
+ if generate_detections_and_associations:
106
  result = predict_detections_and_associations(
107
  path_to_image,
108
  input_character_detection_threshold,
 
111
  input_character_character_matching_threshold,
112
  input_text_character_matching_threshold,
113
  )
114
+ output = model.visualise_single_image_prediction(image, result)
115
+ st.image(output)
116
 
117
  if generate_transcript:
118
  ocr_results = predict_ocr(
 
123
  input_character_character_matching_threshold,
124
  input_text_character_matching_threshold,
125
  )
 
 
 
 
 
 
 
 
 
 
126
  transcript = model.generate_transcript_for_single_image(result, ocr_results[0])
127
+ st.text(transcript)