Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
from os import listdir | |
from os.path import isdir | |
from fastapi import FastAPI, HTTPException, Request, responses, Body | |
from fastapi.middleware.cors import CORSMiddleware | |
from llama_cpp import Llama | |
from pydantic import BaseModel | |
from enum import Enum | |
from typing import Optional | |
# MODEL LOADING, FUNCTIONS, AND TESTING | |
print("Loading model...") | |
SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf", mmap=False, mlock=True) | |
FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", mmap=False, mlock=True) | |
# n_gpu_layers=28, # Uncomment to use GPU acceleration | |
# seed=1337, # Uncomment to set a specific seed | |
# n_ctx=2048, # Uncomment to increase the context window | |
#) | |
def extract_restext(response): | |
return response['choices'][0]['text'].strip() | |
def ask_fi(question, max_new_tokens=200, temperature=0.5): | |
prompt = f"""###User: {question}\n###Assistant:""" | |
result = extract_restext(FIllm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False)) | |
return result | |
def check_sentiment(text): | |
prompt = f'Analyze the sentiment of the tweet enclosed in square brackets, determine if it is positive or negative, and return the answer as the corresponding sentiment label "positive" or "negative" [{text}] =' | |
response = SAllm(prompt, max_tokens=3, stop=["\n"], echo=False, temperature=0.5) | |
# print(response) | |
result = extract_restext(response) | |
if "positive" in result: | |
return "positive" | |
elif "negative" in result: | |
return "negative" | |
else: | |
return "unknown" | |
# TESTING THE MODEL | |
print("Testing model...") | |
assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง") | |
assert ask_fi("Hello!, How are you today?") | |
print("Ready.") | |
# START OF FASTAPI APP | |
app = FastAPI( | |
title = "Gemma Finetuned API", | |
description="Gemma Finetuned API for Sentiment Analysis and Finance Questions.", | |
version="1.0.0", | |
) | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"] | |
) | |
# API DATA CLASSES | |
class SA_Result(str, Enum): | |
positive = "positive" | |
negative = "negative" | |
unknown = "unknown" | |
class SA_Response(BaseModel): | |
code: int = 200 | |
text: Optional[str] = None | |
result: SA_Result = None | |
class FI_Response(BaseModel): | |
code: int = 200 | |
question: Optional[str] = None | |
answer: str = None | |
config: Optional[dict] = None | |
# API ROUTES | |
def docs(): | |
"Redirects the user from the main page to the docs." | |
return responses.RedirectResponse('./docs') | |
async def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SA_Response: | |
"""Performs a sentiment analysis using a finetuned version of Gemma-7b""" | |
if prompt: | |
try: | |
print(f"Checking sentiment for {prompt}") | |
result = check_sentiment(prompt) | |
print(f"Result: {result}") | |
return SA_Response(result=result, text=prompt) | |
except Exception as e: | |
return HTTPException(500, SA_Response(code=500, result=str(e), text=prompt)) | |
else: | |
return HTTPException(400, SA_Response(code=400, result="Request argument 'prompt' not provided.")) | |
async def ask_gemmaFinanceTH( | |
prompt: str = Body(..., embed=True, example="What's the best way to invest my money"), | |
temperature: float = Body(0.5, embed=True), | |
max_new_tokens: int = Body(200, embed=True) | |
) -> FI_Response: | |
""" | |
Ask a finetuned Gemma a finance-related question, just for fun. | |
NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS. | |
""" | |
if prompt: | |
try: | |
print(f'Asking FI with the question "{prompt}"') | |
result = ask_fi(prompt, max_new_tokens=max_new_tokens, temperature=temperature) | |
print(f"Result: {result}") | |
return FI_Response(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens}) | |
except Exception as e: | |
return HTTPException(500, FI_Response(code=500, answer=str(e), question=prompt)) | |
else: | |
return HTTPException(400, FI_Response(code=400, answer="Request argument 'prompt' not provided.")) | |