testapi / manga_translator /manga_translator.py
Sunday01's picture
up
9dce458
import asyncio
import base64
import io
import cv2
from aiohttp.web_middlewares import middleware
from omegaconf import OmegaConf
import langcodes
import langdetect
import requests
import os
import re
import torch
import time
import logging
import numpy as np
from PIL import Image
from typing import List, Tuple, Union
from aiohttp import web
from marshmallow import Schema, fields, ValidationError
from manga_translator.utils.threading import Throttler
from .args import DEFAULT_ARGS, translator_chain
from .utils import (
BASE_PATH,
LANGUAGE_ORIENTATION_PRESETS,
ModelWrapper,
Context,
PriorityLock,
load_image,
dump_image,
replace_prefix,
visualize_textblocks,
add_file_logger,
remove_file_logger,
is_valuable_text,
rgb2hex,
hex2rgb,
get_color_name,
natural_sort,
sort_regions,
)
from .detection import DETECTORS, dispatch as dispatch_detection, prepare as prepare_detection
from .upscaling import dispatch as dispatch_upscaling, prepare as prepare_upscaling, UPSCALERS
from .ocr import OCRS, dispatch as dispatch_ocr, prepare as prepare_ocr
from .textline_merge import dispatch as dispatch_textline_merge
from .mask_refinement import dispatch as dispatch_mask_refinement
from .inpainting import INPAINTERS, dispatch as dispatch_inpainting, prepare as prepare_inpainting
from .translators import (
TRANSLATORS,
VALID_LANGUAGES,
LANGDETECT_MAP,
LanguageUnsupportedException,
TranslatorChain,
dispatch as dispatch_translation,
prepare as prepare_translation,
)
from .colorization import dispatch as dispatch_colorization, prepare as prepare_colorization
from .rendering import dispatch as dispatch_rendering, dispatch_eng_render
from .save import save_result
# Will be overwritten by __main__.py if module is being run directly (with python -m)
logger = logging.getLogger('manga_translator')
def set_main_logger(l):
global logger
logger = l
class TranslationInterrupt(Exception):
"""
Can be raised from within a progress hook to prematurely terminate
the translation.
"""
pass
class MangaTranslator():
def __init__(self, params: dict = None):
self._progress_hooks = []
self._add_logger_hook()
params = params or {}
self.parse_init_params(params)
self.result_sub_folder = ''
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
def parse_init_params(self, params: dict):
self.verbose = params.get('verbose', False)
self.ignore_errors = params.get('ignore_errors', False)
# check mps for apple silicon or cuda for nvidia
device = 'mps' if torch.backends.mps.is_available() else 'cuda'
self.device = device if params.get('use_gpu', False) else 'cpu'
self._gpu_limited_memory = params.get('use_gpu_limited', False)
if self._gpu_limited_memory and not self.using_gpu:
self.device = device
if self.using_gpu and ( not torch.cuda.is_available() and not torch.backends.mps.is_available()):
raise Exception(
'CUDA or Metal compatible device could not be found in torch whilst --use-gpu args was set.\n' \
'Is the correct pytorch version installed? (See https://pytorch.org/)')
if params.get('model_dir'):
ModelWrapper._MODEL_DIR = params.get('model_dir')
self.kernel_size=int(params.get('kernel_size'))
os.environ['INPAINTING_PRECISION'] = params.get('inpainting_precision', 'fp32')
@property
def using_gpu(self):
return self.device.startswith('cuda') or self.device == 'mps'
async def translate_path(self, path: str, dest: str = None, params: dict = None):
"""
Translates an image or folder (recursively) specified through the path.
"""
if not os.path.exists(path):
raise FileNotFoundError(path)
path = os.path.abspath(os.path.expanduser(path))
dest = os.path.abspath(os.path.expanduser(dest)) if dest else ''
params = params or {}
# Handle format
file_ext = params.get('format')
if params.get('save_quality', 100) < 100:
if not params.get('format'):
file_ext = 'jpg'
elif params.get('format') != 'jpg':
raise ValueError('--save-quality of lower than 100 is only supported for .jpg files')
if os.path.isfile(path):
# Determine destination file path
if not dest:
# Use the same folder as the source
p, ext = os.path.splitext(path)
_dest = f'{p}-translated.{file_ext or ext[1:]}'
elif not os.path.basename(dest):
p, ext = os.path.splitext(os.path.basename(path))
# If the folders differ use the original filename from the source
if os.path.dirname(path) != dest:
_dest = os.path.join(dest, f'{p}.{file_ext or ext[1:]}')
else:
_dest = os.path.join(dest, f'{p}-translated.{file_ext or ext[1:]}')
else:
p, ext = os.path.splitext(dest)
_dest = f'{p}.{file_ext or ext[1:]}'
await self.translate_file(path, _dest, params)
elif os.path.isdir(path):
# Determine destination folder path
if path[-1] == '\\' or path[-1] == '/':
path = path[:-1]
_dest = dest or path + '-translated'
if os.path.exists(_dest) and not os.path.isdir(_dest):
raise FileExistsError(_dest)
translated_count = 0
for root, subdirs, files in os.walk(path):
files = natural_sort(files)
dest_root = replace_prefix(root, path, _dest)
os.makedirs(dest_root, exist_ok=True)
for f in files:
if f.lower() == '.thumb':
continue
file_path = os.path.join(root, f)
output_dest = replace_prefix(file_path, path, _dest)
p, ext = os.path.splitext(output_dest)
output_dest = f'{p}.{file_ext or ext[1:]}'
if await self.translate_file(file_path, output_dest, params):
translated_count += 1
if translated_count == 0:
logger.info('No further untranslated files found. Use --overwrite to write over existing translations.')
else:
logger.info(f'Done. Translated {translated_count} image{"" if translated_count == 1 else "s"}')
async def translate_file(self, path: str, dest: str, params: dict):
if not params.get('overwrite') and os.path.exists(dest):
logger.info(
f'Skipping as already translated: "{dest}". Use --overwrite to overwrite existing translations.')
await self._report_progress('saved', True)
return True
logger.info(f'Translating: "{path}"')
# Turn dict to context to make values also accessible through params.<property>
params = params or {}
ctx = Context(**params)
self._preprocess_params(ctx)
attempts = 0
while ctx.attempts == -1 or attempts < ctx.attempts + 1:
if attempts > 0:
logger.info(f'Retrying translation! Attempt {attempts}'
+ (f' of {ctx.attempts}' if ctx.attempts != -1 else ''))
try:
return await self._translate_file(path, dest, ctx)
except TranslationInterrupt:
break
except Exception as e:
if isinstance(e, LanguageUnsupportedException):
await self._report_progress('error-lang', True)
else:
await self._report_progress('error', True)
if not self.ignore_errors and not (ctx.attempts == -1 or attempts < ctx.attempts):
raise
else:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if self.verbose else None)
attempts += 1
return False
async def _translate_file(self, path: str, dest: str, ctx: Context):
if path.endswith('.txt'):
with open(path, 'r') as f:
queries = f.read().split('\n')
translated_sentences = \
await dispatch_translation(ctx.translator, queries, ctx.use_mtpe, ctx,
'cpu' if self._gpu_limited_memory else self.device)
p, ext = os.path.splitext(dest)
if ext != '.txt':
dest = p + '.txt'
logger.info(f'Saving "{dest}"')
with open(dest, 'w') as f:
f.write('\n'.join(translated_sentences))
return True
# TODO: Add .gif handler
else: # Treat as image
try:
img = Image.open(path)
img.verify()
img = Image.open(path)
except Exception:
logger.warn(f'Failed to open image: {path}')
return False
ctx = await self.translate(img, ctx)
result = ctx.result
# Save result
if ctx.skip_no_text and not ctx.text_regions:
logger.debug('Not saving due to --skip-no-text')
return True
if result:
logger.info(f'Saving "{dest}"')
save_result(result, dest, ctx)
await self._report_progress('saved', True)
if ctx.save_text or ctx.save_text_file or ctx.prep_manual:
if ctx.prep_manual:
# Save original image next to translated
p, ext = os.path.splitext(dest)
img_filename = p + '-orig' + ext
img_path = os.path.join(os.path.dirname(dest), img_filename)
img.save(img_path, quality=ctx.save_quality)
if ctx.text_regions:
self._save_text_to_file(path, ctx)
return True
return False
async def translate(self, image: Image.Image, params: Union[dict, Context] = None) -> Context:
"""
Translates a PIL image from a manga. Returns dict with result and intermediates of translation.
Default params are taken from args.py.
```py
translation_dict = await translator.translate(image)
result = translation_dict.result
```
"""
# TODO: Take list of images to speed up batch processing
if not isinstance(params, Context):
params = params or {}
ctx = Context(**params)
self._preprocess_params(ctx)
else:
ctx = params
ctx.input = image
ctx.result = None
# preload and download models (not strictly necessary, remove to lazy load)
logger.info('Loading models')
if ctx.upscale_ratio:
await prepare_upscaling(ctx.upscaler)
await prepare_detection(ctx.detector)
await prepare_ocr(ctx.ocr, self.device)
await prepare_inpainting(ctx.inpainter, self.device)
await prepare_translation(ctx.translator)
if ctx.colorizer:
await prepare_colorization(ctx.colorizer)
# translate
return await self._translate(ctx)
def _preprocess_params(self, ctx: Context):
# params auto completion
# TODO: Move args into ctx.args and only calculate once, or just copy into ctx
for arg in DEFAULT_ARGS:
ctx.setdefault(arg, DEFAULT_ARGS[arg])
if 'direction' not in ctx:
if ctx.force_horizontal:
ctx.direction = 'h'
elif ctx.force_vertical:
ctx.direction = 'v'
else:
ctx.direction = 'auto'
if 'alignment' not in ctx:
if ctx.align_left:
ctx.alignment = 'left'
elif ctx.align_center:
ctx.alignment = 'center'
elif ctx.align_right:
ctx.alignment = 'right'
else:
ctx.alignment = 'auto'
if ctx.prep_manual:
ctx.renderer = 'none'
ctx.setdefault('renderer', 'manga2eng' if ctx.manga2eng else 'default')
if ctx.selective_translation is not None:
ctx.selective_translation.target_lang = ctx.target_lang
ctx.translator = ctx.selective_translation
elif ctx.translator_chain is not None:
ctx.target_lang = ctx.translator_chain.langs[-1]
ctx.translator = ctx.translator_chain
else:
ctx.translator = TranslatorChain(f'{ctx.translator}:{ctx.target_lang}')
if ctx.gpt_config:
ctx.gpt_config = OmegaConf.load(ctx.gpt_config)
if ctx.filter_text:
ctx.filter_text = re.compile(ctx.filter_text)
if ctx.font_color:
colors = ctx.font_color.split(':')
try:
ctx.font_color_fg = hex2rgb(colors[0])
ctx.font_color_bg = hex2rgb(colors[1]) if len(colors) > 1 else None
except:
raise Exception(f'Invalid --font-color value: {ctx.font_color}. Use a hex value such as FF0000')
async def _translate(self, ctx: Context) -> Context:
# -- Colorization
if ctx.colorizer:
await self._report_progress('colorizing')
ctx.img_colorized = await self._run_colorizer(ctx)
else:
ctx.img_colorized = ctx.input
# -- Upscaling
# The default text detector doesn't work very well on smaller images, might want to
# consider adding automatic upscaling on certain kinds of small images.
if ctx.upscale_ratio:
await self._report_progress('upscaling')
ctx.upscaled = await self._run_upscaling(ctx)
else:
ctx.upscaled = ctx.img_colorized
ctx.img_rgb, ctx.img_alpha = load_image(ctx.upscaled)
# -- Detection
await self._report_progress('detection')
ctx.textlines, ctx.mask_raw, ctx.mask = await self._run_detection(ctx)
if self.verbose:
cv2.imwrite(self._result_path('mask_raw.png'), ctx.mask_raw)
if not ctx.textlines:
await self._report_progress('skip-no-regions', True)
# If no text was found result is intermediate image product
ctx.result = ctx.upscaled
return await self._revert_upscale(ctx)
if self.verbose:
img_bbox_raw = np.copy(ctx.img_rgb)
for txtln in ctx.textlines:
cv2.polylines(img_bbox_raw, [txtln.pts], True, color=(255, 0, 0), thickness=2)
cv2.imwrite(self._result_path('bboxes_unfiltered.png'), cv2.cvtColor(img_bbox_raw, cv2.COLOR_RGB2BGR))
# -- OCR
await self._report_progress('ocr')
ctx.textlines = await self._run_ocr(ctx)
if ctx.skip_lang is not None :
filtered_textlines = []
skip_langs = ctx.skip_lang.split(',')
for txtln in ctx.textlines :
try :
source_language = LANGDETECT_MAP.get(langdetect.detect(txtln.text), 'UNKNOWN')
except Exception :
source_language = 'UNKNOWN'
if source_language not in skip_langs :
filtered_textlines.append(txtln)
ctx.textlines = filtered_textlines
if not ctx.textlines:
await self._report_progress('skip-no-text', True)
# If no text was found result is intermediate image product
ctx.result = ctx.upscaled
return await self._revert_upscale(ctx)
# -- Textline merge
await self._report_progress('textline_merge')
ctx.text_regions = await self._run_textline_merge(ctx)
if self.verbose:
bboxes = visualize_textblocks(cv2.cvtColor(ctx.img_rgb, cv2.COLOR_BGR2RGB), ctx.text_regions)
cv2.imwrite(self._result_path('bboxes.png'), bboxes)
# -- Translation
await self._report_progress('translating')
ctx.text_regions = await self._run_text_translation(ctx)
await self._report_progress('after-translating')
if not ctx.text_regions:
await self._report_progress('error-translating', True)
ctx.result = ctx.upscaled
return await self._revert_upscale(ctx)
elif ctx.text_regions == 'cancel':
await self._report_progress('cancelled', True)
ctx.result = ctx.upscaled
return await self._revert_upscale(ctx)
# -- Mask refinement
# (Delayed to take advantage of the region filtering done after ocr and translation)
if ctx.mask is None:
await self._report_progress('mask-generation')
ctx.mask = await self._run_mask_refinement(ctx)
if self.verbose:
inpaint_input_img = await dispatch_inpainting('none', ctx.img_rgb, ctx.mask, ctx.inpainting_size,
self.using_gpu, self.verbose)
cv2.imwrite(self._result_path('inpaint_input.png'), cv2.cvtColor(inpaint_input_img, cv2.COLOR_RGB2BGR))
cv2.imwrite(self._result_path('mask_final.png'), ctx.mask)
# -- Inpainting
await self._report_progress('inpainting')
ctx.img_inpainted = await self._run_inpainting(ctx)
ctx.gimp_mask = np.dstack((cv2.cvtColor(ctx.img_inpainted, cv2.COLOR_RGB2BGR), ctx.mask))
if self.verbose:
cv2.imwrite(self._result_path('inpainted.png'), cv2.cvtColor(ctx.img_inpainted, cv2.COLOR_RGB2BGR))
# -- Rendering
await self._report_progress('rendering')
ctx.img_rendered = await self._run_text_rendering(ctx)
await self._report_progress('finished', True)
ctx.result = dump_image(ctx.input, ctx.img_rendered, ctx.img_alpha)
return await self._revert_upscale(ctx)
# If `revert_upscaling` is True, revert to input size
# Else leave `ctx` as-is
async def _revert_upscale(self, ctx: Context):
if ctx.revert_upscaling:
await self._report_progress('downscaling')
ctx.result = ctx.result.resize(ctx.input.size)
return ctx
async def _run_colorizer(self, ctx: Context):
return await dispatch_colorization(ctx.colorizer, device=self.device, image=ctx.input, **ctx)
async def _run_upscaling(self, ctx: Context):
return (await dispatch_upscaling(ctx.upscaler, [ctx.img_colorized], ctx.upscale_ratio, self.device))[0]
async def _run_detection(self, ctx: Context):
return await dispatch_detection(ctx.detector, ctx.img_rgb, ctx.detection_size, ctx.text_threshold,
ctx.box_threshold,
ctx.unclip_ratio, ctx.det_invert, ctx.det_gamma_correct, ctx.det_rotate,
ctx.det_auto_rotate,
self.device, self.verbose)
async def _run_ocr(self, ctx: Context):
textlines = await dispatch_ocr(ctx.ocr, ctx.img_rgb, ctx.textlines, ctx, self.device, self.verbose)
new_textlines = []
for textline in textlines:
if textline.text.strip():
if ctx.font_color_fg:
textline.fg_r, textline.fg_g, textline.fg_b = ctx.font_color_fg
if ctx.font_color_bg:
textline.bg_r, textline.bg_g, textline.bg_b = ctx.font_color_bg
new_textlines.append(textline)
return new_textlines
async def _run_textline_merge(self, ctx: Context):
text_regions = await dispatch_textline_merge(ctx.textlines, ctx.img_rgb.shape[1], ctx.img_rgb.shape[0],
verbose=self.verbose)
new_text_regions = []
for region in text_regions:
if len(region.text) >= ctx.min_text_length \
and not is_valuable_text(region.text) \
or (not ctx.no_text_lang_skip and langcodes.tag_distance(region.source_lang, ctx.target_lang) == 0):
if region.text.strip():
logger.info(f'Filtered out: {region.text}')
else:
if ctx.font_color_fg or ctx.font_color_bg:
if ctx.font_color_bg:
region.adjust_bg_color = False
new_text_regions.append(region)
text_regions = new_text_regions
# Sort ctd (comic text detector) regions left to right. Otherwise right to left.
# Sorting will improve text translation quality.
text_regions = sort_regions(text_regions, right_to_left=True if ctx.detector != 'ctd' else False)
return text_regions
async def _run_text_translation(self, ctx: Context):
translated_sentences = \
await dispatch_translation(ctx.translator,
[region.text for region in ctx.text_regions],
ctx.use_mtpe,
ctx, 'cpu' if self._gpu_limited_memory else self.device)
for region, translation in zip(ctx.text_regions, translated_sentences):
if ctx.uppercase:
translation = translation.upper()
elif ctx.lowercase:
translation = translation.upper()
region.translation = translation
region.target_lang = ctx.target_lang
region._alignment = ctx.alignment
region._direction = ctx.direction
# Filter out regions by their translations
new_text_regions = []
for region in ctx.text_regions:
# TODO: Maybe print reasons for filtering
if not ctx.translator == 'none' and (region.translation.isnumeric() \
or ctx.filter_text and re.search(ctx.filter_text, region.translation)
or not ctx.translator == 'original' and region.text.lower().strip() == region.translation.lower().strip()):
if region.translation.strip():
logger.info(f'Filtered out: {region.translation}')
else:
new_text_regions.append(region)
return new_text_regions
async def _run_mask_refinement(self, ctx: Context):
return await dispatch_mask_refinement(ctx.text_regions, ctx.img_rgb, ctx.mask_raw, 'fit_text',
ctx.mask_dilation_offset, ctx.ignore_bubble, self.verbose,self.kernel_size)
async def _run_inpainting(self, ctx: Context):
return await dispatch_inpainting(ctx.inpainter, ctx.img_rgb, ctx.mask, ctx.inpainting_size, self.device,
self.verbose)
async def _run_text_rendering(self, ctx: Context):
if ctx.renderer == 'none':
output = ctx.img_inpainted
# manga2eng currently only supports horizontal left to right rendering
elif ctx.renderer == 'manga2eng' and ctx.text_regions and LANGUAGE_ORIENTATION_PRESETS.get(
ctx.text_regions[0].target_lang) == 'h':
output = await dispatch_eng_render(ctx.img_inpainted, ctx.img_rgb, ctx.text_regions, ctx.font_path, ctx.line_spacing)
else:
output = await dispatch_rendering(ctx.img_inpainted, ctx.text_regions, ctx.font_path, ctx.font_size,
ctx.font_size_offset,
ctx.font_size_minimum, not ctx.no_hyphenation, ctx.render_mask, ctx.line_spacing)
return output
def _result_path(self, path: str) -> str:
"""
Returns path to result folder where intermediate images are saved when using verbose flag
or web mode input/result images are cached.
"""
return os.path.join(BASE_PATH, 'result', self.result_sub_folder, path)
def add_progress_hook(self, ph):
self._progress_hooks.append(ph)
async def _report_progress(self, state: str, finished: bool = False):
for ph in self._progress_hooks:
await ph(state, finished)
def _add_logger_hook(self):
# TODO: Pass ctx to logger hook
LOG_MESSAGES = {
'upscaling': 'Running upscaling',
'detection': 'Running text detection',
'ocr': 'Running ocr',
'mask-generation': 'Running mask refinement',
'translating': 'Running text translation',
'rendering': 'Running rendering',
'colorizing': 'Running colorization',
'downscaling': 'Running downscaling',
}
LOG_MESSAGES_SKIP = {
'skip-no-regions': 'No text regions! - Skipping',
'skip-no-text': 'No text regions with text! - Skipping',
'error-translating': 'Text translator returned empty queries',
'cancelled': 'Image translation cancelled',
}
LOG_MESSAGES_ERROR = {
# 'error-lang': 'Target language not supported by chosen translator',
}
async def ph(state, finished):
if state in LOG_MESSAGES:
logger.info(LOG_MESSAGES[state])
elif state in LOG_MESSAGES_SKIP:
logger.warn(LOG_MESSAGES_SKIP[state])
elif state in LOG_MESSAGES_ERROR:
logger.error(LOG_MESSAGES_ERROR[state])
self.add_progress_hook(ph)
def _save_text_to_file(self, image_path: str, ctx: Context):
cached_colors = []
def identify_colors(fg_rgb: List[int]):
idx = 0
for rgb, _ in cached_colors:
# If similar color already saved
if abs(rgb[0] - fg_rgb[0]) + abs(rgb[1] - fg_rgb[1]) + abs(rgb[2] - fg_rgb[2]) < 50:
break
else:
idx += 1
else:
cached_colors.append((fg_rgb, get_color_name(fg_rgb)))
return idx + 1, cached_colors[idx][1]
s = f'\n[{image_path}]\n'
for i, region in enumerate(ctx.text_regions):
fore, back = region.get_font_colors()
color_id, color_name = identify_colors(fore)
s += f'\n-- {i + 1} --\n'
s += f'color: #{color_id}: {color_name} (fg, bg: {rgb2hex(*fore)} {rgb2hex(*back)})\n'
s += f'text: {region.text}\n'
s += f'trans: {region.translation}\n'
for line in region.lines:
s += f'coords: {list(line.ravel())}\n'
s += '\n'
text_output_file = ctx.text_output_file
if not text_output_file:
text_output_file = os.path.splitext(image_path)[0] + '_translations.txt'
with open(text_output_file, 'a', encoding='utf-8') as f:
f.write(s)
class MangaTranslatorWeb(MangaTranslator):
"""
Translator client that executes tasks on behalf of the webserver in web_main.py.
"""
def __init__(self, params: dict = None):
super().__init__(params)
self.host = params.get('host', '127.0.0.1')
if self.host == '0.0.0.0':
self.host = '127.0.0.1'
self.port = params.get('port', 5003)
self.nonce = params.get('nonce', '')
self.ignore_errors = params.get('ignore_errors', True)
self._task_id = None
self._params = None
async def _init_connection(self):
available_translators = []
from .translators import MissingAPIKeyException, get_translator
for key in TRANSLATORS:
try:
get_translator(key)
available_translators.append(key)
except MissingAPIKeyException:
pass
data = {
'nonce': self.nonce,
'capabilities': {
'translators': available_translators,
},
}
requests.post(f'http://{self.host}:{self.port}/connect-internal', json=data)
async def _send_state(self, state: str, finished: bool):
# wait for translation to be saved first (bad solution?)
finished = finished and not state == 'finished'
while True:
try:
data = {
'task_id': self._task_id,
'nonce': self.nonce,
'state': state,
'finished': finished,
}
requests.post(f'http://{self.host}:{self.port}/task-update-internal', json=data, timeout=20)
break
except Exception:
# if translation is finished server has to know
if finished:
continue
else:
break
def _get_task(self):
try:
rjson = requests.get(f'http://{self.host}:{self.port}/task-internal?nonce={self.nonce}',
timeout=3600).json()
return rjson.get('task_id'), rjson.get('data')
except Exception:
return None, None
async def listen(self, translation_params: dict = None):
"""
Listens for translation tasks from web server.
"""
logger.info('Waiting for translation tasks')
await self._init_connection()
self.add_progress_hook(self._send_state)
while True:
self._task_id, self._params = self._get_task()
if self._params and 'exit' in self._params:
break
if not (self._task_id and self._params):
await asyncio.sleep(0.1)
continue
self.result_sub_folder = self._task_id
logger.info(f'Processing task {self._task_id}')
if translation_params is not None:
# Combine default params with params chosen by webserver
for p, default_value in translation_params.items():
current_value = self._params.get(p)
self._params[p] = current_value if current_value is not None else default_value
if self.verbose:
# Write log file
log_file = self._result_path('log.txt')
add_file_logger(log_file)
# final.png will be renamed if format param is set
await self.translate_path(self._result_path('input.png'), self._result_path('final.png'),
params=self._params)
print()
if self.verbose:
remove_file_logger(log_file)
self._task_id = None
self._params = None
self.result_sub_folder = ''
async def _run_text_translation(self, ctx: Context):
# Run machine translation as reference for manual translation (if `--translator=none` is not set)
text_regions = await super()._run_text_translation(ctx)
if ctx.get('manual', False):
logger.info('Waiting for user input from manual translation')
requests.post(f'http://{self.host}:{self.port}/request-manual-internal', json={
'task_id': self._task_id,
'nonce': self.nonce,
'texts': [r.text for r in text_regions],
'translations': [r.translation for r in text_regions],
}, timeout=20)
# wait for at most 1 hour for manual translation
wait_until = time.time() + 3600
while time.time() < wait_until:
ret = requests.post(f'http://{self.host}:{self.port}/get-manual-result-internal', json={
'task_id': self._task_id,
'nonce': self.nonce
}, timeout=20).json()
if 'result' in ret:
manual_translations = ret['result']
if isinstance(manual_translations, str):
if manual_translations == 'error':
return []
i = 0
for translation in manual_translations:
if not translation.strip():
text_regions.pop(i)
i = i - 1
else:
text_regions[i].translation = translation
text_regions[i].target_lang = ctx.translator.langs[-1]
i = i + 1
break
elif 'cancel' in ret:
return 'cancel'
await asyncio.sleep(0.1)
return text_regions
class MangaTranslatorWS(MangaTranslator):
def __init__(self, params: dict = None):
super().__init__(params)
self.url = params.get('ws_url')
self.secret = params.get('ws_secret', os.getenv('WS_SECRET', ''))
self.ignore_errors = params.get('ignore_errors', True)
self._task_id = None
self._websocket = None
async def listen(self, translation_params: dict = None):
from threading import Thread
import io
import aioshutil
from aiofiles import os
import websockets
from .server import ws_pb2
self._server_loop = asyncio.new_event_loop()
self.task_lock = PriorityLock()
self.counter = 0
async def _send_and_yield(websocket, msg):
# send message and yield control to the event loop (to actually send the message)
await websocket.send(msg)
await asyncio.sleep(0)
send_throttler = Throttler(0.2)
send_and_yield = send_throttler.wrap(_send_and_yield)
async def sync_state(state, finished):
if self._websocket is None:
return
msg = ws_pb2.WebSocketMessage()
msg.status.id = self._task_id
msg.status.status = state
self._server_loop.call_soon_threadsafe(
asyncio.create_task,
send_and_yield(self._websocket, msg.SerializeToString())
)
self.add_progress_hook(sync_state)
async def translate(task_id, websocket, image, params):
async with self.task_lock((1 << 31) - params['ws_count']):
self._task_id = task_id
self._websocket = websocket
result = await self.translate(image, params)
self._task_id = None
self._websocket = None
return result
async def server_send_status(websocket, task_id, status):
msg = ws_pb2.WebSocketMessage()
msg.status.id = task_id
msg.status.status = status
await websocket.send(msg.SerializeToString())
await asyncio.sleep(0)
async def server_process_inner(main_loop, logger_task, session, websocket, task) -> Tuple[bool, bool]:
logger_task.info(f'-- Processing task {task.id}')
await server_send_status(websocket, task.id, 'pending')
if self.verbose:
await aioshutil.rmtree(f'result/{task.id}', ignore_errors=True)
await os.makedirs(f'result/{task.id}', exist_ok=True)
params = {
'target_lang': task.target_language,
'skip_lang': task.skip_language,
'detector': task.detector,
'direction': task.direction,
'translator': task.translator,
'size': task.size,
'ws_event_loop': asyncio.get_event_loop(),
'ws_count': self.counter,
}
self.counter += 1
logger_task.info(f'-- Downloading image from {task.source_image}')
await server_send_status(websocket, task.id, 'downloading')
async with session.get(task.source_image) as resp:
if resp.status == 200:
source_image = await resp.read()
else:
msg = ws_pb2.WebSocketMessage()
msg.status.id = task.id
msg.status.status = 'error-download'
await websocket.send(msg.SerializeToString())
await asyncio.sleep(0)
return False, False
logger_task.info(f'-- Translating image')
if translation_params:
for p, default_value in translation_params.items():
current_value = params.get(p)
params[p] = current_value if current_value is not None else default_value
image = Image.open(io.BytesIO(source_image))
(ori_w, ori_h) = image.size
if max(ori_h, ori_w) > 1200:
params['upscale_ratio'] = 1
await server_send_status(websocket, task.id, 'preparing')
# translation_dict = await self.translate(image, params)
translation_dict = await asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(
translate(task.id, websocket, image, params),
main_loop
)
)
await send_throttler.flush()
output: Image.Image = translation_dict.result
if output is not None:
await server_send_status(websocket, task.id, 'saving')
output = output.resize((ori_w, ori_h), resample=Image.LANCZOS)
img = io.BytesIO()
output.save(img, format='PNG')
if self.verbose:
output.save(self._result_path('ws_final.png'))
img_bytes = img.getvalue()
logger_task.info(f'-- Uploading result to {task.translation_mask}')
await server_send_status(websocket, task.id, 'uploading')
async with session.put(task.translation_mask, data=img_bytes) as resp:
if resp.status != 200:
logger_task.error(f'-- Failed to upload result:')
logger_task.error(f'{resp.status}: {resp.reason}')
msg = ws_pb2.WebSocketMessage()
msg.status.id = task.id
msg.status.status = 'error-upload'
await websocket.send(msg.SerializeToString())
await asyncio.sleep(0)
return False, False
return True, output is not None
async def server_process(main_loop, session, websocket, task) -> bool:
logger_task = logger.getChild(f'{task.id}')
try:
(success, has_translation_mask) = await server_process_inner(main_loop, logger_task, session, websocket,
task)
except Exception as e:
logger_task.error(f'-- Task failed with exception:')
logger_task.error(f'{e.__class__.__name__}: {e}', exc_info=e if self.verbose else None)
(success, has_translation_mask) = False, False
finally:
result = ws_pb2.WebSocketMessage()
result.finish_task.id = task.id
result.finish_task.success = success
result.finish_task.has_translation_mask = has_translation_mask
await websocket.send(result.SerializeToString())
await asyncio.sleep(0)
logger_task.info(f'-- Task finished')
async def async_server_thread(main_loop):
from aiohttp import ClientSession, ClientTimeout
timeout = ClientTimeout(total=30)
async with ClientSession(timeout=timeout) as session:
logger_conn = logger.getChild('connection')
if self.verbose:
logger_conn.setLevel(logging.DEBUG)
async for websocket in websockets.connect(
self.url,
extra_headers={
'x-secret': self.secret,
},
max_size=1_000_000,
logger=logger_conn
):
bg_tasks = set()
try:
logger.info('-- Connected to websocket server')
async for raw in websocket:
# logger.info(f'Got message: {raw}')
msg = ws_pb2.WebSocketMessage()
msg.ParseFromString(raw)
if msg.WhichOneof('message') == 'new_task':
task = msg.new_task
bg_task = asyncio.create_task(server_process(main_loop, session, websocket, task))
bg_tasks.add(bg_task)
bg_task.add_done_callback(bg_tasks.discard)
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}', exc_info=e if self.verbose else None)
finally:
logger.info('-- Disconnected from websocket server')
for bg_task in bg_tasks:
bg_task.cancel()
def server_thread(future, main_loop, server_loop):
asyncio.set_event_loop(server_loop)
try:
server_loop.run_until_complete(async_server_thread(main_loop))
finally:
future.set_result(None)
future = asyncio.Future()
Thread(
target=server_thread,
args=(future, asyncio.get_running_loop(), self._server_loop),
daemon=True
).start()
# create a future that is never done
await future
async def _run_text_translation(self, ctx: Context):
coroutine = super()._run_text_translation(ctx)
if ctx.translator.has_offline():
return await coroutine
else:
task_id = self._task_id
websocket = self._websocket
await self.task_lock.release()
result = await asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(
coroutine,
ctx.ws_event_loop
)
)
await self.task_lock.acquire((1 << 30) - ctx.ws_count)
self._task_id = task_id
self._websocket = websocket
return result
async def _run_text_rendering(self, ctx: Context):
render_mask = (ctx.mask >= 127).astype(np.uint8)[:, :, None]
output = await super()._run_text_rendering(ctx)
render_mask[np.sum(ctx.img_rgb != output, axis=2) > 0] = 1
ctx.render_mask = render_mask
if self.verbose:
cv2.imwrite(self._result_path('ws_render_in.png'), cv2.cvtColor(ctx.img_rgb, cv2.COLOR_RGB2BGR))
cv2.imwrite(self._result_path('ws_render_out.png'), cv2.cvtColor(output, cv2.COLOR_RGB2BGR))
cv2.imwrite(self._result_path('ws_mask.png'), render_mask * 255)
# only keep sections in mask
if self.verbose:
cv2.imwrite(self._result_path('ws_inmask.png'), cv2.cvtColor(ctx.img_rgb, cv2.COLOR_RGB2BGRA) * render_mask)
output = cv2.cvtColor(output, cv2.COLOR_RGB2RGBA) * render_mask
if self.verbose:
cv2.imwrite(self._result_path('ws_output.png'), cv2.cvtColor(output, cv2.COLOR_RGBA2BGRA) * render_mask)
return output
# Experimental. May be replaced by a refactored server/web_main.py in the future.
class MangaTranslatorAPI(MangaTranslator):
def __init__(self, params: dict = None):
import nest_asyncio
nest_asyncio.apply()
super().__init__(params)
self.host = params.get('host', '127.0.0.1')
self.port = params.get('port', '5003')
self.log_web = params.get('log_web', False)
self.ignore_errors = params.get('ignore_errors', True)
self._task_id = None
self._params = None
self.params = params
self.queue = []
async def wait_queue(self, id: int):
while self.queue[0] != id:
await asyncio.sleep(0.05)
def remove_from_queue(self, id: int):
self.queue.remove(id)
def generate_id(self):
try:
x = max(self.queue)
except:
x = 0
return x + 1
def middleware_factory(self):
@middleware
async def sample_middleware(request, handler):
id = self.generate_id()
self.queue.append(id)
try:
await self.wait_queue(id)
except Exception as e:
print(e)
try:
# todo make cancellable
response = await handler(request)
except:
response = web.json_response({'error': "Internal Server Error", 'status': 500},
status=500)
# Handle cases where a user leaves the queue, request fails, or is completed
try:
self.remove_from_queue(id)
except Exception as e:
print(e)
return response
return sample_middleware
async def get_file(self, image, base64Images, url) -> Image:
if image is not None:
content = image.file.read()
elif base64Images is not None:
base64Images = base64Images
if base64Images.__contains__('base64,'):
base64Images = base64Images.split('base64,')[1]
content = base64.b64decode(base64Images)
elif url is not None:
from aiohttp import ClientSession
async with ClientSession() as session:
async with session.get(url) as resp:
if resp.status == 200:
content = await resp.read()
else:
return web.json_response({'status': 'error'})
else:
raise ValidationError("donest exist")
img = Image.open(io.BytesIO(content))
img.verify()
img = Image.open(io.BytesIO(content))
if img.width * img.height > 8000 ** 2:
raise ValidationError("to large")
return img
async def listen(self, translation_params: dict = None):
self.params = translation_params
app = web.Application(client_max_size=1024 * 1024 * 50, middlewares=[self.middleware_factory()])
routes = web.RouteTableDef()
run_until_state = ''
async def hook(state, finished):
if run_until_state and run_until_state == state and not finished:
raise TranslationInterrupt()
self.add_progress_hook(hook)
@routes.post("/get_text")
async def text_api(req):
nonlocal run_until_state
run_until_state = 'translating'
return await self.err_handling(self.run_translate, req, self.format_translate)
@routes.post("/translate")
async def translate_api(req):
nonlocal run_until_state
run_until_state = 'after-translating'
return await self.err_handling(self.run_translate, req, self.format_translate)
@routes.post("/inpaint_translate")
async def inpaint_translate_api(req):
nonlocal run_until_state
run_until_state = 'rendering'
return await self.err_handling(self.run_translate, req, self.format_translate)
@routes.post("/colorize_translate")
async def colorize_translate_api(req):
nonlocal run_until_state
run_until_state = 'rendering'
return await self.err_handling(self.run_translate, req, self.format_translate, True)
# #@routes.post("/file")
# async def file_api(req):
# #TODO: return file
# return await self.err_handling(self.file_exec, req, None)
app.add_routes(routes)
web.run_app(app, host=self.host, port=self.port)
async def run_translate(self, translation_params, img):
return await self.translate(img, translation_params)
async def err_handling(self, func, req, format, ri=False):
try:
if req.content_type == 'application/json' or req.content_type == 'multipart/form-data':
if req.content_type == 'application/json':
d = await req.json()
else:
d = await req.post()
schema = self.PostSchema()
data = schema.load(d)
if 'translator_chain' in data:
data['translator_chain'] = translator_chain(data['translator_chain'])
if 'selective_translation' in data:
data['selective_translation'] = translator_chain(data['selective_translation'])
ctx = Context(**dict(self.params, **data))
self._preprocess_params(ctx)
if data.get('image') is None and data.get('base64Images') is None and data.get('url') is None:
return web.json_response({'error': "Missing input", 'status': 422})
fil = await self.get_file(data.get('image'), data.get('base64Images'), data.get('url'))
if 'image' in data:
del data['image']
if 'base64Images' in data:
del data['base64Images']
if 'url' in data:
del data['url']
attempts = 0
while ctx.attempts == -1 or attempts <= ctx.attempts:
if attempts > 0:
logger.info(f'Retrying translation! Attempt {attempts}' + (
f' of {ctx.attempts}' if ctx.attempts != -1 else ''))
try:
await func(ctx, fil)
break
except TranslationInterrupt:
break
except Exception as e:
print(e)
attempts += 1
if ctx.attempts != -1 and attempts > ctx.attempts:
return web.json_response({'error': "Internal Server Error", 'status': 500},
status=500)
try:
return format(ctx, ri)
except Exception as e:
print(e)
return web.json_response({'error': "Failed to format", 'status': 500},
status=500)
else:
return web.json_response({'error': "Wrong content type: " + req.content_type, 'status': 415},
status=415)
except ValueError as e:
print(e)
return web.json_response({'error': "Wrong input type", 'status': 422}, status=422)
except ValidationError as e:
print(e)
return web.json_response({'error': "Input invalid", 'status': 422}, status=422)
def format_translate(self, ctx: Context, return_image: bool):
text_regions = ctx.text_regions
inpaint = ctx.img_inpainted
results = []
if 'overlay_ext' in ctx:
overlay_ext = ctx['overlay_ext']
else:
overlay_ext = 'jpg'
for i, blk in enumerate(text_regions):
minX, minY, maxX, maxY = blk.xyxy
if 'translations' in ctx:
trans = {key: value[i] for key, value in ctx['translations'].items()}
else:
trans = {}
trans["originalText"] = text_regions[i].text
if inpaint is not None:
overlay = inpaint[minY:maxY, minX:maxX]
retval, buffer = cv2.imencode('.' + overlay_ext, overlay)
jpg_as_text = base64.b64encode(buffer)
background = "data:image/" + overlay_ext + ";base64," + jpg_as_text.decode("utf-8")
else:
background = None
text_region = text_regions[i]
text_region.adjust_bg_color = False
color1, color2 = text_region.get_font_colors()
results.append({
'text': trans,
'minX': int(minX),
'minY': int(minY),
'maxX': int(maxX),
'maxY': int(maxY),
'textColor': {
'fg': color1.tolist(),
'bg': color2.tolist()
},
'language': text_regions[i].source_lang,
'background': background
})
if return_image and ctx.img_colorized is not None:
retval, buffer = cv2.imencode('.' + overlay_ext, np.array(ctx.img_colorized))
jpg_as_text = base64.b64encode(buffer)
img = "data:image/" + overlay_ext + ";base64," + jpg_as_text.decode("utf-8")
else:
img = None
return web.json_response({'details': results, 'img': img})
class PostSchema(Schema):
target_lang = fields.Str(required=False, validate=lambda a: a.upper() in VALID_LANGUAGES)
detector = fields.Str(required=False, validate=lambda a: a.lower() in DETECTORS)
ocr = fields.Str(required=False, validate=lambda a: a.lower() in OCRS)
inpainter = fields.Str(required=False, validate=lambda a: a.lower() in INPAINTERS)
upscaler = fields.Str(required=False, validate=lambda a: a.lower() in UPSCALERS)
translator = fields.Str(required=False, validate=lambda a: a.lower() in TRANSLATORS)
direction = fields.Str(required=False, validate=lambda a: a.lower() in {'auto', 'h', 'v'})
skip_language = fields.Str(required=False)
upscale_ratio = fields.Integer(required=False)
translator_chain = fields.Str(required=False)
selective_translation = fields.Str(required=False)
attempts = fields.Integer(required=False)
detection_size = fields.Integer(required=False)
text_threshold = fields.Float(required=False)
box_threshold = fields.Float(required=False)
unclip_ratio = fields.Float(required=False)
inpainting_size = fields.Integer(required=False)
det_rotate = fields.Bool(required=False)
det_auto_rotate = fields.Bool(required=False)
det_invert = fields.Bool(required=False)
det_gamma_correct = fields.Bool(required=False)
min_text_length = fields.Integer(required=False)
colorization_size = fields.Integer(required=False)
denoise_sigma = fields.Integer(required=False)
mask_dilation_offset = fields.Integer(required=False)
ignore_bubble = fields.Integer(required=False)
gpt_config = fields.String(required=False)
filter_text = fields.String(required=False)
# api specific
overlay_ext = fields.Str(required=False)
base64Images = fields.Raw(required=False)
image = fields.Raw(required=False)
url = fields.Raw(required=False)
# no functionality except preventing errors when given
fingerprint = fields.Raw(required=False)
clientUuid = fields.Raw(required=False)