Spaces:
Sleeping
Sleeping
File size: 5,827 Bytes
4926347 0c228e3 4926347 7afebb8 f5fdf38 4926347 ce36f28 4926347 f5fdf38 4926347 33d6214 4926347 f5fdf38 befe899 f5fdf38 46411ab f5fdf38 85e5dc5 254fd05 f5fdf38 85e5dc5 33d6214 ce36f28 4926347 85e5dc5 f5fdf38 4926347 ce36f28 4926347 ce36f28 f5fdf38 4926347 ce36f28 7a0ef1d 33d6214 7afebb8 f5fdf38 2186398 f5fdf38 ce36f28 4926347 f5fdf38 0c228e3 f5fdf38 7a0ef1d 33d6214 f5fdf38 33d6214 f5fdf38 c578562 f5fdf38 c578562 f5fdf38 c578562 f5fdf38 c578562 f5fdf38 33d6214 7a0ef1d 33d6214 7a0ef1d f5fdf38 43bf3d6 f5fdf38 43bf3d6 f5fdf38 43bf3d6 f5fdf38 43bf3d6 f5fdf38 43bf3d6 f5fdf38 43bf3d6 f5fdf38 43bf3d6 f5fdf38 43bf3d6 f5fdf38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
#!/usr/bin/env python
# coding: utf-8
from os import listdir
from os.path import isdir
from fastapi import FastAPI, HTTPException, Request, responses, Body
from fastapi.middleware.cors import CORSMiddleware
from llama_cpp import Llama
from pydantic import BaseModel
from enum import Enum
from typing import Optional, Literal, Dict, List
# MODEL LOADING, FUNCTIONS, AND TESTING
print("Loading model...")
PHllm = Llama(model_path="/models/final-Physics_llama3.gguf", use_mmap=False, use_mlock=True)
# MIllm = Llama(model_path="/models/final-LlamaTuna_Q8_0.gguf", use_mmap=False, use_mlock=True)
# n_gpu_layers=28, # Uncomment to use GPU acceleration
# seed=1337, # Uncomment to set a specific seed
# n_ctx=2048, # Uncomment to increase the context window
#)
print("Loading Translators.")
from pythainlp.translate.en_th import EnThTranslator, ThEnTranslator
t = EnThTranslator()
e = ThEnTranslator()
def extract_restext(response, is_chat=False):
return response['choices'][0]['message' if is_chat else 'text'].strip()
def ask_llama(llm: Llama, question: str, max_new_tokens=200, temperature=0.5, repeat_penalty=2.0):
prompt = f"""<|begin_of_text|>
<|start_header_id|> user <|end_header_id|> {question} <|eot_id|>
<|start_header_id|> assistant <|end_header_id|>"""
result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty, stop=["<|eot_id|>", "<|end_of_text|>"])).replace("<|eot_id|>", "").replace("<|end_of_text|>", "")
return result
# def chat_llama(llm: Llama, chat_history: dict, max_new_tokens=200, temperature=0.5, repeat_penalty=2.0):
# result = extract_restext(llm.create_chat_completion(chat_history, max_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty, stop=["<|eot_id|>", "<|end_of_text|>"]), is_chat=True)
# return result
# TESTING THE MODEL
print("Testing model...")
assert ask_llama(PHllm, "Hello!, How are you today?", max_new_tokens=5) #Just checking that it can run
print("Checking Translators.")
assert t.translate("Hello!") == "สวัสดี!"
assert e.translate("สวัสดี!") == "Hello!"
print("Ready.")
# START OF FASTAPI APP
app = FastAPI(
title = "Gemma Finetuned API",
description="Gemma Finetuned API for Thai Open-ended question answering.",
version="1.0.0",
)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
# API DATA CLASSES
class QuestionResponse(BaseModel):
code: int = 200
question: Optional[str] = None
answer: str = None
config: Optional[dict] = None
class ChatHistoryResponse(BaseModel):
code: int = 200
chat_history: Dict[str, str] = None
answer: str = None
config: Optional[dict] = None
class LlamaChatMessage(BaseModel):
role: Literal["user", "assistant"]
content: str
# API ROUTES
@app.get('/')
def docs():
"Redirects the user from the main page to the docs."
return responses.RedirectResponse('./docs')
@app.post('/questions/physics')
async def ask_gemmaPhysics(
prompt: str = Body(..., embed=True, example="Why do ice cream melt so fast?"),
temperature: float = Body(0.5, embed=True),
repeat_penalty: float = Body(1.0, embed=True),
max_new_tokens: int = Body(200, embed=True),
translate_from_thai: bool = Body(False, embed=True)
) -> QuestionResponse:
"""
Ask a finetuned Gemma an physics question.
NOTICE: Answers may be random / inaccurate. Always do your research & confirm its responses before doing anything.
"""
if prompt:
try:
print(f'Asking LlamaPhysics with the question "{prompt}", translation is {"enabled" if translate_from_thai else "disabled"}')
if translate_from_thai:
print("Translating content to EN.")
prompt = e.translate(prompt)
print(f"Asking the model with the question {prompt}")
result = ask_llama(PHllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty)
print(f"Got Model Response: {result}")
if translate_from_thai:
result = t.translate(result)
print(f"Translation Result: {result}")
return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens, "repeat_penalty": repeat_penalty})
except Exception as e:
return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
else:
return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
# @app.post('/chat/multiturn')
# async def ask_llama3_Tuna(
# chat_history: List[LlamaChatMessage] = Body(..., embed=True),
# temperature: float = Body(0.5, embed=True),
# repeat_penalty: float = Body(2.0, embed=True),
# max_new_tokens: int = Body(200, embed=True)
# ) -> ChatHistoryResponse:
# """
# Chat with a finetuned Llama-3 model (in Thai).
# Answers may be random / inaccurate. Always do your research & confirm its responses before doing anything.
# NOTICE: YOU MUST APPLY THE LLAMA3 PROMPT YOURSELF!
# """
# if chat_history:
# try:
# print(f'Asking Llama3Tuna with the question "{chat_history}"')
# result = chat_llama(MIllm, chat_history, max_new_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty)
# print(f"Result: {result}")
# return ChatHistoryResponse(answer=result, config={"temperature": temperature, "max_new_tokens": max_new_tokens, "repeat_penalty": repeat_penalty})
# except Exception as e:
# return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=chat_history))
# else:
# return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
|