Spaces:
Running
Running
File size: 5,020 Bytes
ca56e6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import time
import uuid
from functools import partial
from typing import (
Dict,
Any,
AsyncIterator,
)
import anyio
from fastapi import APIRouter, Depends
from fastapi import HTTPException, Request
from loguru import logger
from openai.types.chat import (
ChatCompletionMessage,
ChatCompletion,
ChatCompletionChunk,
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from sse_starlette import EventSourceResponse
from text_generation.types import StreamResponse, Response
from api.core.tgi import TGIEngine
from api.models import GENERATE_ENGINE
from api.utils.compat import model_dump
from api.utils.protocol import Role, ChatCompletionCreateParams
from api.utils.request import (
check_api_key,
handle_request,
get_event_publisher,
)
chat_router = APIRouter(prefix="/chat")
def get_engine():
yield GENERATE_ENGINE
@chat_router.post("/completions", dependencies=[Depends(check_api_key)])
async def create_chat_completion(
request: ChatCompletionCreateParams,
raw_request: Request,
engine: TGIEngine = Depends(get_engine),
):
if (not request.messages) or request.messages[-1]["role"] == Role.ASSISTANT:
raise HTTPException(status_code=400, detail="Invalid request")
request = await handle_request(request, engine.prompt_adapter.stop)
request.max_tokens = request.max_tokens or 512
prompt = engine.apply_chat_template(request.messages)
include = {
"temperature",
"best_of",
"repetition_penalty",
"typical_p",
"watermark",
}
params = model_dump(request, include=include)
params.update(
dict(
prompt=prompt,
do_sample=request.temperature > 1e-5,
max_new_tokens=request.max_tokens,
stop_sequences=request.stop,
top_p=request.top_p if request.top_p < 1.0 else 0.99,
)
)
logger.debug(f"==== request ====\n{params}")
request_id: str = f"chatcmpl-{str(uuid.uuid4())}"
if request.stream:
generator = engine.generate_stream(**params)
iterator = create_chat_completion_stream(generator, params, request_id)
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial(
get_event_publisher,
request=raw_request,
inner_send_chan=send_chan,
iterator=iterator,
),
)
response: Response = await engine.generate(**params)
finish_reason = response.details.finish_reason.value
finish_reason = "length" if finish_reason == "length" else "stop"
message = ChatCompletionMessage(role="assistant", content=response.generated_text)
choice = Choice(
index=0,
message=message,
finish_reason=finish_reason,
logprobs=None,
)
num_prompt_tokens = len(response.details.prefill)
num_generated_tokens = response.details.generated_tokens
usage = CompletionUsage(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
return ChatCompletion(
id=request_id,
choices=[choice],
created=int(time.time()),
model=request.model,
object="chat.completion",
usage=usage,
)
async def create_chat_completion_stream(
generator: AsyncIterator[StreamResponse], params: Dict[str, Any], request_id: str
) -> AsyncIterator[ChatCompletionChunk]:
# First chunk with role
choice = ChunkChoice(
index=0,
delta=ChoiceDelta(role="assistant", content=""),
finish_reason=None,
logprobs=None,
)
yield ChatCompletionChunk(
id=request_id,
choices=[choice],
created=int(time.time()),
model=params.get("model", "llm"),
object="chat.completion.chunk",
)
async for output in generator:
output: StreamResponse
if output.token.special:
continue
choice = ChunkChoice(
index=0,
delta=ChoiceDelta(content=output.token.text),
finish_reason=None,
logprobs=None,
)
yield ChatCompletionChunk(
id=request_id,
choices=[choice],
created=int(time.time()),
model=params.get("model", "llm"),
object="chat.completion.chunk",
)
choice = ChunkChoice(
index=0,
delta=ChoiceDelta(),
finish_reason="stop",
logprobs=None,
)
yield ChatCompletionChunk(
id=request_id,
choices=[choice],
created=int(time.time()),
model=params.get("model", "llm"),
object="chat.completion.chunk",
)
|