|
import logging |
|
import os |
|
from pathlib import Path |
|
from typing import List, Union |
|
import psutil |
|
|
|
import torch |
|
|
|
from relik.common.utils import is_package_available |
|
from relik.inference.annotator import Relik |
|
|
|
if not is_package_available("fastapi"): |
|
raise ImportError( |
|
"FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`." |
|
) |
|
from fastapi import FastAPI, HTTPException |
|
|
|
if not is_package_available("ray"): |
|
raise ImportError( |
|
"Ray is not installed. Please install Ray with `pip install relik[serve]`." |
|
) |
|
from ray import serve |
|
|
|
from relik.common.log import get_logger |
|
from relik.inference.serve.backend.utils import ( |
|
RayParameterManager, |
|
ServerParameterManager, |
|
) |
|
|
|
logger = get_logger(__name__, level=logging.INFO) |
|
|
|
VERSION = {} |
|
with open( |
|
Path(__file__).parent.parent.parent.parent / "version.py", "r" |
|
) as version_file: |
|
exec(version_file.read(), VERSION) |
|
|
|
|
|
SERVER_MANAGER = ServerParameterManager() |
|
RAY_MANAGER = RayParameterManager() |
|
|
|
app = FastAPI( |
|
title="ReLiK", |
|
version=VERSION["VERSION"], |
|
description="ReLiK REST API", |
|
) |
|
|
|
|
|
@serve.deployment( |
|
ray_actor_options={ |
|
"num_gpus": RAY_MANAGER.num_gpus |
|
if ( |
|
SERVER_MANAGER.device == "cuda" |
|
or SERVER_MANAGER.retriever_device == "cuda" |
|
or SERVER_MANAGER.reader_device == "cuda" |
|
) |
|
else 0 |
|
}, |
|
autoscaling_config={ |
|
"min_replicas": RAY_MANAGER.min_replicas, |
|
"max_replicas": RAY_MANAGER.max_replicas, |
|
}, |
|
) |
|
@serve.ingress(app) |
|
class RelikServer: |
|
def __init__( |
|
self, |
|
relik_pretrained: str | None = None, |
|
device: str = "cpu", |
|
retriever_device: str | None = None, |
|
document_index_device: str | None = None, |
|
reader_device: str | None = None, |
|
precision: str | int | torch.dtype = 32, |
|
retriever_precision: str | int | torch.dtype | None = None, |
|
document_index_precision: str | int | torch.dtype | None = None, |
|
reader_precision: str | int | torch.dtype | None = None, |
|
annotation_type: str = "char", |
|
retriever_batch_size: int = 32, |
|
reader_batch_size: int = 32, |
|
relik_config_override: dict | None = None, |
|
**kwargs, |
|
): |
|
num_threads = os.getenv("TORCH_NUM_THREADS", psutil.cpu_count(logical=False)) |
|
torch.set_num_threads(num_threads) |
|
logger.info(f"Torch is running on {num_threads} threads.") |
|
|
|
|
|
logger.info(f"RELIK_PRETRAINED: {relik_pretrained}") |
|
self.relik_pretrained = relik_pretrained |
|
|
|
if relik_config_override is None: |
|
relik_config_override = {} |
|
logger.info(f"RELIK_CONFIG_OVERRIDE: {relik_config_override}") |
|
self.relik_config_override = relik_config_override |
|
|
|
logger.info(f"DEVICE: {device}") |
|
self.device = device |
|
|
|
if retriever_device is not None: |
|
logger.info(f"RETRIEVER_DEVICE: {retriever_device}") |
|
self.retriever_device = retriever_device or device |
|
|
|
if document_index_device is not None: |
|
logger.info(f"INDEX_DEVICE: {document_index_device}") |
|
self.document_index_device = document_index_device or retriever_device |
|
|
|
if reader_device is not None: |
|
logger.info(f"READER_DEVICE: {reader_device}") |
|
self.reader_device = reader_device |
|
|
|
logger.info(f"PRECISION: {precision}") |
|
self.precision = precision |
|
|
|
if retriever_precision is not None: |
|
logger.info(f"RETRIEVER_PRECISION: {retriever_precision}") |
|
self.retriever_precision = retriever_precision or precision |
|
|
|
if document_index_precision is not None: |
|
logger.info(f"INDEX_PRECISION: {document_index_precision}") |
|
self.document_index_precision = document_index_precision or precision |
|
|
|
if reader_precision is not None: |
|
logger.info(f"READER_PRECISION: {reader_precision}") |
|
self.reader_precision = reader_precision or precision |
|
|
|
logger.info(f"ANNOTATION_TYPE: {annotation_type}") |
|
self.annotation_type = annotation_type |
|
|
|
self.relik = Relik.from_pretrained( |
|
self.relik_pretrained, |
|
device=self.device, |
|
retriever_device=self.retriever_device, |
|
document_index_device=self.document_index_device, |
|
reader_device=self.reader_device, |
|
precision=self.precision, |
|
retriever_precision=self.retriever_precision, |
|
document_index_precision=self.document_index_precision, |
|
reader_precision=self.reader_precision, |
|
**self.relik_config_override, |
|
) |
|
|
|
self.retriever_batch_size = retriever_batch_size |
|
self.reader_batch_size = reader_batch_size |
|
|
|
|
|
async def handle_batch(self, text: List[str]) -> List: |
|
return self.relik( |
|
text, |
|
annotation_type=self.annotation_type, |
|
retriever_batch_size=self.retriever_batch_size, |
|
reader_batch_size=self.reader_batch_size, |
|
) |
|
|
|
@app.post("/api/relik") |
|
async def relik_endpoint(self, text: Union[str, List[str]]): |
|
try: |
|
|
|
return await self.handle_batch(text) |
|
except Exception as e: |
|
|
|
logger.exception(e) |
|
raise HTTPException(status_code=500, detail=f"Server Error: {e}") |
|
|
|
|
|
server = RelikServer.bind(**vars(SERVER_MANAGER)) |
|
|