Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
from os import listdir | |
from os.path import isdir | |
from fastapi import FastAPI, HTTPException, Request, responses, Body | |
from fastapi.middleware.cors import CORSMiddleware | |
from llama_cpp import Llama | |
from pydantic import BaseModel | |
from enum import Enum | |
from typing import Optional, Literal, Dict, List | |
# MODEL LOADING, FUNCTIONS, AND TESTING | |
print("Loading model...") | |
PHllm = Llama(model_path="/models/final-Physics_llama3.gguf", use_mmap=False, use_mlock=True) | |
# MIllm = Llama(model_path="/models/final-LlamaTuna_Q8_0.gguf", use_mmap=False, use_mlock=True) | |
# n_gpu_layers=28, # Uncomment to use GPU acceleration | |
# seed=1337, # Uncomment to set a specific seed | |
# n_ctx=2048, # Uncomment to increase the context window | |
#) | |
print("Loading Translators.") | |
from pythainlp.translate.en_th import EnThTranslator, ThEnTranslator | |
t = EnThTranslator() | |
e = ThEnTranslator() | |
def extract_restext(response, is_chat=False): | |
return response['choices'][0]['text' if is_chat else 'message'].strip() | |
def ask_llama(llm: Llama, question: str, max_new_tokens=200, temperature=0.5, repeat_penalty=2.0): | |
result = extract_restext(llm.create_chat_completion({"role": "user", "content": question}, max_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty, stop=["<|eot_id|>", "<|end_of_text|>"]), is_chat=True) | |
return result | |
def chat_llama(llm: Llama, chat_history: dict, max_new_tokens=200, temperature=0.5, repeat_penalty=2.0): | |
result = extract_restext(llm.create_chat_completion(chat_history, max_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty, stop=["<|eot_id|>", "<|end_of_text|>"]), is_chat=True) | |
return result | |
# TESTING THE MODEL | |
print("Testing model...") | |
assert ask_llama(PHllm, ["Hello!, How are you today?"], max_new_tokens=5) #Just checking that it can run | |
print("Checking Translators.") | |
assert t.translate("Hello!") == "สวัสดี!" | |
assert e.translate("สวัสดี!") == "Hello!" | |
print("Ready.") | |
# START OF FASTAPI APP | |
app = FastAPI( | |
title = "Gemma Finetuned API", | |
description="Gemma Finetuned API for Thai Open-ended question answering.", | |
version="1.0.0", | |
) | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"] | |
) | |
# API DATA CLASSES | |
class QuestionResponse(BaseModel): | |
code: int = 200 | |
question: Optional[str] = None | |
answer: str = None | |
config: Optional[dict] = None | |
class ChatHistoryResponse(BaseModel): | |
code: int = 200 | |
chat_history: Dict[str] = None | |
answer: str = None | |
config: Optional[dict] = None | |
class LlamaChatMessage(BaseModel): | |
role: Literal["user", "assistant"] | |
content: str | |
# API ROUTES | |
def docs(): | |
"Redirects the user from the main page to the docs." | |
return responses.RedirectResponse('./docs') | |
async def ask_gemmaPhysics( | |
prompt: str = Body(..., embed=True, example="Why do ice cream melt so fast?"), | |
temperature: float = Body(0.5, embed=True), | |
repeat_penalty: float = Body(1.0, embed=True), | |
max_new_tokens: int = Body(200, embed=True), | |
translate_from_thai: bool = Body(False, embed=True) | |
) -> QuestionResponse: | |
""" | |
Ask a finetuned Gemma an physics question. | |
NOTICE: Answers may be random / inaccurate. Always do your research & confirm its responses before doing anything. | |
""" | |
if prompt: | |
try: | |
print(f'Asking LlamaPhysics with the question "{prompt}", translation is {"enabled" if translate_from_thai else "disabled"}') | |
if translate_from_thai: | |
prompt = e.translate(prompt) | |
result = ask_llama(PHllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty) | |
print(f"Result: {result}") | |
if translate_from_thai: | |
result = t.translate(result) | |
return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens, "repeat_penalty": repeat_penalty}) | |
except Exception as e: | |
return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt)) | |
else: | |
return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided.")) | |
# @app.post('/chat/multiturn') | |
# async def ask_llama3_Tuna( | |
# chat_history: List[LlamaChatMessage] = Body(..., embed=True), | |
# temperature: float = Body(0.5, embed=True), | |
# repeat_penalty: float = Body(2.0, embed=True), | |
# max_new_tokens: int = Body(200, embed=True) | |
# ) -> ChatHistoryResponse: | |
# """ | |
# Chat with a finetuned Llama-3 model (in Thai). | |
# Answers may be random / inaccurate. Always do your research & confirm its responses before doing anything. | |
# NOTICE: YOU MUST APPLY THE LLAMA3 PROMPT YOURSELF! | |
# """ | |
# if chat_history: | |
# try: | |
# print(f'Asking Llama3Tuna with the question "{chat_history}"') | |
# result = chat_llama(MIllm, chat_history, max_new_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty) | |
# print(f"Result: {result}") | |
# return ChatHistoryResponse(answer=result, config={"temperature": temperature, "max_new_tokens": max_new_tokens, "repeat_penalty": repeat_penalty}) | |
# except Exception as e: | |
# return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=chat_history)) | |
# else: | |
# return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided.")) | |