Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
from os import listdir | |
from os.path import isdir | |
from fastapi import FastAPI, HTTPException, Request, responses, Body | |
from fastapi.middleware.cors import CORSMiddleware | |
from llama_cpp import Llama | |
from pydantic import BaseModel | |
from enum import Enum | |
from typing import Optional | |
# MODEL LOADING, FUNCTIONS, AND TESTING | |
print("Loading model...") | |
WIllm = Llama(model_path="/models/final-GemmaWild7b-Q8_0.gguf", use_mmap=False, use_mlock=True) | |
COllm = Llama(model_path="/models/TunaCodes-Q8_0.gguf", use_mmap=False, use_mlock=True) | |
# n_gpu_layers=28, # Uncomment to use GPU acceleration | |
# seed=1337, # Uncomment to set a specific seed | |
# n_ctx=2048, # Uncomment to increase the context window | |
#) | |
def extract_restext(response): | |
return response['choices'][0]['text'].strip() | |
def ask_llm(llm, question, max_new_tokens=200, temperature=0.5): | |
prompt = f"""###User: {question}\n###Assistant:""" | |
result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False)) | |
return result | |
# TESTING THE MODEL | |
print("Testing model...") | |
assert ask_llm(WIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run | |
print("Ready.") | |
# START OF FASTAPI APP | |
app = FastAPI( | |
title = "Gemma Finetuned API", | |
description="Gemma Finetuned API for Open-ended and Coding questions.", | |
version="1.0.0", | |
) | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"] | |
) | |
# API DATA CLASSES | |
class SAResponse(BaseModel): | |
code: int = 200 | |
text: Optional[str] = None | |
result: SA_Result = None | |
class QuestionResponse(BaseModel): | |
code: int = 200 | |
question: Optional[str] = None | |
answer: str = None | |
config: Optional[dict] = None | |
# API ROUTES | |
def docs(): | |
"Redirects the user from the main page to the docs." | |
return responses.RedirectResponse('./docs') | |
async def ask_gemmaWild( | |
prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"), | |
temperature: float = Body(0.5, embed=True), | |
max_new_tokens: int = Body(200, embed=True) | |
) -> QuestionResponse: | |
""" | |
Ask a finetuned Gemma an open-ended question.. | |
NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS. | |
""" | |
if prompt: | |
try: | |
print(f'Asking GemmaWild with the question "{prompt}"') | |
result = ask_llm(WIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature) | |
print(f"Result: {result}") | |
return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens}) | |
except Exception as e: | |
return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt)) | |
else: | |
return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided.")) | |
async def ask_gemmaCode( | |
prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"), | |
temperature: float = Body(0.5, embed=True), | |
max_new_tokens: int = Body(200, embed=True) | |
) -> QuestionResponse: | |
""" | |
Ask a finetuned Gemma an open-ended question.. | |
NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS. | |
""" | |
if prompt: | |
try: | |
print(f'Asking GemmaCode with the question "{prompt}"') | |
result = ask_llm(COllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature) | |
print(f"Result: {result}") | |
return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens}) | |
except Exception as e: | |
return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt)) | |
else: | |
return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided.")) | |