Spaces:
Runtime error
Runtime error
import streamlit as st | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
import pytesseract | |
import torch | |
from torchvision import models, transforms | |
from transformers import DetrImageProcessor, DetrForObjectDetection | |
# Load a pre-trained DETR model for object detection | |
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
# Image transformations | |
transform = transforms.Compose([ | |
transforms.ToTensor() | |
]) | |
def detect_panels(image, threshold): | |
# Convert image to grayscale | |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
edges = cv2.Canny(gray, 100, 200) | |
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
panels = [] | |
for cnt in contours: | |
x, y, w, h = cv2.boundingRect(cnt) | |
if w > threshold and h > threshold: | |
panels.append({"coords": (x, y, w, h)}) | |
return panels | |
def detect_characters(image, threshold): | |
# Apply DETR model to detect characters | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
logits = outputs.logits | |
bboxes = outputs.pred_boxes | |
# Filter results | |
characters = [] | |
for logit, box in zip(logits[0], bboxes[0]): | |
if logit.argmax() == 0: # Assuming '0' corresponds to 'character' | |
x, y, w, h = box * torch.tensor([image.width, image.height, image.width, image.height]) | |
if w > threshold and h > threshold: | |
characters.append({"coords": (x.item(), y.item(), w.item(), h.item())}) | |
return characters | |
def match_text_to_characters(image, panels): | |
text_matches = [] | |
for panel in panels: | |
x, y, w, h = map(int, panel['coords']) | |
panel_img = image.crop((x, y, x+w, y+h)) | |
text = pytesseract.image_to_string(panel_img) | |
text_matches.append({"panel": panel, "dialog": text}) | |
return text_matches | |
def match_characters(characters): | |
coords = np.array([((c['coords'][0] + c['coords'][2]) / 2, (c['coords'][1] + c['coords'][3]) / 2) for c in characters]) | |
clustering = DBSCAN(eps=20, min_samples=1).fit(coords) | |
character_matches = [{"character": c, "cluster": cluster} for c, cluster in zip(characters, clustering.labels_)] | |
return character_matches | |
# Streamlit UI | |
st.title("Advanced Manga Reader") | |
uploaded_file = st.file_uploader("Upload a manga page", type=["jpg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file).convert('RGB') | |
st.image(image, caption='Uploaded Manga Page', use_column_width=True) | |
panel_threshold = st.slider("Panel Detection Threshold", 0, 500, 100) | |
character_threshold = st.slider("Character Detection Threshold", 0.0, 50.0, 10.0) | |
panels = detect_panels(np.array(image), panel_threshold) | |
characters = detect_characters(image, character_threshold) | |
dialogues = match_text_to_characters(image, panels) | |
st.write("Detected Panels:", panels) | |
st.write("Detected Characters:", characters) | |
st.write("Dialogues:", dialogues) | |
for dialogue in dialogues: | |
st.write(f"Panel: {dialogue['dialog']}") |