Spaces:
Runtime error
Runtime error
File size: 2,108 Bytes
533ea62 f410bac 5fa76ab 887a39c 883ff33 9298929 883ff33 5fa76ab f410bac 6b10b2e 27abc54 5fa76ab e6fa3d8 abd2555 473963a 5fa76ab f441c05 5fa76ab 722d8dd 76e497a 722d8dd 0489a11 883ff33 f445a1c 887a39c 883ff33 533ea62 f445a1c 27bf18d dc7b6bf 27bf18d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from fastapi import FastAPI, Depends, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import os
from fastapi.security import OAuth2PasswordBearer
from typing import Annotated
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
app = FastAPI()
# Allow all CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
# client = InferenceClient("mistralai/Mistral-Nemo-Instruct-2407")
class Item(BaseModel):
prompt: str
history: list
system_prompt: str
temperature: float = 0.0
max_new_tokens: int = 2048
top_p: float = 0.15
repetition_penalty: float = 1.0
def format_prompt(message, history):
return message
def generate(item: Item):
temperature = float(item.temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(item.top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=item.max_new_tokens,
top_p=top_p,
repetition_penalty=item.repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
return output
@app.get("/")
async def root():
return {"status": "ok"}
@app.head("/")
async def root():
return {"status": "ok"}
@app.post("/generate")
async def generate_text(item: Item, token: Annotated[str, Depends(oauth2_scheme)]):
# Reject if not authenticated
apiKey = os.environ.get("API_KEY")
if apiKey != token:
raise HTTPException(status_code=403, detail="Invalid API key")
return {
"response": generate(item)
}
|