File size: 2,108 Bytes
533ea62
f410bac
5fa76ab
 
887a39c
883ff33
9298929
883ff33
 
5fa76ab
 
 
f410bac
 
 
 
 
 
 
 
 
6b10b2e
27abc54
 
5fa76ab
 
 
 
 
e6fa3d8
abd2555
473963a
5fa76ab
 
 
f441c05
5fa76ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722d8dd
 
 
 
76e497a
 
 
 
722d8dd
0489a11
883ff33
f445a1c
887a39c
883ff33
533ea62
f445a1c
27bf18d
dc7b6bf
27bf18d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from fastapi import FastAPI, Depends, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import os
from fastapi.security import OAuth2PasswordBearer
from typing import Annotated

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

app = FastAPI()

# Allow all CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
# client = InferenceClient("mistralai/Mistral-Nemo-Instruct-2407")

class Item(BaseModel):
    prompt: str
    history: list
    system_prompt: str
    temperature: float = 0.0
    max_new_tokens: int = 2048
    top_p: float = 0.15
    repetition_penalty: float = 1.0

def format_prompt(message, history):
    return message

def generate(item: Item):
    temperature = float(item.temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(item.top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=item.max_new_tokens,
        top_p=top_p,
        repetition_penalty=item.repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
    return output

@app.get("/")
async def root():
    return {"status": "ok"}

@app.head("/")
async def root():
    return {"status": "ok"}


@app.post("/generate")
async def generate_text(item: Item, token: Annotated[str, Depends(oauth2_scheme)]):
    # Reject if not authenticated
    apiKey = os.environ.get("API_KEY")
    if apiKey != token:
        raise HTTPException(status_code=403, detail="Invalid API key")

    return {
        "response": generate(item)
    }