Spaces:
Running
Running
import json | |
from typing import Optional, List, AsyncIterator | |
from aiohttp import ClientSession | |
from openai.types.chat import ChatCompletionMessageParam | |
from pydantic import ValidationError | |
from text_generation import AsyncClient | |
from text_generation.errors import parse_error | |
from text_generation.types import Request, Parameters | |
from text_generation.types import Response, StreamResponse | |
from api.adapter import get_prompt_adapter | |
from api.utils.compat import model_dump | |
class TGIEngine: | |
def __init__( | |
self, | |
model: AsyncClient, | |
model_name: str, | |
prompt_name: Optional[str] = None, | |
): | |
""" | |
Initializes the TGIEngine object. | |
Args: | |
model: The AsyncLLMEngine object. | |
model_name: The name of the model. | |
prompt_name: The name of the prompt (optional). | |
""" | |
self.model = model | |
self.model_name = model_name.lower() | |
self.prompt_name = prompt_name.lower() if prompt_name is not None else None | |
self.prompt_adapter = get_prompt_adapter(self.model_name, prompt_name=self.prompt_name) | |
def apply_chat_template( | |
self, messages: List[ChatCompletionMessageParam], | |
) -> str: | |
""" | |
Applies a chat template to the given messages and returns the processed output. | |
Args: | |
messages: A list of ChatCompletionMessageParam objects representing the chat messages. | |
Returns: | |
str: The processed output as a string. | |
""" | |
return self.prompt_adapter.apply_chat_template(messages) | |
async def generate( | |
self, | |
prompt: str, | |
do_sample: bool = True, | |
max_new_tokens: int = 20, | |
best_of: Optional[int] = None, | |
repetition_penalty: Optional[float] = None, | |
return_full_text: bool = False, | |
seed: Optional[int] = None, | |
stop_sequences: Optional[List[str]] = None, | |
temperature: Optional[float] = None, | |
top_k: Optional[int] = None, | |
top_p: Optional[float] = None, | |
truncate: Optional[int] = None, | |
typical_p: Optional[float] = None, | |
watermark: bool = False, | |
decoder_input_details: bool = True, | |
top_n_tokens: Optional[int] = None, | |
) -> Response: | |
""" | |
Given a prompt, generate the following text asynchronously | |
Args: | |
prompt (`str`): | |
Input text | |
do_sample (`bool`): | |
Activate logits sampling | |
max_new_tokens (`int`): | |
Maximum number of generated tokens | |
best_of (`int`): | |
Generate best_of sequences and return the one if the highest token logprobs | |
repetition_penalty (`float`): | |
The parameter for repetition penalty. 1.0 means no penalty. See [this | |
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. | |
return_full_text (`bool`): | |
Whether to prepend the prompt to the generated text | |
seed (`int`): | |
Random sampling seed | |
stop_sequences (`List[str]`): | |
Stop generating tokens if a member of `stop_sequences` is generated | |
temperature (`float`): | |
The value used to module the logits distribution. | |
top_k (`int`): | |
The number of the highest probability vocabulary tokens to keep for top-k-filtering. | |
top_p (`float`): | |
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or | |
higher are kept for generation. | |
truncate (`int`): | |
Truncate inputs tokens to the given size | |
typical_p (`float`): | |
Typical Decoding mass | |
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information | |
watermark (`bool`): | |
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) | |
decoder_input_details (`bool`): | |
Return the decoder input token logprobs and ids | |
top_n_tokens (`int`): | |
Return the `n` most likely tokens at each step | |
Returns: | |
Response: generated response | |
""" | |
# Validate parameters | |
parameters = Parameters( | |
best_of=best_of, | |
details=True, | |
decoder_input_details=decoder_input_details, | |
do_sample=do_sample, | |
max_new_tokens=max_new_tokens, | |
repetition_penalty=repetition_penalty, | |
return_full_text=return_full_text, | |
seed=seed, | |
stop=stop_sequences if stop_sequences is not None else [], | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
truncate=truncate, | |
typical_p=typical_p, | |
watermark=watermark, | |
top_n_tokens=top_n_tokens, | |
) | |
request = Request(inputs=prompt, stream=False, parameters=parameters) | |
async with ClientSession( | |
headers=self.model.headers, cookies=self.model.cookies, timeout=self.model.timeout | |
) as session: | |
async with session.post(f"{self.model.base_url}/generate", json=model_dump(request)) as resp: | |
payload = await resp.json() | |
if resp.status != 200: | |
raise parse_error(resp.status, payload) | |
return Response(**payload) | |
async def generate_stream( | |
self, | |
prompt: str, | |
do_sample: bool = False, | |
max_new_tokens: int = 20, | |
best_of: Optional[int] = 1, | |
repetition_penalty: Optional[float] = None, | |
return_full_text: bool = False, | |
seed: Optional[int] = None, | |
stop_sequences: Optional[List[str]] = None, | |
temperature: Optional[float] = None, | |
top_k: Optional[int] = None, | |
top_p: Optional[float] = None, | |
truncate: Optional[int] = None, | |
typical_p: Optional[float] = None, | |
watermark: bool = False, | |
top_n_tokens: Optional[int] = None, | |
) -> AsyncIterator[StreamResponse]: | |
""" | |
Given a prompt, generate the following stream of tokens asynchronously | |
Args: | |
prompt (`str`): | |
Input text | |
do_sample (`bool`): | |
Activate logits sampling | |
max_new_tokens (`int`): | |
Maximum number of generated tokens | |
best_of (`int`): | |
Generate best_of sequences and return the one if the highest token logprobs | |
repetition_penalty (`float`): | |
The parameter for repetition penalty. 1.0 means no penalty. See [this | |
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. | |
return_full_text (`bool`): | |
Whether to prepend the prompt to the generated text | |
seed (`int`): | |
Random sampling seed | |
stop_sequences (`List[str]`): | |
Stop generating tokens if a member of `stop_sequences` is generated | |
temperature (`float`): | |
The value used to module the logits distribution. | |
top_k (`int`): | |
The number of the highest probability vocabulary tokens to keep for top-k-filtering. | |
top_p (`float`): | |
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or | |
higher are kept for generation. | |
truncate (`int`): | |
Truncate inputs tokens to the given size | |
typical_p (`float`): | |
Typical Decoding mass | |
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information | |
watermark (`bool`): | |
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) | |
top_n_tokens (`int`): | |
Return the `n` most likely tokens at each step | |
Returns: | |
AsyncIterator: stream of generated tokens | |
""" | |
# Validate parameters | |
parameters = Parameters( | |
best_of=best_of, | |
details=True, | |
do_sample=do_sample, | |
max_new_tokens=max_new_tokens, | |
repetition_penalty=repetition_penalty, | |
return_full_text=return_full_text, | |
seed=seed, | |
stop=stop_sequences if stop_sequences is not None else [], | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
truncate=truncate, | |
typical_p=typical_p, | |
watermark=watermark, | |
top_n_tokens=top_n_tokens, | |
) | |
request = Request(inputs=prompt, parameters=parameters) | |
async with ClientSession( | |
headers=self.model.headers, cookies=self.model.cookies, timeout=self.model.timeout | |
) as session: | |
async with session.post(f"{self.model.base_url}/generate_stream", json=model_dump(request)) as resp: | |
if resp.status != 200: | |
raise parse_error(resp.status, await resp.json()) | |
# Parse ServerSentEvents | |
async for byte_payload in resp.content: | |
# Skip line | |
if byte_payload == b"\n": | |
continue | |
payload = byte_payload.decode("utf-8") | |
# Event data | |
if payload.startswith("data:"): | |
# Decode payload | |
json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) | |
# Parse payload | |
try: | |
response = StreamResponse(**json_payload) | |
except ValidationError: | |
# If we failed to parse the payload, then it is an error payload | |
raise parse_error(resp.status, json_payload) | |
yield response | |
def stop(self): | |
""" | |
Gets the stop property of the prompt adapter. | |
Returns: | |
The stop property of the prompt adapter, or None if it does not exist. | |
""" | |
return self.prompt_adapter.stop if hasattr(self.prompt_adapter, "stop") else None | |