|
import os.path |
|
import shutil |
|
from typing import List, Optional, Union |
|
|
|
from inference.core.env import MODEL_CACHE_DIR |
|
from inference.core.utils.file_system import ( |
|
dump_bytes, |
|
dump_json, |
|
dump_text_lines, |
|
read_json, |
|
read_text_file, |
|
) |
|
|
|
|
|
def initialise_cache(model_id: Optional[str] = None) -> None: |
|
cache_dir = get_cache_dir(model_id=model_id) |
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
def are_all_files_cached(files: List[str], model_id: Optional[str] = None) -> bool: |
|
return all(is_file_cached(file=file, model_id=model_id) for file in files) |
|
|
|
|
|
def is_file_cached(file: str, model_id: Optional[str] = None) -> bool: |
|
cached_file_path = get_cache_file_path(file=file, model_id=model_id) |
|
return os.path.isfile(cached_file_path) |
|
|
|
|
|
def load_text_file_from_cache( |
|
file: str, |
|
model_id: Optional[str] = None, |
|
split_lines: bool = False, |
|
strip_white_chars: bool = False, |
|
) -> Union[str, List[str]]: |
|
cached_file_path = get_cache_file_path(file=file, model_id=model_id) |
|
return read_text_file( |
|
path=cached_file_path, |
|
split_lines=split_lines, |
|
strip_white_chars=strip_white_chars, |
|
) |
|
|
|
|
|
def load_json_from_cache( |
|
file: str, model_id: Optional[str] = None, **kwargs |
|
) -> Optional[Union[dict, list]]: |
|
cached_file_path = get_cache_file_path(file=file, model_id=model_id) |
|
return read_json(path=cached_file_path, **kwargs) |
|
|
|
|
|
def save_bytes_in_cache( |
|
content: bytes, |
|
file: str, |
|
model_id: Optional[str] = None, |
|
allow_override: bool = True, |
|
) -> None: |
|
cached_file_path = get_cache_file_path(file=file, model_id=model_id) |
|
dump_bytes(path=cached_file_path, content=content, allow_override=allow_override) |
|
|
|
|
|
def save_json_in_cache( |
|
content: Union[dict, list], |
|
file: str, |
|
model_id: Optional[str] = None, |
|
allow_override: bool = True, |
|
**kwargs, |
|
) -> None: |
|
cached_file_path = get_cache_file_path(file=file, model_id=model_id) |
|
dump_json( |
|
path=cached_file_path, content=content, allow_override=allow_override, **kwargs |
|
) |
|
|
|
|
|
def save_text_lines_in_cache( |
|
content: List[str], |
|
file: str, |
|
model_id: Optional[str] = None, |
|
allow_override: bool = True, |
|
) -> None: |
|
cached_file_path = get_cache_file_path(file=file, model_id=model_id) |
|
dump_text_lines( |
|
path=cached_file_path, content=content, allow_override=allow_override |
|
) |
|
|
|
|
|
def get_cache_file_path(file: str, model_id: Optional[str] = None) -> str: |
|
cache_dir = get_cache_dir(model_id=model_id) |
|
return os.path.join(cache_dir, file) |
|
|
|
|
|
def clear_cache(model_id: Optional[str] = None) -> None: |
|
cache_dir = get_cache_dir(model_id=model_id) |
|
if os.path.exists(cache_dir): |
|
shutil.rmtree(cache_dir) |
|
|
|
|
|
def get_cache_dir(model_id: Optional[str] = None) -> str: |
|
if model_id is not None: |
|
return os.path.join(MODEL_CACHE_DIR, model_id) |
|
return MODEL_CACHE_DIR |
|
|