File size: 4,506 Bytes
7456850 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import logging
import os
import json
from peewee import Model, SqliteDatabase, AutoField, CharField, TextField, SQL
from typing import Optional
# we don't init the database here
db = SqliteDatabase(None)
logger = logging.getLogger(__name__)
class _TranslationCache(Model):
id = AutoField()
translate_engine = CharField(max_length=20)
translate_engine_params = TextField()
original_text = TextField()
translation = TextField()
class Meta:
database = db
constraints = [
SQL(
"""
UNIQUE (
translate_engine,
translate_engine_params,
original_text
)
ON CONFLICT REPLACE
"""
)
]
class TranslationCache:
@staticmethod
def _sort_dict_recursively(obj):
if isinstance(obj, dict):
return {
k: TranslationCache._sort_dict_recursively(v)
for k in sorted(obj.keys())
for v in [obj[k]]
}
elif isinstance(obj, list):
return [TranslationCache._sort_dict_recursively(item) for item in obj]
return obj
def __init__(self, translate_engine: str, translate_engine_params: dict = None):
assert (
len(translate_engine) < 20
), "current cache require translate engine name less than 20 characters"
self.translate_engine = translate_engine
self.replace_params(translate_engine_params)
# The program typically starts multi-threaded translation
# only after cache parameters are fully configured,
# so thread safety doesn't need to be considered here.
def replace_params(self, params: dict = None):
if params is None:
params = {}
self.params = params
params = self._sort_dict_recursively(params)
self.translate_engine_params = json.dumps(params)
def update_params(self, params: dict = None):
if params is None:
params = {}
self.params.update(params)
self.replace_params(self.params)
def add_params(self, k: str, v):
self.params[k] = v
self.replace_params(self.params)
# Since peewee and the underlying sqlite are thread-safe,
# get and set operations don't need locks.
def get(self, original_text: str) -> Optional[str]:
result = _TranslationCache.get_or_none(
translate_engine=self.translate_engine,
translate_engine_params=self.translate_engine_params,
original_text=original_text,
)
return result.translation if result else None
def set(self, original_text: str, translation: str):
try:
_TranslationCache.create(
translate_engine=self.translate_engine,
translate_engine_params=self.translate_engine_params,
original_text=original_text,
translation=translation,
)
except Exception as e:
logger.debug(f"Error setting cache: {e}")
def init_db(remove_exists=False):
cache_folder = os.path.join(os.path.expanduser("~"), ".cache", "pdf2zh")
os.makedirs(cache_folder, exist_ok=True)
# The current version does not support database migration, so add the version number to the file name.
cache_db_path = os.path.join(cache_folder, "cache.v1.db")
if remove_exists and os.path.exists(cache_db_path):
os.remove(cache_db_path)
db.init(
cache_db_path,
pragmas={
"journal_mode": "wal",
"busy_timeout": 1000,
},
)
db.create_tables([_TranslationCache], safe=True)
def init_test_db():
import tempfile
cache_db_path = tempfile.mktemp(suffix=".db")
test_db = SqliteDatabase(
cache_db_path,
pragmas={
"journal_mode": "wal",
"busy_timeout": 1000,
},
)
test_db.bind([_TranslationCache], bind_refs=False, bind_backrefs=False)
test_db.connect()
test_db.create_tables([_TranslationCache], safe=True)
return test_db
def clean_test_db(test_db):
test_db.drop_tables([_TranslationCache])
test_db.close()
db_path = test_db.database
if os.path.exists(db_path):
os.remove(test_db.database)
wal_path = db_path + "-wal"
if os.path.exists(wal_path):
os.remove(wal_path)
shm_path = db_path + "-shm"
if os.path.exists(shm_path):
os.remove(shm_path)
init_db()
|