Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,125 +1,84 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import AutoModel
|
3 |
from PIL import Image
|
4 |
-
import
|
5 |
import numpy as np
|
6 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
image = Image.open(image_path).convert("L").convert("RGB")
|
21 |
-
image = np.array(image)
|
22 |
-
return image
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
image_path,
|
27 |
-
character_detection_threshold,
|
28 |
-
panel_detection_threshold,
|
29 |
-
text_detection_threshold,
|
30 |
-
character_character_matching_threshold,
|
31 |
-
text_character_matching_threshold,
|
32 |
-
):
|
33 |
-
image = read_image_as_np_array(image_path)
|
34 |
-
with torch.no_grad():
|
35 |
-
result = model.predict_detections_and_associations(
|
36 |
-
[image],
|
37 |
-
character_detection_threshold=character_detection_threshold,
|
38 |
-
panel_detection_threshold=panel_detection_threshold,
|
39 |
-
text_detection_threshold=text_detection_threshold,
|
40 |
-
character_character_matching_threshold=character_character_matching_threshold,
|
41 |
-
text_character_matching_threshold=text_character_matching_threshold,
|
42 |
-
)[0]
|
43 |
-
return result
|
44 |
|
45 |
-
|
46 |
-
def predict_ocr(
|
47 |
-
image_path,
|
48 |
-
character_detection_threshold,
|
49 |
-
panel_detection_threshold,
|
50 |
-
text_detection_threshold,
|
51 |
-
character_character_matching_threshold,
|
52 |
-
text_character_matching_threshold,
|
53 |
-
):
|
54 |
-
if not generate_transcript:
|
55 |
-
return
|
56 |
-
image = read_image_as_np_array(image_path)
|
57 |
-
result = predict_detections_and_associations(
|
58 |
-
path_to_image,
|
59 |
-
character_detection_threshold,
|
60 |
-
panel_detection_threshold,
|
61 |
-
text_detection_threshold,
|
62 |
-
character_character_matching_threshold,
|
63 |
-
text_character_matching_threshold,
|
64 |
-
)
|
65 |
-
text_bboxes_for_all_images = [result["texts"]]
|
66 |
-
with torch.no_grad():
|
67 |
-
ocr_results = model.predict_ocr([image], text_bboxes_for_all_images)
|
68 |
-
return ocr_results
|
69 |
|
70 |
-
|
|
|
|
|
71 |
|
72 |
-
st.
|
73 |
-
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
st.sidebar.markdown("**Hyperparameters**")
|
79 |
-
input_character_detection_threshold = st.sidebar.slider('Character detection threshold', 0.0, 1.0, 0.30, step=0.01)
|
80 |
-
input_panel_detection_threshold = st.sidebar.slider('Panel detection threshold', 0.0, 1.0, 0.2, step=0.01)
|
81 |
-
input_text_detection_threshold = st.sidebar.slider('Text detection threshold', 0.0, 1.0, 0.25, step=0.01)
|
82 |
-
input_character_character_matching_threshold = st.sidebar.slider('Character-character matching threshold', 0.0, 1.0, 0.7, step=0.01)
|
83 |
-
input_text_character_matching_threshold = st.sidebar.slider('Text-character matching threshold', 0.0, 1.0, 0.4, step=0.01)
|
84 |
|
|
|
|
|
|
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
st.markdown("**Prediction**")
|
90 |
-
if generate_detections_and_associations or generate_transcript:
|
91 |
-
result = predict_detections_and_associations(
|
92 |
-
path_to_image,
|
93 |
-
input_character_detection_threshold,
|
94 |
-
input_panel_detection_threshold,
|
95 |
-
input_text_detection_threshold,
|
96 |
-
input_character_character_matching_threshold,
|
97 |
-
input_text_character_matching_threshold,
|
98 |
-
)
|
99 |
-
|
100 |
-
if generate_transcript:
|
101 |
-
ocr_results = predict_ocr(
|
102 |
-
path_to_image,
|
103 |
-
input_character_detection_threshold,
|
104 |
-
input_panel_detection_threshold,
|
105 |
-
input_text_detection_threshold,
|
106 |
-
input_character_character_matching_threshold,
|
107 |
-
input_text_character_matching_threshold,
|
108 |
-
)
|
109 |
-
|
110 |
-
if generate_detections_and_associations and generate_transcript:
|
111 |
-
col1, col2 = st.columns(2)
|
112 |
-
output = model.visualise_single_image_prediction(image, result)
|
113 |
-
col1.image(output)
|
114 |
-
text_bboxes_for_all_images = [result["texts"]]
|
115 |
-
ocr_results = model.predict_ocr([image], text_bboxes_for_all_images)
|
116 |
-
transcript = model.generate_transcript_for_single_image(result, ocr_results[0])
|
117 |
-
col2.text(transcript)
|
118 |
-
|
119 |
-
elif generate_detections_and_associations:
|
120 |
-
output = model.visualise_single_image_prediction(image, result)
|
121 |
-
st.image(output)
|
122 |
-
|
123 |
-
elif generate_transcript:
|
124 |
-
transcript = model.generate_transcript_for_single_image(result, ocr_results[0])
|
125 |
-
st.text(transcript)
|
|
|
1 |
import streamlit as st
|
|
|
2 |
from PIL import Image
|
3 |
+
import cv2
|
4 |
import numpy as np
|
5 |
+
import pytesseract
|
6 |
+
import torch
|
7 |
+
from torchvision import models, transforms
|
8 |
+
from transformers import DetrImageProcessor, DetrForObjectDetection
|
9 |
+
|
10 |
+
# Load a pre-trained DETR model for object detection
|
11 |
+
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
12 |
+
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
|
13 |
+
|
14 |
+
# Image transformations
|
15 |
+
transform = transforms.Compose([
|
16 |
+
transforms.ToTensor()
|
17 |
+
])
|
18 |
+
|
19 |
+
def detect_panels(image, threshold):
|
20 |
+
# Convert image to grayscale
|
21 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
22 |
+
edges = cv2.Canny(gray, 100, 200)
|
23 |
+
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
24 |
+
|
25 |
+
panels = []
|
26 |
+
for cnt in contours:
|
27 |
+
x, y, w, h = cv2.boundingRect(cnt)
|
28 |
+
if w > threshold and h > threshold:
|
29 |
+
panels.append({"coords": (x, y, w, h)})
|
30 |
+
return panels
|
31 |
+
|
32 |
+
def detect_characters(image, threshold):
|
33 |
+
# Apply DETR model to detect characters
|
34 |
+
inputs = processor(images=image, return_tensors="pt")
|
35 |
+
outputs = model(**inputs)
|
36 |
+
logits = outputs.logits
|
37 |
+
bboxes = outputs.pred_boxes
|
38 |
+
|
39 |
+
# Filter results
|
40 |
+
characters = []
|
41 |
+
for logit, box in zip(logits[0], bboxes[0]):
|
42 |
+
if logit.argmax() == 0: # Assuming '0' corresponds to 'character'
|
43 |
+
x, y, w, h = box * torch.tensor([image.width, image.height, image.width, image.height])
|
44 |
+
if w > threshold and h > threshold:
|
45 |
+
characters.append({"coords": (x.item(), y.item(), w.item(), h.item())})
|
46 |
+
return characters
|
47 |
|
48 |
+
def match_text_to_characters(image, panels):
|
49 |
+
text_matches = []
|
50 |
+
for panel in panels:
|
51 |
+
x, y, w, h = map(int, panel['coords'])
|
52 |
+
panel_img = image.crop((x, y, x+w, y+h))
|
53 |
+
text = pytesseract.image_to_string(panel_img)
|
54 |
+
text_matches.append({"panel": panel, "dialog": text})
|
55 |
+
return text_matches
|
56 |
|
57 |
+
def match_characters(characters):
|
58 |
+
coords = np.array([((c['coords'][0] + c['coords'][2]) / 2, (c['coords'][1] + c['coords'][3]) / 2) for c in characters])
|
59 |
+
clustering = DBSCAN(eps=20, min_samples=1).fit(coords)
|
60 |
+
character_matches = [{"character": c, "cluster": cluster} for c, cluster in zip(characters, clustering.labels_)]
|
61 |
+
return character_matches
|
|
|
|
|
|
|
62 |
|
63 |
+
# Streamlit UI
|
64 |
+
st.title("Advanced Manga Reader")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
uploaded_file = st.file_uploader("Upload a manga page", type=["jpg", "png"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
if uploaded_file is not None:
|
69 |
+
image = Image.open(uploaded_file).convert('RGB')
|
70 |
+
st.image(image, caption='Uploaded Manga Page', use_column_width=True)
|
71 |
|
72 |
+
panel_threshold = st.slider("Panel Detection Threshold", 0, 500, 100)
|
73 |
+
character_threshold = st.slider("Character Detection Threshold", 0.0, 50.0, 10.0)
|
74 |
|
75 |
+
panels = detect_panels(np.array(image), panel_threshold)
|
76 |
+
characters = detect_characters(image, character_threshold)
|
77 |
+
dialogues = match_text_to_characters(image, panels)
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
st.write("Detected Panels:", panels)
|
80 |
+
st.write("Detected Characters:", characters)
|
81 |
+
st.write("Dialogues:", dialogues)
|
82 |
|
83 |
+
for dialogue in dialogues:
|
84 |
+
st.write(f"Panel: {dialogue['dialog']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|