File size: 2,607 Bytes
b902873
 
 
 
 
 
 
efddbeb
b902873
efddbeb
b902873
 
 
 
 
 
 
 
 
 
efddbeb
b902873
efddbeb
 
 
b902873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efddbeb
b902873
 
 
 
 
efddbeb
 
b902873
efddbeb
 
b902873
efddbeb
 
b902873
efddbeb
 
 
 
 
 
 
 
b902873
efddbeb
 
 
b902873
 
 
efddbeb
b902873
 
 
efddbeb
 
b902873
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
This file contains the FastAPI application that serves the web interface and handles the API requests.
"""

import os
import io
import base64
from typing import Dict

import numpy as np
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from starlette.requests import Request
from PIL import Image
import uvicorn
from ultralytics import YOLO

from utils.predict_bounding_boxes import predict_bounding_boxes
from utils.manga_ocr import get_text_from_image
from utils.translate_manga import translate_manga
from utils.process_contour import process_contour
from utils.write_text_on_image import add_text


# Load the object detection model
best_model_path = "./model_creation/runs/detect/train5"
object_detection_model = YOLO(os.path.join(best_model_path, "weights/best.pt"))

app = FastAPI()

# Add CORS middleware
app.add_middleware(
	CORSMiddleware,
	allow_origins=["*"],
	allow_methods=["*"],
	allow_headers=["*"]
)

# Serve static files and templates
app.mount("/static", StaticFiles(directory="static"), name="static")
app.mount("/fonts", StaticFiles(directory="fonts"), name="fonts")
templates = Jinja2Templates(directory="templates")

@app.get("/")
def home(request: Request):
	return templates.TemplateResponse("index.html", {"request": request})


@app.post("/predict")
def predict(request: Dict):
	image = request["image"]

	# Decode base64-encoded image
	image = base64.b64decode(image)
	image = Image.open(io.BytesIO(image))
	image_path = "image.png"
	translated_image_path = "translated_image.png"

	# Save the image locally
	image.save(image_path)

	results = predict_bounding_boxes(object_detection_model, image_path)
	image = np.array(image)

	for result in results:
			x1, y1, x2, y2, _, _ = result
			detected_image = image[int(y1):int(y2), int(x1):int(x2)]
			im = Image.fromarray(np.uint8((detected_image)*255))
			text = get_text_from_image(im)
			detected_image, cont = process_contour(detected_image)
			text_translated = translate_manga(text)
			add_text(detected_image, text_translated, cont)

	# Display the translated image
	result_image = Image.fromarray(image, 'RGB')
	result_image.save(translated_image_path)

	# Convert the image to base64
	buff = io.BytesIO()
	result_image.save(buff, format="PNG")
	img_str = base64.b64encode(buff.getvalue()).decode("utf-8")

	# Clean up
	os.remove(image_path)
	os.remove(translated_image_path)

	return {"image": img_str}


if __name__ == '__main__':
	uvicorn.run('app:app', host='localhost', port=8000, reload=True)