Llama3_Physics / app /main.py
PawinC's picture
Update app/main.py
c578562 verified
#!/usr/bin/env python
# coding: utf-8
from os import listdir
from os.path import isdir
from fastapi import FastAPI, HTTPException, Request, responses, Body
from fastapi.middleware.cors import CORSMiddleware
from llama_cpp import Llama
from pydantic import BaseModel
from enum import Enum
from typing import Optional, Literal, Dict, List
# MODEL LOADING, FUNCTIONS, AND TESTING
print("Loading model...")
PHllm = Llama(model_path="/models/final-Physics_llama3.gguf", use_mmap=False, use_mlock=True)
# MIllm = Llama(model_path="/models/final-LlamaTuna_Q8_0.gguf", use_mmap=False, use_mlock=True)
# n_gpu_layers=28, # Uncomment to use GPU acceleration
# seed=1337, # Uncomment to set a specific seed
# n_ctx=2048, # Uncomment to increase the context window
#)
print("Loading Translators.")
from pythainlp.translate.en_th import EnThTranslator, ThEnTranslator
t = EnThTranslator()
e = ThEnTranslator()
def extract_restext(response, is_chat=False):
return response['choices'][0]['message' if is_chat else 'text'].strip()
def ask_llama(llm: Llama, question: str, max_new_tokens=200, temperature=0.5, repeat_penalty=2.0):
prompt = f"""<|begin_of_text|>
<|start_header_id|> user <|end_header_id|> {question} <|eot_id|>
<|start_header_id|> assistant <|end_header_id|>"""
result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty, stop=["<|eot_id|>", "<|end_of_text|>"])).replace("<|eot_id|>", "").replace("<|end_of_text|>", "")
return result
# def chat_llama(llm: Llama, chat_history: dict, max_new_tokens=200, temperature=0.5, repeat_penalty=2.0):
# result = extract_restext(llm.create_chat_completion(chat_history, max_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty, stop=["<|eot_id|>", "<|end_of_text|>"]), is_chat=True)
# return result
# TESTING THE MODEL
print("Testing model...")
assert ask_llama(PHllm, "Hello!, How are you today?", max_new_tokens=5) #Just checking that it can run
print("Checking Translators.")
assert t.translate("Hello!") == "สวัสดี!"
assert e.translate("สวัสดี!") == "Hello!"
print("Ready.")
# START OF FASTAPI APP
app = FastAPI(
title = "Gemma Finetuned API",
description="Gemma Finetuned API for Thai Open-ended question answering.",
version="1.0.0",
)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
# API DATA CLASSES
class QuestionResponse(BaseModel):
code: int = 200
question: Optional[str] = None
answer: str = None
config: Optional[dict] = None
class ChatHistoryResponse(BaseModel):
code: int = 200
chat_history: Dict[str, str] = None
answer: str = None
config: Optional[dict] = None
class LlamaChatMessage(BaseModel):
role: Literal["user", "assistant"]
content: str
# API ROUTES
@app.get('/')
def docs():
"Redirects the user from the main page to the docs."
return responses.RedirectResponse('./docs')
@app.post('/questions/physics')
async def ask_gemmaPhysics(
prompt: str = Body(..., embed=True, example="Why do ice cream melt so fast?"),
temperature: float = Body(0.5, embed=True),
repeat_penalty: float = Body(1.0, embed=True),
max_new_tokens: int = Body(200, embed=True),
translate_from_thai: bool = Body(False, embed=True)
) -> QuestionResponse:
"""
Ask a finetuned Gemma an physics question.
NOTICE: Answers may be random / inaccurate. Always do your research & confirm its responses before doing anything.
"""
if prompt:
try:
print(f'Asking LlamaPhysics with the question "{prompt}", translation is {"enabled" if translate_from_thai else "disabled"}')
if translate_from_thai:
print("Translating content to EN.")
prompt = e.translate(prompt)
print(f"Asking the model with the question {prompt}")
result = ask_llama(PHllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty)
print(f"Got Model Response: {result}")
if translate_from_thai:
result = t.translate(result)
print(f"Translation Result: {result}")
return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens, "repeat_penalty": repeat_penalty})
except Exception as e:
return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
else:
return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
# @app.post('/chat/multiturn')
# async def ask_llama3_Tuna(
# chat_history: List[LlamaChatMessage] = Body(..., embed=True),
# temperature: float = Body(0.5, embed=True),
# repeat_penalty: float = Body(2.0, embed=True),
# max_new_tokens: int = Body(200, embed=True)
# ) -> ChatHistoryResponse:
# """
# Chat with a finetuned Llama-3 model (in Thai).
# Answers may be random / inaccurate. Always do your research & confirm its responses before doing anything.
# NOTICE: YOU MUST APPLY THE LLAMA3 PROMPT YOURSELF!
# """
# if chat_history:
# try:
# print(f'Asking Llama3Tuna with the question "{chat_history}"')
# result = chat_llama(MIllm, chat_history, max_new_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty)
# print(f"Result: {result}")
# return ChatHistoryResponse(answer=result, config={"temperature": temperature, "max_new_tokens": max_new_tokens, "repeat_penalty": repeat_penalty})
# except Exception as e:
# return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=chat_history))
# else:
# return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))