|
import fastapi |
|
from fastapi.responses import JSONResponse |
|
from fastapi_users import schemas |
|
from time import time |
|
|
|
|
|
import logging |
|
import llama_cpp |
|
import llama_cpp.llama_tokenizer |
|
from pydantic import BaseModel |
|
from fastapi import APIRouter |
|
from app.users import current_active_user |
|
|
|
class GenModel(BaseModel): |
|
question: str |
|
system: str = "You are a helpful medical AI chat assistant. Help as much as you can.Also continuously ask for possible symptoms in order to atat a conclusive ailment or sickness and possible solutions.Remember, response in English." |
|
temperature: float = 0.8 |
|
seed: int = 101 |
|
mirostat_mode: int=2 |
|
mirostat_tau: float=4.0 |
|
mirostat_eta: float=1.1 |
|
|
|
class ChatModel(BaseModel): |
|
question: list |
|
system: str = "You are a helpful AI assistant. You are chatting with a human. Help as much as you can." |
|
|
|
temperature: float = 0.8 |
|
seed: int = 101 |
|
mirostat_mode: int=2 |
|
mirostat_tau: float=4.0 |
|
mirostat_eta: float=1.1 |
|
llm_chat = llama_cpp.Llama.from_pretrained( |
|
repo_id="moriire/healthcare-ai-q2_k", |
|
filename="*.gguf", |
|
|
|
verbose=False, |
|
n_ctx=256, |
|
n_gpu_layers=0, |
|
|
|
) |
|
llm_generate = llama_cpp.Llama.from_pretrained( |
|
repo_id="moriire/healthcare-ai-q2_k", |
|
filename="*.gguf", |
|
|
|
verbose=False, |
|
n_ctx=4096, |
|
n_gpu_layers=0, |
|
mirostat_mode=2, |
|
mirostat_tau=4.0, |
|
mirostat_eta=1.1, |
|
|
|
) |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
""" |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins = ["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"] |
|
) |
|
""" |
|
llm_router = APIRouter(prefix="/llm") |
|
|
|
@llm_router.get("/health", tags=["llm"]) |
|
def health(): |
|
return {"status": "ok"} |
|
|
|
|
|
@llm_router.post("/chat/", tags=["llm"]) |
|
async def chat(chatm:ChatModel): |
|
|
|
try: |
|
st = time() |
|
output = llm_chat.create_chat_completion( |
|
messages = chatm.question, |
|
temperature = chatm.temperature, |
|
seed = chatm.seed, |
|
|
|
) |
|
print(output) |
|
|
|
et = time() |
|
output["time"] = et - st |
|
|
|
|
|
return output |
|
except Exception as e: |
|
logger.error(f"Error in /complete endpoint: {e}") |
|
return JSONResponse( |
|
status_code=500, content={"message": "Internal Server Error"} |
|
) |
|
|
|
|
|
@llm_router.post("/generate", tags=["llm"]) |
|
async def generate(gen:GenModel): |
|
gen.system = "You are an helpful medical AI assistant." |
|
gen.temperature = 0.5 |
|
gen.seed = 42 |
|
try: |
|
st = time() |
|
output = llm_generate.create_chat_completion( |
|
messages=[ |
|
{"role": "system", "content": gen.system}, |
|
{"role": "user", "content": gen.question}, |
|
], |
|
temperature = gen.temperature, |
|
seed= gen.seed, |
|
|
|
|
|
) |
|
""" |
|
for chunk in output: |
|
delta = chunk['choices'][0]['delta'] |
|
if 'role' in delta: |
|
print(delta['role'], end=': ') |
|
elif 'content' in delta: |
|
print(delta['content'], end='') |
|
#print(chunk) |
|
""" |
|
et = time() |
|
output["time"] = et - st |
|
return output |
|
except Exception as e: |
|
logger.error(f"Error in /generate endpoint: {e}") |
|
return JSONResponse( |
|
status_code=500, content={"message": "Internal Server Error"} |
|
) |
|
|
|
|
|
|