Spaces:
Sleeping
Sleeping
File size: 2,455 Bytes
ffa839a 5fa76ab 09d4d49 b8a4dc6 a2ffc10 0af13be b8a4dc6 09d4d49 a2ffc10 b8a4dc6 a2ffc10 b8a4dc6 a2ffc10 09d4d49 a2ffc10 09d4d49 0af13be a2ffc10 b8a4dc6 a2ffc10 b8a4dc6 a2ffc10 e32c97b a2ffc10 b8a4dc6 e32c97b b8a4dc6 e32c97b b8a4dc6 ffa839a a2ffc10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
from typing import List, Optional
app = FastAPI()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
class ChatMessage(BaseModel):
role: str
content: str
class GenerationRequest(BaseModel):
prompt: str
message: Optional[str] = None
system_message: Optional[str] = None
history: Optional[List[ChatMessage]] = None
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.95
def format_prompt(message: str, history: List[ChatMessage] = None, system_message: str = None) -> str:
prompt = "<s>"
# Add system message if provided
if system_message:
prompt += f"[INST] {system_message} [/INST]</s>"
# Add conversation history
if history:
for msg in history:
if msg.role == "user":
prompt += f"[INST] {msg.content} [/INST]"
else:
prompt += f" {msg.content}</s>"
# Add the current message
prompt += f"[INST] {message} [/INST]"
return prompt
@app.post("/generate/")
async def generate_text(request: GenerationRequest):
try:
# Use either prompt or message
message = request.prompt if request.prompt else request.message
if not message:
raise HTTPException(status_code=400, detail="Either 'prompt' or 'message' must be provided")
# Format the prompt with history and system message if provided
formatted_prompt = format_prompt(
message=message,
history=request.history,
system_message=request.system_message
)
# Generate response
params = {
"temperature": max(request.temperature, 0.01), # Ensure temperature isn't too low
"max_new_tokens": 1048,
"top_p": request.top_p,
"repetition_penalty": 1.0,
"do_sample": True,
"seed": 42
}
# Generate the response - handling the response as a single string
response = client.text_generation(
formatted_prompt,
**params
)
# The response is now directly a string
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) |