File size: 5,544 Bytes
2f044c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import logging
import os
from pathlib import Path
from typing import List, Union
import psutil
import torch
from relik.common.utils import is_package_available
from relik.inference.annotator import Relik
if not is_package_available("fastapi"):
raise ImportError(
"FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
)
from fastapi import FastAPI, HTTPException
if not is_package_available("ray"):
raise ImportError(
"Ray is not installed. Please install Ray with `pip install relik[serve]`."
)
from ray import serve
from relik.common.log import get_logger
from relik.inference.serve.backend.utils import (
RayParameterManager,
ServerParameterManager,
)
logger = get_logger(__name__, level=logging.INFO)
VERSION = {} # type: ignore
with open(
Path(__file__).parent.parent.parent.parent / "version.py", "r"
) as version_file:
exec(version_file.read(), VERSION)
# Env variables for server
SERVER_MANAGER = ServerParameterManager()
RAY_MANAGER = RayParameterManager()
app = FastAPI(
title="ReLiK",
version=VERSION["VERSION"],
description="ReLiK REST API",
)
@serve.deployment(
ray_actor_options={
"num_gpus": RAY_MANAGER.num_gpus
if (
SERVER_MANAGER.device == "cuda"
or SERVER_MANAGER.retriever_device == "cuda"
or SERVER_MANAGER.reader_device == "cuda"
)
else 0
},
autoscaling_config={
"min_replicas": RAY_MANAGER.min_replicas,
"max_replicas": RAY_MANAGER.max_replicas,
},
)
@serve.ingress(app)
class RelikServer:
def __init__(
self,
relik_pretrained: str | None = None,
device: str = "cpu",
retriever_device: str | None = None,
document_index_device: str | None = None,
reader_device: str | None = None,
precision: str | int | torch.dtype = 32,
retriever_precision: str | int | torch.dtype | None = None,
document_index_precision: str | int | torch.dtype | None = None,
reader_precision: str | int | torch.dtype | None = None,
annotation_type: str = "char",
retriever_batch_size: int = 32,
reader_batch_size: int = 32,
relik_config_override: dict | None = None,
**kwargs,
):
num_threads = os.getenv("TORCH_NUM_THREADS", psutil.cpu_count(logical=False))
torch.set_num_threads(num_threads)
logger.info(f"Torch is running on {num_threads} threads.")
# parameters
logger.info(f"RELIK_PRETRAINED: {relik_pretrained}")
self.relik_pretrained = relik_pretrained
if relik_config_override is None:
relik_config_override = {}
logger.info(f"RELIK_CONFIG_OVERRIDE: {relik_config_override}")
self.relik_config_override = relik_config_override
logger.info(f"DEVICE: {device}")
self.device = device
if retriever_device is not None:
logger.info(f"RETRIEVER_DEVICE: {retriever_device}")
self.retriever_device = retriever_device or device
if document_index_device is not None:
logger.info(f"INDEX_DEVICE: {document_index_device}")
self.document_index_device = document_index_device or retriever_device
if reader_device is not None:
logger.info(f"READER_DEVICE: {reader_device}")
self.reader_device = reader_device
logger.info(f"PRECISION: {precision}")
self.precision = precision
if retriever_precision is not None:
logger.info(f"RETRIEVER_PRECISION: {retriever_precision}")
self.retriever_precision = retriever_precision or precision
if document_index_precision is not None:
logger.info(f"INDEX_PRECISION: {document_index_precision}")
self.document_index_precision = document_index_precision or precision
if reader_precision is not None:
logger.info(f"READER_PRECISION: {reader_precision}")
self.reader_precision = reader_precision or precision
logger.info(f"ANNOTATION_TYPE: {annotation_type}")
self.annotation_type = annotation_type
self.relik = Relik.from_pretrained(
self.relik_pretrained,
device=self.device,
retriever_device=self.retriever_device,
document_index_device=self.document_index_device,
reader_device=self.reader_device,
precision=self.precision,
retriever_precision=self.retriever_precision,
document_index_precision=self.document_index_precision,
reader_precision=self.reader_precision,
**self.relik_config_override,
)
self.retriever_batch_size = retriever_batch_size
self.reader_batch_size = reader_batch_size
# @serve.batch()
async def handle_batch(self, text: List[str]) -> List:
return self.relik(
text,
annotation_type=self.annotation_type,
retriever_batch_size=self.retriever_batch_size,
reader_batch_size=self.reader_batch_size,
)
@app.post("/api/relik")
async def relik_endpoint(self, text: Union[str, List[str]]):
try:
# get predictions for the retriever
return await self.handle_batch(text)
except Exception as e:
# log the entire stack trace
logger.exception(e)
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
server = RelikServer.bind(**vars(SERVER_MANAGER))
|