Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
from os import listdir | |
from os.path import isdir | |
from fastapi import FastAPI, HTTPException, Request, responses, Body | |
from fastapi.middleware.cors import CORSMiddleware | |
from llama_cpp import Llama | |
from pydantic import BaseModel | |
from enum import Enum | |
from typing import Optional | |
print("Loading model...") | |
SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf")#, | |
# n_gpu_layers=28, # Uncomment to use GPU acceleration | |
# seed=1337, # Uncomment to set a specific seed | |
# n_ctx=2048, # Uncomment to increase the context window | |
#) | |
# FIllm = Llama(model_path="/models/final-gemma2b_FI-Q8_0.gguf") | |
# def ask(question, max_new_tokens=200): | |
# output = llm( | |
# question, # Prompt | |
# max_tokens=max_new_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window | |
# stop=["\n"], # Stop generating just before the model would generate a new question | |
# echo=False, # Echo the prompt back in the output | |
# temperature=0.0, | |
# ) | |
# return output | |
def extract_restext(response): | |
return response['choices'][0]['text'].strip() | |
def check_sentiment(text): | |
prompt = f'Analyze the sentiment of the tweet enclosed in square brackets, determine if it is positive or negative, and return the answer as the corresponding sentiment label "positive" or "negative" [{text}] =' | |
response = SAllm(prompt, max_tokens=3, stop=["\n"], echo=False, temperature=0.5) | |
# print(response) | |
result = extract_restext(response) | |
if "positive" in result: | |
return "positive" | |
elif "negative" in result: | |
return "negative" | |
else: | |
return "unknown" | |
print("Testing model...") | |
assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง") | |
print("Ready.") | |
app = FastAPI( | |
title = "GemmaSA_2b", | |
description="A simple sentiment analysis API for the Thai language, powered by a finetuned version of Gemma-2b", | |
version="1.0.0", | |
) | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"] | |
) | |
class SA_Result(str, Enum): | |
positive = "positive" | |
negative = "negative" | |
unknown = "unknown" | |
class SA_Response(BaseModel): | |
code: int = 200 | |
text: Optional[str] = None | |
result: SA_Result = None | |
class FI_Response(BaseModel): | |
code: int = 200 | |
question: Optional[str] = None | |
answer: str = None | |
config: Optional[dict] = None | |
def docs(): | |
"Redirects the user from the main page to the docs." | |
return responses.RedirectResponse('./docs') | |
def add(a: int,b: int): | |
return a + b | |
def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SA_Response: | |
"""Performs a sentiment analysis using a finetuned version of Gemma-7b""" | |
if prompt: | |
try: | |
print(f"Checking sentiment for {prompt}") | |
result = check_sentiment(prompt) | |
print(f"Result: {result}") | |
return SA_Response(result=result, text=prompt) | |
except Exception as e: | |
return HTTPException(500, SA_Response(code=500, result=str(e), text=prompt)) | |
else: | |
return HTTPException(400, SA_Response(code=400, result="Request argument 'prompt' not provided.")) | |
def ask_gemmaFinanceTH( | |
prompt: str = Body(..., embed=True, example="What's the best way to invest my money"), | |
temperature: float = Body(0.5, embed=True), | |
max_new_tokens: int = Body(200, embed=True) | |
) -> FI_Response: | |
""" | |
Ask a finetuned Gemma a finance-related question, just for fun. | |
NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS. | |
""" | |
if prompt: | |
try: | |
print(f'Asking FI with the question "{prompt}"') | |
prompt = f"""###User: {prompt}\n###Assistant:""" | |
result = extract_restext(FIllm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False)) | |
print(f"Result: {result}") | |
return FI_Response(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens}) | |
except Exception as e: | |
return HTTPException(500, FI_Response(code=500, answer=str(e), question=prompt)) | |
else: | |
return HTTPException(400, FI_Response(code=400, answer="Request argument 'prompt' not provided.")) | |