Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request, Response | |
from pydantic import BaseModel | |
from webscout.LLM import LLM | |
from typing import Union, Dict, List, Any | |
app = FastAPI() | |
class Model(BaseModel): | |
id: str | |
object: str | |
created: int | |
owned_by: str | |
class Message(BaseModel): | |
role: str | |
content: str | |
class CompletionRequest(BaseModel): | |
model: str | |
messages: List[Message] | |
class CompletionResponse(BaseModel): | |
id: str | |
object: str | |
created: int | |
model: str | |
choices: List[Dict[str, Any]] | |
usage: Dict[str, int] | |
models = [ | |
{"id": "meta-llama/Meta-Llama-3-70B-Instruct", "object": "model", "created": 1686935002, "owned_by": "meta"}, | |
{"id": "google/gemma-2-27b-it", "object": "model", "created": 1686935002, "owned_by": "meta"}, | |
{"id": "google/gemma-2-9b-it", "object": "model", "created": 1686935002, "owned_by": "ConsiousAI"}, | |
{"id": "cognitivecomputations/dolphin-2.9.1-llama-3-70b", "object": "model", "created": 1686935002, "owned_by": "cognitivecomputations"}, | |
{"id": "nvidia/Nemotron-4-340B-Instruct", "object": "model", "created": 1686935002, "owned_by": "nvidia"}, | |
{"id": "Qwen/Qwen2-72B-Instruct", "object": "model", "created": 1686935002, "owned_by": "qwen"}, | |
{"id": "microsoft/Phi-3-medium-4k-instruct", "object": "model", "created": 1686935002, "owned_by": "microsoft"}, | |
{"id": "google/gemma-2-9b-it", "object": "model", "created": 1686935002, "owned_by": "ConsiousAI"}, | |
{"id": "openchat/openchat-3.6-8b", "object": "model", "created": 1686935002, "owned_by": "unknown"}, | |
{"id": "mistralai/Mistral-7B-Instruct-v0.3", "object": "model", "created": 1686935002, "owned_by": "mistral"}, | |
{"id": "meta-llama/Meta-Llama-3-8B-Instruct", "object": "model", "created": 1686935002, "owned_by": "meta"}, | |
{"id": "mistralai/Mixtral-8x22B-Instruct-v0.1", "object": "model", "created": 1686935002, "owned_by": "mistral"}, | |
{"id": "mistralai/Mixtral-8x7B-Instruct-v0.1", "object": "model", "created": 1686935002, "owned_by": "mistral"}, | |
{"id": "Qwen/Qwen2-7B-Instruct", "object": "model", "created": 1686935002, "owned_by": "Qwen"}, | |
{"id": "meta-llama/Meta-Llama-3.1-405B-Instruct", "object": "model", "created": 1686935002, "owned_by": "meta"} | |
] | |
def handle_completions(completion_request: CompletionRequest): | |
system_prompt = next((message.content for message in completion_request.messages if message.role == 'system'), None) | |
user_query = next((message.content for message in completion_request.messages if message.role == 'user'), None) | |
response_text = generative(query=user_query, system_prompt=system_prompt, model=completion_request.model) | |
response = CompletionResponse( | |
id="chatcmpl-1", | |
object="chat.completion", | |
created=1234567890, | |
model=completion_request.model, | |
choices=[{"index": 0, "message": {"role": "assistant", "content": response_text}, "finish_reason": "stop"}], | |
usage={"prompt_tokens": sum(len(message.content.split()) for message in completion_request.messages), "total_tokens": sum(len(message.content.split()) for message in completion_request.messages) + len(response_text.split())} | |
) | |
return response | |
def get_models(): | |
return {"object": "list", "data": models} | |
def create_completion(prompt: str, model: str, best_of: int = 1, echo: bool = False, frequency_penalty: float = 0.0): | |
response_text = generative(prompt, "you are an helpful assistant", model) | |
response = { | |
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", | |
"object": "text_completion", | |
"created": 1589478378, | |
"model": model, | |
"system_fingerprint": "fp_44709d6fcb", | |
"choices": [{"text": response_text, "index": 0, "logprobs": None, "finish_reason": "length"}] | |
} | |
return response | |
def generative(system_prompt, query, model): | |
llm = LLM(model=model, system_message=system_prompt) | |
messages = [{"role": "user", "content": query}] | |
response = llm.chat(messages) | |
return response | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |