Spaces:
Sleeping
Sleeping
File size: 7,386 Bytes
626eca0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import logging
from pathlib import Path
from typing import List, Optional, Union
from relik.common.utils import is_package_available
if not is_package_available("fastapi"):
raise ImportError(
"FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
)
from fastapi import FastAPI, HTTPException
if not is_package_available("ray"):
raise ImportError(
"Ray is not installed. Please install Ray with `pip install relik[serve]`."
)
from ray import serve
from relik.common.log import get_logger
from relik.inference.data.tokenizers import SpacyTokenizer, WhitespaceTokenizer
from relik.inference.data.window.manager import WindowManager
from relik.inference.serve.backend.utils import (
RayParameterManager,
ServerParameterManager,
)
from relik.retriever.data.utils import batch_generator
from relik.retriever.pytorch_modules import GoldenRetriever
logger = get_logger(__name__, level=logging.INFO)
VERSION = {} # type: ignore
with open(Path(__file__).parent.parent.parent / "version.py", "r") as version_file:
exec(version_file.read(), VERSION)
# Env variables for server
SERVER_MANAGER = ServerParameterManager()
RAY_MANAGER = RayParameterManager()
app = FastAPI(
title="Golden Retriever",
version=VERSION["VERSION"],
description="Golden Retriever REST API",
)
@serve.deployment(
ray_actor_options={
"num_gpus": RAY_MANAGER.num_gpus if SERVER_MANAGER.device == "cuda" else 0
},
autoscaling_config={
"min_replicas": RAY_MANAGER.min_replicas,
"max_replicas": RAY_MANAGER.max_replicas,
},
)
@serve.ingress(app)
class GoldenRetrieverServer:
def __init__(
self,
question_encoder: str,
document_index: str,
passage_encoder: Optional[str] = None,
top_k: int = 100,
device: str = "cpu",
index_device: Optional[str] = None,
precision: int = 32,
index_precision: Optional[int] = None,
use_faiss: bool = False,
window_batch_size: int = 32,
window_size: int = 32,
window_stride: int = 16,
split_on_spaces: bool = False,
):
# parameters
self.question_encoder = question_encoder
self.passage_encoder = passage_encoder
self.document_index = document_index
self.top_k = top_k
self.device = device
self.index_device = index_device or device
self.precision = precision
self.index_precision = index_precision or precision
self.use_faiss = use_faiss
self.window_batch_size = window_batch_size
self.window_size = window_size
self.window_stride = window_stride
self.split_on_spaces = split_on_spaces
# log stuff for debugging
logger.info("Initializing GoldenRetrieverServer with parameters:")
logger.info(f"QUESTION_ENCODER: {self.question_encoder}")
logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}")
logger.info(f"DOCUMENT_INDEX: {self.document_index}")
logger.info(f"TOP_K: {self.top_k}")
logger.info(f"DEVICE: {self.device}")
logger.info(f"INDEX_DEVICE: {self.index_device}")
logger.info(f"PRECISION: {self.precision}")
logger.info(f"INDEX_PRECISION: {self.index_precision}")
logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}")
logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}")
self.retriever = GoldenRetriever(
question_encoder=self.question_encoder,
passage_encoder=self.passage_encoder,
document_index=self.document_index,
device=self.device,
index_device=self.index_device,
index_precision=self.index_precision,
)
self.retriever.eval()
if self.split_on_spaces:
logger.info("Using WhitespaceTokenizer")
self.tokenizer = WhitespaceTokenizer()
# logger.info("Using RegexTokenizer")
# self.tokenizer = RegexTokenizer()
else:
logger.info("Using SpacyTokenizer")
self.tokenizer = SpacyTokenizer(language="en")
self.window_manager = WindowManager(tokenizer=self.tokenizer)
# @serve.batch()
async def handle_batch(
self, documents: List[str], document_topics: List[str]
) -> List:
return self.retriever.retrieve(
documents, text_pair=document_topics, k=self.top_k, precision=self.precision
)
@app.post("/api/retrieve")
async def retrieve_endpoint(
self,
documents: Union[str, List[str]],
document_topics: Optional[Union[str, List[str]]] = None,
):
try:
# normalize input
if isinstance(documents, str):
documents = [documents]
if document_topics is not None:
if isinstance(document_topics, str):
document_topics = [document_topics]
assert len(documents) == len(document_topics)
# get predictions
return await self.handle_batch(documents, document_topics)
except Exception as e:
# log the entire stack trace
logger.exception(e)
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
@app.post("/api/gerbil")
async def gerbil_endpoint(self, documents: Union[str, List[str]]):
try:
# normalize input
if isinstance(documents, str):
documents = [documents]
# output list
windows_passages = []
# split documents into windows
document_windows = [
window
for doc_id, document in enumerate(documents)
for window in self.window_manager(
self.tokenizer,
document,
window_size=self.window_size,
stride=self.window_stride,
doc_id=doc_id,
)
]
# get text and topic from document windows and create new list
model_inputs = [
(window.text, window.doc_topic) for window in document_windows
]
# batch generator
for batch in batch_generator(
model_inputs, batch_size=self.window_batch_size
):
text, text_pair = zip(*batch)
batch_predictions = await self.handle_batch(text, text_pair)
windows_passages.extend(
[
[p.label for p in predictions]
for predictions in batch_predictions
]
)
# add passage to document windows
for window, passages in zip(document_windows, windows_passages):
# clean up passages (remove everything after first <def> tag if present)
passages = [c.split(" <def>", 1)[0] for c in passages]
window.window_candidates = passages
# return document windows
return document_windows
except Exception as e:
# log the entire stack trace
logger.exception(e)
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
server = GoldenRetrieverServer.bind(**vars(SERVER_MANAGER))
|