daranaka commited on
Commit
83a0630
1 Parent(s): 580bfba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -113
app.py CHANGED
@@ -1,125 +1,84 @@
1
  import streamlit as st
2
- from transformers import AutoModel
3
  from PIL import Image
4
- import torch
5
  import numpy as np
6
- import urllib.request
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- @st.cache_resource
9
- def load_model():
10
- model = AutoModel.from_pretrained("ragavsachdeva/magi", trust_remote_code=True)
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- model.to(device)
13
- return model
 
 
14
 
15
- @st.cache_data
16
- def read_image_as_np_array(image_path):
17
- if "http" in image_path:
18
- image = Image.open(urllib.request.urlopen(image_path)).convert("L").convert("RGB")
19
- else:
20
- image = Image.open(image_path).convert("L").convert("RGB")
21
- image = np.array(image)
22
- return image
23
 
24
- @st.cache_data
25
- def predict_detections_and_associations(
26
- image_path,
27
- character_detection_threshold,
28
- panel_detection_threshold,
29
- text_detection_threshold,
30
- character_character_matching_threshold,
31
- text_character_matching_threshold,
32
- ):
33
- image = read_image_as_np_array(image_path)
34
- with torch.no_grad():
35
- result = model.predict_detections_and_associations(
36
- [image],
37
- character_detection_threshold=character_detection_threshold,
38
- panel_detection_threshold=panel_detection_threshold,
39
- text_detection_threshold=text_detection_threshold,
40
- character_character_matching_threshold=character_character_matching_threshold,
41
- text_character_matching_threshold=text_character_matching_threshold,
42
- )[0]
43
- return result
44
 
45
- @st.cache_data
46
- def predict_ocr(
47
- image_path,
48
- character_detection_threshold,
49
- panel_detection_threshold,
50
- text_detection_threshold,
51
- character_character_matching_threshold,
52
- text_character_matching_threshold,
53
- ):
54
- if not generate_transcript:
55
- return
56
- image = read_image_as_np_array(image_path)
57
- result = predict_detections_and_associations(
58
- path_to_image,
59
- character_detection_threshold,
60
- panel_detection_threshold,
61
- text_detection_threshold,
62
- character_character_matching_threshold,
63
- text_character_matching_threshold,
64
- )
65
- text_bboxes_for_all_images = [result["texts"]]
66
- with torch.no_grad():
67
- ocr_results = model.predict_ocr([image], text_bboxes_for_all_images)
68
- return ocr_results
69
 
70
- model = load_model()
 
 
71
 
72
- st.markdown(""" <style> .title-container { background-color: #0d1117; padding: 20px; border-radius: 10px; margin: 20px; } .title { font-size: 2em; text-align: center; color: #fff; font-family: 'Comic Sans MS', cursive; text-transform: uppercase; letter-spacing: 0.1em; padding: 0.5em 0 0.2em; background: 0 0; } .title span { background: -webkit-linear-gradient(45deg, #6495ed, #4169e1); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .subheading { font-size: 1.5em; text-align: center; color: #ddd; font-family: 'Comic Sans MS', cursive; } .affil, .authors { font-size: 1em; text-align: center; color: #ddd; font-family: 'Comic Sans MS', cursive; } .authors { padding-top: 1em; } </style> <div class='title-container'> <div class='title'> The <span>Ma</span>n<span>g</span>a Wh<span>i</span>sperer </div> <div class='subheading'> Automatically Generating Transcriptions for Comics </div> <div class='authors'> Ragav Sachdeva and Andrew Zisserman </div> <div class='affil'> University of Oxford </div> </div>""", unsafe_allow_html=True)
73
- path_to_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
74
 
75
- st.sidebar.markdown("**Mode**")
76
- generate_detections_and_associations = st.sidebar.toggle("Generate detections and associations", True)
77
- generate_transcript = st.sidebar.toggle("Generate transcript (slower)", False)
78
- st.sidebar.markdown("**Hyperparameters**")
79
- input_character_detection_threshold = st.sidebar.slider('Character detection threshold', 0.0, 1.0, 0.30, step=0.01)
80
- input_panel_detection_threshold = st.sidebar.slider('Panel detection threshold', 0.0, 1.0, 0.2, step=0.01)
81
- input_text_detection_threshold = st.sidebar.slider('Text detection threshold', 0.0, 1.0, 0.25, step=0.01)
82
- input_character_character_matching_threshold = st.sidebar.slider('Character-character matching threshold', 0.0, 1.0, 0.7, step=0.01)
83
- input_text_character_matching_threshold = st.sidebar.slider('Text-character matching threshold', 0.0, 1.0, 0.4, step=0.01)
84
 
 
 
 
85
 
86
- if path_to_image is not None:
87
- image = read_image_as_np_array(path_to_image)
88
-
89
- st.markdown("**Prediction**")
90
- if generate_detections_and_associations or generate_transcript:
91
- result = predict_detections_and_associations(
92
- path_to_image,
93
- input_character_detection_threshold,
94
- input_panel_detection_threshold,
95
- input_text_detection_threshold,
96
- input_character_character_matching_threshold,
97
- input_text_character_matching_threshold,
98
- )
99
-
100
- if generate_transcript:
101
- ocr_results = predict_ocr(
102
- path_to_image,
103
- input_character_detection_threshold,
104
- input_panel_detection_threshold,
105
- input_text_detection_threshold,
106
- input_character_character_matching_threshold,
107
- input_text_character_matching_threshold,
108
- )
109
-
110
- if generate_detections_and_associations and generate_transcript:
111
- col1, col2 = st.columns(2)
112
- output = model.visualise_single_image_prediction(image, result)
113
- col1.image(output)
114
- text_bboxes_for_all_images = [result["texts"]]
115
- ocr_results = model.predict_ocr([image], text_bboxes_for_all_images)
116
- transcript = model.generate_transcript_for_single_image(result, ocr_results[0])
117
- col2.text(transcript)
118
-
119
- elif generate_detections_and_associations:
120
- output = model.visualise_single_image_prediction(image, result)
121
- st.image(output)
122
-
123
- elif generate_transcript:
124
- transcript = model.generate_transcript_for_single_image(result, ocr_results[0])
125
- st.text(transcript)
 
1
  import streamlit as st
 
2
  from PIL import Image
3
+ import cv2
4
  import numpy as np
5
+ import pytesseract
6
+ import torch
7
+ from torchvision import models, transforms
8
+ from transformers import DetrImageProcessor, DetrForObjectDetection
9
+
10
+ # Load a pre-trained DETR model for object detection
11
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
12
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
13
+
14
+ # Image transformations
15
+ transform = transforms.Compose([
16
+ transforms.ToTensor()
17
+ ])
18
+
19
+ def detect_panels(image, threshold):
20
+ # Convert image to grayscale
21
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
22
+ edges = cv2.Canny(gray, 100, 200)
23
+ contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
24
+
25
+ panels = []
26
+ for cnt in contours:
27
+ x, y, w, h = cv2.boundingRect(cnt)
28
+ if w > threshold and h > threshold:
29
+ panels.append({"coords": (x, y, w, h)})
30
+ return panels
31
+
32
+ def detect_characters(image, threshold):
33
+ # Apply DETR model to detect characters
34
+ inputs = processor(images=image, return_tensors="pt")
35
+ outputs = model(**inputs)
36
+ logits = outputs.logits
37
+ bboxes = outputs.pred_boxes
38
+
39
+ # Filter results
40
+ characters = []
41
+ for logit, box in zip(logits[0], bboxes[0]):
42
+ if logit.argmax() == 0: # Assuming '0' corresponds to 'character'
43
+ x, y, w, h = box * torch.tensor([image.width, image.height, image.width, image.height])
44
+ if w > threshold and h > threshold:
45
+ characters.append({"coords": (x.item(), y.item(), w.item(), h.item())})
46
+ return characters
47
 
48
+ def match_text_to_characters(image, panels):
49
+ text_matches = []
50
+ for panel in panels:
51
+ x, y, w, h = map(int, panel['coords'])
52
+ panel_img = image.crop((x, y, x+w, y+h))
53
+ text = pytesseract.image_to_string(panel_img)
54
+ text_matches.append({"panel": panel, "dialog": text})
55
+ return text_matches
56
 
57
+ def match_characters(characters):
58
+ coords = np.array([((c['coords'][0] + c['coords'][2]) / 2, (c['coords'][1] + c['coords'][3]) / 2) for c in characters])
59
+ clustering = DBSCAN(eps=20, min_samples=1).fit(coords)
60
+ character_matches = [{"character": c, "cluster": cluster} for c, cluster in zip(characters, clustering.labels_)]
61
+ return character_matches
 
 
 
62
 
63
+ # Streamlit UI
64
+ st.title("Advanced Manga Reader")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ uploaded_file = st.file_uploader("Upload a manga page", type=["jpg", "png"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ if uploaded_file is not None:
69
+ image = Image.open(uploaded_file).convert('RGB')
70
+ st.image(image, caption='Uploaded Manga Page', use_column_width=True)
71
 
72
+ panel_threshold = st.slider("Panel Detection Threshold", 0, 500, 100)
73
+ character_threshold = st.slider("Character Detection Threshold", 0.0, 50.0, 10.0)
74
 
75
+ panels = detect_panels(np.array(image), panel_threshold)
76
+ characters = detect_characters(image, character_threshold)
77
+ dialogues = match_text_to_characters(image, panels)
 
 
 
 
 
 
78
 
79
+ st.write("Detected Panels:", panels)
80
+ st.write("Detected Characters:", characters)
81
+ st.write("Dialogues:", dialogues)
82
 
83
+ for dialogue in dialogues:
84
+ st.write(f"Panel: {dialogue['dialog']}")