PawinC's picture
Upload main.py
0c228e3 verified
raw
history blame
4.41 kB
#!/usr/bin/env python
# coding: utf-8
from os import listdir
from os.path import isdir
from fastapi import FastAPI, HTTPException, Request, responses, Body
from fastapi.middleware.cors import CORSMiddleware
from llama_cpp import Llama
from pydantic import BaseModel
from enum import Enum
from typing import Optional
print("Loading model...")
SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf")#,
# n_gpu_layers=28, # Uncomment to use GPU acceleration
# seed=1337, # Uncomment to set a specific seed
# n_ctx=2048, # Uncomment to increase the context window
#)
# FIllm = Llama(model_path="/models/final-gemma2b_FI-Q8_0.gguf")
# def ask(question, max_new_tokens=200):
# output = llm(
# question, # Prompt
# max_tokens=max_new_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window
# stop=["\n"], # Stop generating just before the model would generate a new question
# echo=False, # Echo the prompt back in the output
# temperature=0.0,
# )
# return output
def extract_restext(response):
return response['choices'][0]['text'].strip()
def check_sentiment(text):
prompt = f'Analyze the sentiment of the tweet enclosed in square brackets, determine if it is positive or negative, and return the answer as the corresponding sentiment label "positive" or "negative" [{text}] ='
response = SAllm(prompt, max_tokens=3, stop=["\n"], echo=False, temperature=0.5)
# print(response)
result = extract_restext(response)
if "positive" in result:
return "positive"
elif "negative" in result:
return "negative"
else:
return "unknown"
print("Testing model...")
assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
print("Ready.")
app = FastAPI(
title = "GemmaSA_2b",
description="A simple sentiment analysis API for the Thai language, powered by a finetuned version of Gemma-2b",
version="1.0.0",
)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
class SA_Result(str, Enum):
positive = "positive"
negative = "negative"
unknown = "unknown"
class SA_Response(BaseModel):
code: int = 200
text: Optional[str] = None
result: SA_Result = None
class FI_Response(BaseModel):
code: int = 200
question: Optional[str] = None
answer: str = None
config: Optional[dict] = None
@app.get('/')
def docs():
"Redirects the user from the main page to the docs."
return responses.RedirectResponse('./docs')
@app.get('/add/{a}/{b}')
def add(a: int,b: int):
return a + b
@app.post('/SA')
def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SA_Response:
"""Performs a sentiment analysis using a finetuned version of Gemma-7b"""
if prompt:
try:
print(f"Checking sentiment for {prompt}")
result = check_sentiment(prompt)
print(f"Result: {result}")
return SA_Response(result=result, text=prompt)
except Exception as e:
return HTTPException(500, SA_Response(code=500, result=str(e), text=prompt))
else:
return HTTPException(400, SA_Response(code=400, result="Request argument 'prompt' not provided."))
@app.post('/FI')
def ask_gemmaFinanceTH(
prompt: str = Body(..., embed=True, example="What's the best way to invest my money"),
temperature: float = Body(0.5, embed=True),
max_new_tokens: int = Body(200, embed=True)
) -> FI_Response:
"""
Ask a finetuned Gemma a finance-related question, just for fun.
NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
"""
if prompt:
try:
print(f'Asking FI with the question "{prompt}"')
prompt = f"""###User: {prompt}\n###Assistant:"""
result = extract_restext(FIllm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False))
print(f"Result: {result}")
return FI_Response(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
except Exception as e:
return HTTPException(500, FI_Response(code=500, answer=str(e), question=prompt))
else:
return HTTPException(400, FI_Response(code=400, answer="Request argument 'prompt' not provided."))