File size: 2,455 Bytes
ffa839a
5fa76ab
09d4d49
b8a4dc6
a2ffc10
0af13be
b8a4dc6
09d4d49
 
a2ffc10
 
 
 
 
b8a4dc6
a2ffc10
 
 
 
 
b8a4dc6
a2ffc10
09d4d49
a2ffc10
 
 
 
 
 
 
 
 
 
 
 
 
 
09d4d49
 
0af13be
a2ffc10
 
 
 
 
 
 
b8a4dc6
a2ffc10
 
 
 
 
 
b8a4dc6
a2ffc10
e32c97b
a2ffc10
 
 
 
 
 
 
b8a4dc6
e32c97b
 
 
 
 
b8a4dc6
e32c97b
 
b8a4dc6
ffa839a
a2ffc10
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
from typing import List, Optional

app = FastAPI()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

class ChatMessage(BaseModel):
    role: str
    content: str

class GenerationRequest(BaseModel):
    prompt: str
    message: Optional[str] = None
    system_message: Optional[str] = None
    history: Optional[List[ChatMessage]] = None
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 0.95

def format_prompt(message: str, history: List[ChatMessage] = None, system_message: str = None) -> str:
    prompt = "<s>"
    
    # Add system message if provided
    if system_message:
        prompt += f"[INST] {system_message} [/INST]</s>"
    
    # Add conversation history
    if history:
        for msg in history:
            if msg.role == "user":
                prompt += f"[INST] {msg.content} [/INST]"
            else:
                prompt += f" {msg.content}</s>"
    
    # Add the current message
    prompt += f"[INST] {message} [/INST]"
    return prompt

@app.post("/generate/")
async def generate_text(request: GenerationRequest):
    try:
        # Use either prompt or message
        message = request.prompt if request.prompt else request.message
        if not message:
            raise HTTPException(status_code=400, detail="Either 'prompt' or 'message' must be provided")

        # Format the prompt with history and system message if provided
        formatted_prompt = format_prompt(
            message=message,
            history=request.history,
            system_message=request.system_message
        )

        # Generate response
        params = {
            "temperature": max(request.temperature, 0.01),  # Ensure temperature isn't too low
            "max_new_tokens": 1048,
            "top_p": request.top_p,
            "repetition_penalty": 1.0,
            "do_sample": True,
            "seed": 42
        }

        # Generate the response - handling the response as a single string
        response = client.text_generation(
            formatted_prompt,
            **params
        )

        # The response is now directly a string
        return {"response": response}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)