daranaka commited on
Commit
585854e
1 Parent(s): 6552ee7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -81
app.py CHANGED
@@ -1,84 +1,143 @@
1
  import streamlit as st
 
2
  from PIL import Image
3
- import cv2
4
- import numpy as np
5
- import pytesseract
6
  import torch
7
- from torchvision import models, transforms
8
- from transformers import DetrImageProcessor, DetrForObjectDetection
9
-
10
- # Load a pre-trained DETR model for object detection
11
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
12
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
13
-
14
- # Image transformations
15
- transform = transforms.Compose([
16
- transforms.ToTensor()
17
- ])
18
-
19
- def detect_panels(image, threshold):
20
- # Convert image to grayscale
21
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
22
- edges = cv2.Canny(gray, 100, 200)
23
- contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
24
-
25
- panels = []
26
- for cnt in contours:
27
- x, y, w, h = cv2.boundingRect(cnt)
28
- if w > threshold and h > threshold:
29
- panels.append({"coords": (x, y, w, h)})
30
- return panels
31
-
32
- def detect_characters(image, threshold):
33
- # Apply DETR model to detect characters
34
- inputs = processor(images=image, return_tensors="pt")
35
- outputs = model(**inputs)
36
- logits = outputs.logits
37
- bboxes = outputs.pred_boxes
38
-
39
- # Filter results
40
- characters = []
41
- for logit, box in zip(logits[0], bboxes[0]):
42
- if logit.argmax() == 0: # Assuming '0' corresponds to 'character'
43
- x, y, w, h = box * torch.tensor([image.width, image.height, image.width, image.height])
44
- if w > threshold and h > threshold:
45
- characters.append({"coords": (x.item(), y.item(), w.item(), h.item())})
46
- return characters
47
-
48
- def match_text_to_characters(image, panels):
49
- text_matches = []
50
- for panel in panels:
51
- x, y, w, h = map(int, panel['coords'])
52
- panel_img = image.crop((x, y, x+w, y+h))
53
- text = pytesseract.image_to_string(panel_img)
54
- text_matches.append({"panel": panel, "dialog": text})
55
- return text_matches
56
-
57
- def match_characters(characters):
58
- coords = np.array([((c['coords'][0] + c['coords'][2]) / 2, (c['coords'][1] + c['coords'][3]) / 2) for c in characters])
59
- clustering = DBSCAN(eps=20, min_samples=1).fit(coords)
60
- character_matches = [{"character": c, "cluster": cluster} for c, cluster in zip(characters, clustering.labels_)]
61
- return character_matches
62
-
63
- # Streamlit UI
64
- st.title("Advanced Manga Reader")
65
-
66
- uploaded_file = st.file_uploader("Upload a manga page", type=["jpg", "png"])
67
-
68
- if uploaded_file is not None:
69
- image = Image.open(uploaded_file).convert('RGB')
70
- st.image(image, caption='Uploaded Manga Page', use_column_width=True)
71
-
72
- panel_threshold = st.slider("Panel Detection Threshold", 0, 500, 100)
73
- character_threshold = st.slider("Character Detection Threshold", 0.0, 50.0, 10.0)
74
-
75
- panels = detect_panels(np.array(image), panel_threshold)
76
- characters = detect_characters(image, character_threshold)
77
- dialogues = match_text_to_characters(image, panels)
78
-
79
- st.write("Detected Panels:", panels)
80
- st.write("Detected Characters:", characters)
81
- st.write("Dialogues:", dialogues)
82
-
83
- for dialogue in dialogues:
84
- st.write(f"Panel: {dialogue['dialog']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoModel
3
  from PIL import Image
 
 
 
4
  import torch
5
+ import numpy as np
6
+ import urllib.request
7
+
8
+ # Initialize session state for memory if not already
9
+ if "memory" not in st.session_state:
10
+ st.session_state.memory = {"characters": {}, "transcript": ""}
11
+
12
+ @st.cache_resource
13
+ def load_model():
14
+ model = AutoModel.from_pretrained("ragavsachdeva/magi", trust_remote_code=True)
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model.to(device)
17
+ return model
18
+
19
+ @st.cache_data
20
+ def read_image_as_np_array(image_path):
21
+ if "http" in image_path:
22
+ image = Image.open(urllib.request.urlopen(image_path)).convert("L").convert("RGB")
23
+ else:
24
+ image = Image.open(image_path).convert("L").convert("RGB")
25
+ image = np.array(image)
26
+ return image
27
+
28
+ @st.cache_data
29
+ def predict_detections_and_associations(
30
+ image_path,
31
+ character_detection_threshold,
32
+ panel_detection_threshold,
33
+ text_detection_threshold,
34
+ character_character_matching_threshold,
35
+ text_character_matching_threshold,
36
+ ):
37
+ image = read_image_as_np_array(image_path)
38
+ with torch.no_grad():
39
+ result = model.predict_detections_and_associations(
40
+ [image],
41
+ character_detection_threshold=character_detection_threshold,
42
+ panel_detection_threshold=panel_detection_threshold,
43
+ text_detection_threshold=text_detection_threshold,
44
+ character_character_matching_threshold=character_character_matching_threshold,
45
+ text_character_matching_threshold=text_character_matching_threshold,
46
+ )[0]
47
+ return result
48
+
49
+ @st.cache_data
50
+ def predict_ocr(
51
+ image_path,
52
+ character_detection_threshold,
53
+ panel_detection_threshold,
54
+ text_detection_threshold,
55
+ character_character_matching_threshold,
56
+ text_character_matching_threshold,
57
+ ):
58
+ if not generate_transcript:
59
+ return
60
+ image = read_image_as_np_array(image_path)
61
+ result = predict_detections_and_associations(
62
+ image_path,
63
+ character_detection_threshold,
64
+ panel_detection_threshold,
65
+ text_detection_threshold,
66
+ character_character_matching_threshold,
67
+ text_character_matching_threshold,
68
+ )
69
+ text_bboxes_for_all_images = [result["texts"]]
70
+ with torch.no_grad():
71
+ ocr_results = model.predict_ocr([image], text_bboxes_for_all_images)
72
+ return ocr_results
73
+
74
+ def clear_memory():
75
+ st.session_state.memory = {"characters": {}, "transcript": ""}
76
+ st.write("Memory cleared.")
77
+
78
+ model = load_model()
79
+
80
+ # Display header and UI components
81
+ st.markdown(""" <style> ... styles here ... </style> """, unsafe_allow_html=True)
82
+ path_to_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
83
+
84
+ # Memory control button
85
+ st.button("Clear Memory", on_click=clear_memory)
86
+
87
+ st.sidebar.markdown("**Mode**")
88
+ generate_detections_and_associations = st.sidebar.toggle("Generate detections and associations", True)
89
+ generate_transcript = st.sidebar.toggle("Generate transcript (slower)", False)
90
+
91
+ st.sidebar.markdown("**Hyperparameters**")
92
+ input_character_detection_threshold = st.sidebar.slider('Character detection threshold', 0.0, 1.0, 0.30, step=0.01)
93
+ input_panel_detection_threshold = st.sidebar.slider('Panel detection threshold', 0.0, 1.0, 0.2, step=0.01)
94
+ input_text_detection_threshold = st.sidebar.slider('Text detection threshold', 0.0, 1.0, 0.25, step=0.01)
95
+ input_character_character_matching_threshold = st.sidebar.slider('Character-character matching threshold', 0.0, 1.0, 0.7, step=0.01)
96
+ input_text_character_matching_threshold = st.sidebar.slider('Text-character matching threshold', 0.0, 1.0, 0.4, step=0.01)
97
+
98
+ if path_to_image is not None:
99
+ image = read_image_as_np_array(path_to_image)
100
+ st.markdown("**Prediction**")
101
+
102
+ if generate_detections_and_associations or generate_transcript:
103
+ result = predict_detections_and_associations(
104
+ path_to_image,
105
+ input_character_detection_threshold,
106
+ input_panel_detection_threshold,
107
+ input_text_detection_threshold,
108
+ input_character_character_matching_threshold,
109
+ input_text_character_matching_threshold,
110
+ )
111
+
112
+ if generate_transcript:
113
+ ocr_results = predict_ocr(
114
+ path_to_image,
115
+ input_character_detection_threshold,
116
+ input_panel_detection_threshold,
117
+ input_text_detection_threshold,
118
+ input_character_character_matching_threshold,
119
+ input_text_character_matching_threshold,
120
+ )
121
+
122
+ # Append new characters and transcript to memory
123
+ if generate_detections_and_associations:
124
+ output = model.visualise_single_image_prediction(image, result)
125
+ st.image(output)
126
+ # Update character memory based on detected characters
127
+ detected_characters = result.get("characters", {})
128
+ st.session_state.memory["characters"].update(detected_characters)
129
+
130
+ # Append the current transcript to the ongoing transcript in memory
131
+ transcript = model.generate_transcript_for_single_image(result, ocr_results[0])
132
+ st.session_state.memory["transcript"] += transcript + "\n"
133
+
134
+ # Display the cumulative transcript from memory
135
+ st.text(st.session_state.memory["transcript"])
136
+
137
+ elif generate_detections_and_associations:
138
+ output = model.visualise_single_image_prediction(image, result)
139
+ st.image(output)
140
+
141
+ elif generate_transcript:
142
+ # Display the cumulative transcript
143
+ st.text(st.session_state.memory["transcript"])