File size: 5,827 Bytes
4926347
 
 
 
0c228e3
4926347
 
7afebb8
 
 
f5fdf38
4926347
ce36f28
 
4926347
f5fdf38
 
4926347
 
 
33d6214
4926347
f5fdf38
 
 
 
befe899
f5fdf38
46411ab
f5fdf38
 
85e5dc5
 
 
254fd05
f5fdf38
 
85e5dc5
 
 
33d6214
ce36f28
4926347
85e5dc5
f5fdf38
 
 
4926347
 
ce36f28
 
4926347
ce36f28
f5fdf38
4926347
 
 
 
 
 
 
 
 
 
 
 
ce36f28
 
7a0ef1d
33d6214
 
 
 
7afebb8
f5fdf38
 
2186398
f5fdf38
 
 
 
 
 
 
ce36f28
 
4926347
 
 
 
 
f5fdf38
 
 
0c228e3
f5fdf38
 
 
7a0ef1d
33d6214
f5fdf38
 
33d6214
 
 
f5fdf38
 
c578562
f5fdf38
c578562
f5fdf38
c578562
f5fdf38
 
c578562
f5fdf38
33d6214
7a0ef1d
33d6214
7a0ef1d
 
f5fdf38
 
 
 
43bf3d6
f5fdf38
43bf3d6
f5fdf38
43bf3d6
f5fdf38
 
 
43bf3d6
f5fdf38
43bf3d6
f5fdf38
 
43bf3d6
f5fdf38
43bf3d6
f5fdf38
43bf3d6
 
f5fdf38
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/env python
# coding: utf-8
from os import listdir
from os.path import isdir
from fastapi import FastAPI, HTTPException, Request, responses, Body
from fastapi.middleware.cors import CORSMiddleware
from llama_cpp import Llama

from pydantic import BaseModel
from enum import Enum
from typing import Optional, Literal, Dict, List

# MODEL LOADING, FUNCTIONS, AND TESTING

print("Loading model...")
PHllm = Llama(model_path="/models/final-Physics_llama3.gguf", use_mmap=False, use_mlock=True)
# MIllm = Llama(model_path="/models/final-LlamaTuna_Q8_0.gguf", use_mmap=False, use_mlock=True)
      # n_gpu_layers=28, # Uncomment to use GPU acceleration
      # seed=1337, # Uncomment to set a specific seed
      # n_ctx=2048, # Uncomment to increase the context window
#)

print("Loading Translators.")
from pythainlp.translate.en_th import EnThTranslator, ThEnTranslator
t = EnThTranslator()
e = ThEnTranslator()

def extract_restext(response, is_chat=False):
  return response['choices'][0]['message' if is_chat else 'text'].strip()

def ask_llama(llm: Llama, question: str, max_new_tokens=200, temperature=0.5, repeat_penalty=2.0):
  prompt = f"""<|begin_of_text|>
<|start_header_id|> user  <|end_header_id|> {question} <|eot_id|>
<|start_header_id|> assistant <|end_header_id|>"""
  result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty, stop=["<|eot_id|>", "<|end_of_text|>"])).replace("<|eot_id|>", "").replace("<|end_of_text|>", "")
  return result

# def chat_llama(llm: Llama, chat_history: dict, max_new_tokens=200, temperature=0.5, repeat_penalty=2.0):
#   result = extract_restext(llm.create_chat_completion(chat_history, max_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty, stop=["<|eot_id|>", "<|end_of_text|>"]), is_chat=True)
#   return result

# TESTING THE MODEL
print("Testing model...")
assert ask_llama(PHllm, "Hello!, How are you today?", max_new_tokens=5) #Just checking that it can run
print("Checking Translators.")
assert t.translate("Hello!") == "สวัสดี!"
assert e.translate("สวัสดี!") == "Hello!"
print("Ready.")


# START OF FASTAPI APP
app = FastAPI(
    title = "Gemma Finetuned API",
    description="Gemma Finetuned API for Thai Open-ended question answering.",
    version="1.0.0",
)

origins = ["*"]
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"]
)


# API DATA CLASSES
class QuestionResponse(BaseModel):
  code: int = 200
  question: Optional[str] = None
  answer: str = None
  config: Optional[dict] = None

class ChatHistoryResponse(BaseModel):
  code: int = 200
  chat_history: Dict[str, str] = None
  answer: str = None
  config: Optional[dict] = None

class LlamaChatMessage(BaseModel):
  role: Literal["user", "assistant"]
  content: str


# API ROUTES
@app.get('/')
def docs():
  "Redirects the user from the main page to the docs."
  return responses.RedirectResponse('./docs')

@app.post('/questions/physics')
async def ask_gemmaPhysics(
    prompt: str = Body(..., embed=True, example="Why do ice cream melt so fast?"),
    temperature: float = Body(0.5, embed=True), 
    repeat_penalty: float = Body(1.0, embed=True),
    max_new_tokens: int = Body(200, embed=True),
    translate_from_thai: bool = Body(False, embed=True)
) -> QuestionResponse:
  """
  Ask a finetuned Gemma an physics question. 
  NOTICE: Answers may be random / inaccurate. Always do your research & confirm its responses before doing anything.
  """
  if prompt:
    try:
      print(f'Asking LlamaPhysics with the question "{prompt}", translation is {"enabled" if translate_from_thai else "disabled"}')
      if translate_from_thai:
        print("Translating content to EN.")
        prompt = e.translate(prompt)
      print(f"Asking the model with the question {prompt}")
      result = ask_llama(PHllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty)
      print(f"Got Model Response: {result}")
      if translate_from_thai:
        result = t.translate(result)
        print(f"Translation Result: {result}")
      return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens, "repeat_penalty": repeat_penalty})
    except Exception as e:
      return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
  else:
    return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))


# @app.post('/chat/multiturn')
# async def ask_llama3_Tuna(
#     chat_history: List[LlamaChatMessage] = Body(..., embed=True),
#     temperature: float = Body(0.5, embed=True), 
#     repeat_penalty: float = Body(2.0, embed=True),
#     max_new_tokens: int = Body(200, embed=True)
# ) -> ChatHistoryResponse:
#   """
#   Chat with a finetuned Llama-3 model (in Thai).
#   Answers may be random / inaccurate. Always do your research & confirm its responses before doing anything.
#   NOTICE: YOU MUST APPLY THE LLAMA3 PROMPT YOURSELF!
#   """
#   if chat_history:
#     try:
#       print(f'Asking Llama3Tuna with the question "{chat_history}"')
#       result = chat_llama(MIllm, chat_history, max_new_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty)
#       print(f"Result: {result}")
#       return ChatHistoryResponse(answer=result, config={"temperature": temperature, "max_new_tokens": max_new_tokens, "repeat_penalty": repeat_penalty})
#     except Exception as e:
#       return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=chat_history))
#   else:
#     return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))