File size: 2,388 Bytes
c625a8c
 
 
654eaa0
30b9c64
654eaa0
 
c625a8c
654eaa0
 
 
 
 
 
 
a05fde6
654eaa0
c625a8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654eaa0
c625a8c
 
 
 
 
 
a05fde6
c625a8c
66da31c
 
 
 
 
 
c625a8c
 
0b4fed2
c625a8c
 
 
 
 
 
6c8cc78
c625a8c
 
 
6c8cc78
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import fastapi
from fastapi.responses import JSONResponse
from time import time
#MODEL_PATH = "./qwen1_5-0_5b-chat-q4_0.gguf" #"./qwen1_5-0_5b-chat-q4_0.gguf"
import logging
import llama_cpp
import llama_cpp.llama_tokenizer

llama = llama_cpp.Llama.from_pretrained(
    repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
    filename="*q4_0.gguf",
    tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B"),
    verbose=False,
     n_ctx=4096,
        n_gpu_layers=0,
    chat_format="llama-2"
)
# Logger setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize Llama model
"""
try:
    llm = Llama.from_pretrained(
    repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
    filename="*q4_0.gguf",
    verbose=False,
        n_ctx=4096,
        n_threads=4,
        n_gpu_layers=0,
)
    
    llm = Llama(
        model_path=MODEL_PATH,
        chat_format="llama-2",
        n_ctx=4096,
        n_threads=8,
        n_gpu_layers=0,
    )
    
except Exception as e:
    logger.error(f"Failed to load model: {e}")
    raise
"""

app = fastapi.FastAPI()


@app.get("/")
def index():
    return fastapi.responses.RedirectResponse(url="/docs")


@app.get("/health")
def health():
    return {"status": "ok"}


# Chat Completion API
@app.get("/generate")
async def complete(
    question: str,
    system: str = "You are a story writing assistant.",
    temperature: float = 0.7,
    seed: int = 42,
) -> dict:
    try:
        st = time()
        output = llama.create_chat_completion(
            messages=[
                {"role": "system", "content": system},
                {"role": "user", "content": question},
            ],
            temperature=temperature,
            seed=seed,
            stream=True
        )
        for chunk in output:
            delta = chunk['choices'][0]['delta']
            if 'role' in delta:
                print(delta['role'], end=': ')
            elif 'content' in delta:
                print(delta['content'], end='')
        et = time()
        output["time"] = et - st
        #return output
    except Exception as e:
        logger.error(f"Error in /complete endpoint: {e}")
        return JSONResponse(
            status_code=500, content={"message": "Internal Server Error"}
        )


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=7860)