daranaka's picture
Update app.py
83a0630 verified
raw
history blame
3.14 kB
import streamlit as st
from PIL import Image
import cv2
import numpy as np
import pytesseract
import torch
from torchvision import models, transforms
from transformers import DetrImageProcessor, DetrForObjectDetection
# Load a pre-trained DETR model for object detection
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
# Image transformations
transform = transforms.Compose([
transforms.ToTensor()
])
def detect_panels(image, threshold):
# Convert image to grayscale
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
edges = cv2.Canny(gray, 100, 200)
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
panels = []
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
if w > threshold and h > threshold:
panels.append({"coords": (x, y, w, h)})
return panels
def detect_characters(image, threshold):
# Apply DETR model to detect characters
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
bboxes = outputs.pred_boxes
# Filter results
characters = []
for logit, box in zip(logits[0], bboxes[0]):
if logit.argmax() == 0: # Assuming '0' corresponds to 'character'
x, y, w, h = box * torch.tensor([image.width, image.height, image.width, image.height])
if w > threshold and h > threshold:
characters.append({"coords": (x.item(), y.item(), w.item(), h.item())})
return characters
def match_text_to_characters(image, panels):
text_matches = []
for panel in panels:
x, y, w, h = map(int, panel['coords'])
panel_img = image.crop((x, y, x+w, y+h))
text = pytesseract.image_to_string(panel_img)
text_matches.append({"panel": panel, "dialog": text})
return text_matches
def match_characters(characters):
coords = np.array([((c['coords'][0] + c['coords'][2]) / 2, (c['coords'][1] + c['coords'][3]) / 2) for c in characters])
clustering = DBSCAN(eps=20, min_samples=1).fit(coords)
character_matches = [{"character": c, "cluster": cluster} for c, cluster in zip(characters, clustering.labels_)]
return character_matches
# Streamlit UI
st.title("Advanced Manga Reader")
uploaded_file = st.file_uploader("Upload a manga page", type=["jpg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption='Uploaded Manga Page', use_column_width=True)
panel_threshold = st.slider("Panel Detection Threshold", 0, 500, 100)
character_threshold = st.slider("Character Detection Threshold", 0.0, 50.0, 10.0)
panels = detect_panels(np.array(image), panel_threshold)
characters = detect_characters(image, character_threshold)
dialogues = match_text_to_characters(image, panels)
st.write("Detected Panels:", panels)
st.write("Detected Characters:", characters)
st.write("Dialogues:", dialogues)
for dialogue in dialogues:
st.write(f"Panel: {dialogue['dialog']}")