Spaces:
Runtime error
Runtime error
File size: 5,596 Bytes
db28818 585854e 42eb874 83a0630 585854e 4a78e11 585854e 44741f9 585854e 44741f9 585854e 44741f9 585854e 44741f9 585854e 44741f9 585854e e8897e7 585854e 4a78e11 e8897e7 585854e e8897e7 585854e e8897e7 585854e e8897e7 585854e e8897e7 585854e e8897e7 585854e e8897e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import streamlit as st
from transformers import AutoModel
from PIL import Image
import torch
import numpy as np
import urllib.request
memory = {}
@st.cache_resource
def load_model():
model = AutoModel.from_pretrained("ragavsachdeva/magi", trust_remote_code=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
return model
@st.cache_data
def read_image_as_np_array(image_path):
if "http" in image_path:
image = Image.open(urllib.request.urlopen(image_path)).convert("L").convert("RGB")
else:
image = Image.open(image_path).convert("L").convert("RGB")
image = np.array(image)
return image
@st.cache_data
def predict_detections_and_associations(
image_path,
character_detection_threshold,
panel_detection_threshold,
text_detection_threshold,
character_character_matching_threshold,
text_character_matching_threshold,
):
image = read_image_as_np_array(image_path)
with torch.no_grad():
result = model.predict_detections_and_associations(
[image],
character_detection_threshold=character_detection_threshold,
panel_detection_threshold=panel_detection_threshold,
text_detection_threshold=text_detection_threshold,
character_character_matching_threshold=character_character_matching_threshold,
text_character_matching_threshold=text_character_matching_threshold,
)[0]
return result
@st.cache_data
def predict_ocr(
image_path,
character_detection_threshold,
panel_detection_threshold,
text_detection_threshold,
character_character_matching_threshold,
text_character_matching_threshold,
):
if not generate_transcript:
return
image = read_image_as_np_array(image_path)
result = predict_detections_and_associations(
path_to_image,
character_detection_threshold,
panel_detection_threshold,
text_detection_threshold,
character_character_matching_threshold,
text_character_matching_threshold,
)
text_bboxes_for_all_images = [result["texts"]]
with torch.no_grad():
ocr_results = model.predict_ocr([image], text_bboxes_for_all_images)
return ocr_results
model = load_model()
# Add a button to clear memory
if st.button("Clear Memory"):
memory.clear()
# Streamlit UI elements
st.markdown("""
<style> .title-container { background-color: #0d1117; padding: 20px; border-radius: 10px; margin: 20px; }
.title { font-size: 2em; text-align: center; color: #fff; font-family: 'Comic Sans MS', cursive; text-transform: uppercase;
letter-spacing: 0.1em; padding: 0.5em 0 0.2em; background: 0 0; } .title span { background: -webkit-linear-gradient(45deg,
#6495ed, #4169e1); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .subheading { font-size: 1.5em;
text-align: center; color: #ddd; font-family: 'Comic Sans MS', cursive; } .affil, .authors { font-size: 1em; text-align: center;
color: #ddd; font-family: 'Comic Sans MS', cursive; } .authors { padding-top: 1em; } </style>
<div class='title-container'> <div class='title'> The <span>Ma</span>n<span>g</span>a Wh<span>i</span>sperer </div>
<div class='subheading'> Automatically Generating Transcriptions for Comics </div> <div class='authors'> Ragav Sachdeva and
Andrew Zisserman </div> <div class='affil'> University of Oxford </div> </div>""", unsafe_allow_html=True)
path_to_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
st.sidebar.markdown("**Mode**")
generate_detections_and_associations = st.sidebar.checkbox("Generate detections and associations", True)
generate_transcript = st.sidebar.checkbox("Generate transcript (slower)", False)
# Hyperparameter Sliders
st.sidebar.markdown("**Hyperparameters**")
input_character_detection_threshold = st.sidebar.slider('Character detection threshold', 0.0, 1.0, 0.30, step=0.01)
input_panel_detection_threshold = st.sidebar.slider('Panel detection threshold', 0.0, 1.0, 0.2, step=0.01)
input_text_detection_threshold = st.sidebar.slider('Text detection threshold', 0.0, 1.0, 0.25, step=0.01)
input_character_character_matching_threshold = st.sidebar.slider('Character-character matching threshold', 0.0, 1.0, 0.7, step=0.01)
input_text_character_matching_threshold = st.sidebar.slider('Text-character matching threshold', 0.0, 1.0, 0.4, step=0.01)
# Main processing based on image input
if path_to_image is not None:
image = read_image_as_np_array(path_to_image)
st.markdown("**Prediction**")
# Run predictions based on checkbox selections
if generate_detections_and_associations:
result = predict_detections_and_associations(
path_to_image,
input_character_detection_threshold,
input_panel_detection_threshold,
input_text_detection_threshold,
input_character_character_matching_threshold,
input_text_character_matching_threshold,
)
output = model.visualise_single_image_prediction(image, result)
st.image(output)
if generate_transcript:
ocr_results = predict_ocr(
path_to_image,
input_character_detection_threshold,
input_panel_detection_threshold,
input_text_detection_threshold,
input_character_character_matching_threshold,
input_text_character_matching_threshold,
)
transcript = model.generate_transcript_for_single_image(result, ocr_results[0])
st.text(transcript)
|