Spaces:
Sleeping
Sleeping
File size: 5,711 Bytes
4926347 0c228e3 4926347 7afebb8 4926347 ce36f28 4926347 2ba1842 43bf3d6 4926347 33d6214 4926347 befe899 7a0ef1d 8b750c3 7a0ef1d 8b750c3 33d6214 befe899 7afebb8 33d6214 ce36f28 4926347 7a0ef1d 43bf3d6 4926347 ce36f28 4926347 ce36f28 4926347 ce36f28 7afebb8 7a0ef1d 7afebb8 33d6214 7afebb8 7a0ef1d 33d6214 7afebb8 ce36f28 4926347 ce36f28 7a0ef1d 7afebb8 33d6214 7afebb8 7a0ef1d 7afebb8 7a0ef1d 7afebb8 7a0ef1d 33d6214 ce36f28 33d6214 0c228e3 7a0ef1d 33d6214 7a0ef1d 33d6214 7a0ef1d 33d6214 7a0ef1d 33d6214 7a0ef1d 43bf3d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
#!/usr/bin/env python
# coding: utf-8
from os import listdir
from os.path import isdir
from fastapi import FastAPI, HTTPException, Request, responses, Body
from fastapi.middleware.cors import CORSMiddleware
from llama_cpp import Llama
from pydantic import BaseModel
from enum import Enum
from typing import Optional
# MODEL LOADING, FUNCTIONS, AND TESTING
print("Loading model...")
SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf", use_mmap=False, use_mlock=True)
FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", use_mmap=False, use_mlock=True)
# WIllm = Llama(model_path="/models/final-GemmaWild7b-Q8_0.gguf", use_mmap=False, use_mlock=True)
# n_gpu_layers=28, # Uncomment to use GPU acceleration
# seed=1337, # Uncomment to set a specific seed
# n_ctx=2048, # Uncomment to increase the context window
#)
def extract_restext(response):
return response['choices'][0]['text'].strip()
def ask_llm(llm, question, max_new_tokens=200, temperature=0.5):
prompt = f"""###User: {question}\n###Assistant:"""
result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False))
return result
def check_sentiment(text):
prompt = f'Analyze the sentiment of the tweet enclosed in square brackets, determine if it is positive or negative, and return the answer as the corresponding sentiment label "positive" or "negative" [{text}] ='
response = SAllm(prompt, max_tokens=3, stop=["\n"], echo=False, temperature=0.5)
# print(response)
result = extract_restext(response)
if "positive" in result:
return "positive"
elif "negative" in result:
return "negative"
else:
return "unknown"
# TESTING THE MODEL
print("Testing model...")
assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
assert ask_llm(FIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
# assert ask_llm(WIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
print("Ready.")
# START OF FASTAPI APP
app = FastAPI(
title = "Gemma Finetuned API",
description="Gemma Finetuned API for Sentiment Analysis and Finance Questions.",
version="1.0.0",
)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
# API DATA CLASSES
class SA_Result(str, Enum):
positive = "positive"
negative = "negative"
unknown = "unknown"
class SAResponse(BaseModel):
code: int = 200
text: Optional[str] = None
result: SA_Result = None
class QuestionResponse(BaseModel):
code: int = 200
question: Optional[str] = None
answer: str = None
config: Optional[dict] = None
# API ROUTES
@app.get('/')
def docs():
"Redirects the user from the main page to the docs."
return responses.RedirectResponse('./docs')
@app.post('/classifications/sentiment')
async def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SAResponse:
"""Performs a sentiment analysis using a finetuned version of Gemma-7b"""
if prompt:
try:
print(f"Checking sentiment for {prompt}")
result = check_sentiment(prompt)
print(f"Result: {result}")
return SAResponse(result=result, text=prompt)
except Exception as e:
return HTTPException(500, SAResponse(code=500, result=str(e), text=prompt))
else:
return HTTPException(400, SAResponse(code=400, result="Request argument 'prompt' not provided."))
@app.post('/questions/finance')
async def ask_gemmaFinanceTH(
prompt: str = Body(..., embed=True, example="What's the best way to invest my money"),
temperature: float = Body(0.5, embed=True),
max_new_tokens: int = Body(200, embed=True)
) -> QuestionResponse:
"""
Ask a finetuned Gemma a finance-related question, just for fun.
NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
"""
if prompt:
try:
print(f'Asking GemmaFinance with the question "{prompt}"')
result = ask_llm(FIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
print(f"Result: {result}")
return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
except Exception as e:
return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
else:
return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
# @app.post('/questions/open-ended')
# async def ask_gemmaWild(
# prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"),
# temperature: float = Body(0.5, embed=True),
# max_new_tokens: int = Body(200, embed=True)
# ) -> QuestionResponse:
# """
# Ask a finetuned Gemma an open-ended question..
# NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
# """
# if prompt:
# try:
# print(f'Asking GemmaWild with the question "{prompt}"')
# result = ask_llm(WIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
# print(f"Result: {result}")
# return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
# except Exception as e:
# return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
# else:
# return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
|