File size: 5,711 Bytes
4926347
 
 
 
0c228e3
4926347
 
7afebb8
 
 
 
4926347
ce36f28
 
4926347
2ba1842
 
43bf3d6
4926347
 
 
33d6214
4926347
befe899
 
 
7a0ef1d
8b750c3
7a0ef1d
8b750c3
 
33d6214
 
 
 
befe899
7afebb8
 
 
 
 
 
33d6214
ce36f28
4926347
 
7a0ef1d
43bf3d6
4926347
 
ce36f28
 
4926347
ce36f28
 
4926347
 
 
 
 
 
 
 
 
 
 
 
ce36f28
 
7afebb8
 
 
 
 
7a0ef1d
7afebb8
33d6214
7afebb8
 
7a0ef1d
33d6214
 
 
 
7afebb8
ce36f28
 
4926347
 
 
 
 
ce36f28
7a0ef1d
7afebb8
 
 
 
33d6214
7afebb8
7a0ef1d
7afebb8
7a0ef1d
7afebb8
7a0ef1d
33d6214
 
ce36f28
 
33d6214
0c228e3
 
7a0ef1d
33d6214
 
 
 
 
 
7a0ef1d
 
33d6214
7a0ef1d
33d6214
7a0ef1d
33d6214
7a0ef1d
 
 
43bf3d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python
# coding: utf-8
from os import listdir
from os.path import isdir
from fastapi import FastAPI, HTTPException, Request, responses, Body
from fastapi.middleware.cors import CORSMiddleware
from llama_cpp import Llama

from pydantic import BaseModel
from enum import Enum
from typing import Optional

# MODEL LOADING, FUNCTIONS, AND TESTING

print("Loading model...")
SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf", use_mmap=False, use_mlock=True)
FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", use_mmap=False, use_mlock=True)
# WIllm = Llama(model_path="/models/final-GemmaWild7b-Q8_0.gguf", use_mmap=False, use_mlock=True)
      # n_gpu_layers=28, # Uncomment to use GPU acceleration
      # seed=1337, # Uncomment to set a specific seed
      # n_ctx=2048, # Uncomment to increase the context window
#)

def extract_restext(response):
  return response['choices'][0]['text'].strip()

def ask_llm(llm, question, max_new_tokens=200, temperature=0.5):
  prompt = f"""###User: {question}\n###Assistant:"""
  result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False))
  return result
    
def check_sentiment(text):
  prompt = f'Analyze the sentiment of the tweet enclosed in square brackets, determine if it is positive or negative, and return the answer as the corresponding sentiment label "positive" or  "negative" [{text}] ='
  response = SAllm(prompt, max_tokens=3, stop=["\n"], echo=False, temperature=0.5)
  # print(response)
  result = extract_restext(response)
  if "positive" in result:
    return "positive"
  elif "negative" in result:
    return "negative"
  else:
    return "unknown"

# TESTING THE MODEL
print("Testing model...")
assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
assert ask_llm(FIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
# assert ask_llm(WIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
print("Ready.")


# START OF FASTAPI APP
app = FastAPI(
    title = "Gemma Finetuned API",
    description="Gemma Finetuned API for Sentiment Analysis and Finance Questions.",
    version="1.0.0",
)

origins = ["*"]
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"]
)


# API DATA CLASSES
class SA_Result(str, Enum):
  positive = "positive"
  negative = "negative"
  unknown = "unknown"

class SAResponse(BaseModel):
  code: int = 200
  text: Optional[str] = None
  result: SA_Result = None

class QuestionResponse(BaseModel):
  code: int = 200
  question: Optional[str] = None
  answer: str = None
  config: Optional[dict] = None


# API ROUTES
@app.get('/')
def docs():
  "Redirects the user from the main page to the docs."
  return responses.RedirectResponse('./docs')

@app.post('/classifications/sentiment')
async def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SAResponse:
  """Performs a sentiment analysis using a finetuned version of Gemma-7b"""
  if prompt:
    try:
      print(f"Checking sentiment for {prompt}")
      result = check_sentiment(prompt)
      print(f"Result: {result}")
      return SAResponse(result=result, text=prompt)
    except Exception as e:
      return HTTPException(500, SAResponse(code=500, result=str(e), text=prompt))
  else:
    return HTTPException(400, SAResponse(code=400, result="Request argument 'prompt' not provided."))


@app.post('/questions/finance')
async def ask_gemmaFinanceTH(
    prompt: str = Body(..., embed=True, example="What's the best way to invest my money"),
    temperature: float = Body(0.5, embed=True), 
    max_new_tokens: int = Body(200, embed=True)
) -> QuestionResponse:
  """
  Ask a finetuned Gemma a finance-related question, just for fun.
  NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
  """
  if prompt:
    try:
      print(f'Asking GemmaFinance with the question "{prompt}"')
      result = ask_llm(FIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
      print(f"Result: {result}")
      return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
    except Exception as e:
      return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
  else:
    return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
  

# @app.post('/questions/open-ended')
# async def ask_gemmaWild(
#     prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"),
#     temperature: float = Body(0.5, embed=True), 
#     max_new_tokens: int = Body(200, embed=True)
# ) -> QuestionResponse:
#   """
#   Ask a finetuned Gemma an open-ended question..
#   NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
#   """
#   if prompt:
#     try:
#       print(f'Asking GemmaWild with the question "{prompt}"')
#       result = ask_llm(WIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
#       print(f"Result: {result}")
#       return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
#     except Exception as e:
#       return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
#   else:
#     return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))