|
import fastapi |
|
from fastapi.responses import JSONResponse |
|
from fastapi_users import schemas |
|
from time import time |
|
|
|
|
|
import logging |
|
import llama_cpp |
|
import llama_cpp.llama_tokenizer |
|
from pydantic import BaseModel |
|
from fastapi import APIRouter |
|
from app.users import current_active_user |
|
|
|
|
|
|
|
from transformers import AutoTokenizer, pipeline |
|
from optimum.onnxruntime import ORTModelForCausalLM |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat") |
|
model = ORTModelForCausalLM.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", export=True) |
|
class GenModel(BaseModel): |
|
question: str |
|
system: str = "You are a helpful medical AI chat assistant. Help as much as you can.Also continuously ask for possible symptoms in order to atat a conclusive ailment or sickness and possible solutions.Remember, response in English." |
|
temperature: float = 0.8 |
|
seed: int = 101 |
|
mirostat_mode: int=2 |
|
mirostat_tau: float=4.0 |
|
mirostat_eta: float=1.1 |
|
|
|
class ChatModel(BaseModel): |
|
question: str |
|
system: str = "You are chatDoctor, a helpful health and medical assistant. You are chatting with a human. Help as much as you can. Also continuously ask for possible symptoms in order to a conclusive ailment or sickness and possible solutions.Remember, response in English." |
|
temperature: float = 0.8 |
|
seed: int = 101 |
|
mirostat_mode: int=2 |
|
mirostat_tau: float=4.0 |
|
mirostat_eta: float=1.1 |
|
|
|
llm_chat = llama_cpp.Llama.from_pretrained( |
|
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF", |
|
filename="*q4_0.gguf", |
|
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat"), |
|
verbose=False, |
|
n_ctx=512, |
|
n_gpu_layers=0, |
|
|
|
) |
|
|
|
llm_generate = llama_cpp.Llama.from_pretrained( |
|
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF", |
|
filename="*q4_0.gguf", |
|
|
|
verbose=False, |
|
n_ctx=4096, |
|
n_gpu_layers=0, |
|
mirostat_mode=2, |
|
mirostat_tau=4.0, |
|
mirostat_eta=1.1, |
|
|
|
) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
""" |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins = ["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"] |
|
) |
|
""" |
|
llm_router = APIRouter(prefix="/llm") |
|
|
|
@llm_router.get("/health", tags=["llm"]) |
|
def health(): |
|
return {"status": "ok"} |
|
|
|
|
|
@llm_router.post("/chat/", tags=["llm"]) |
|
async def chat(chatm:ChatModel): |
|
onnx_qa = pipeline("question-answering", model=model, tokenizer=tokenizer) |
|
question = chatm.question, |
|
context = chatm.system, |
|
pred = onnx_qa(question, context) |
|
print("pred") |
|
return "" |
|
""" |
|
#chatm.system = chatm.system.format("")#user.email) |
|
try: |
|
st = time() |
|
output = llm_chat.create_chat_completion( |
|
messages = chatm.question, |
|
temperature = chatm.temperature, |
|
seed = chatm.seed, |
|
#stream=True |
|
) |
|
print(output) |
|
#print(output) |
|
et = time() |
|
output["time"] = et - st |
|
#messages.append({'role': "assistant", "content": output['choices'][0]['message']['content']}) |
|
#print(messages) |
|
return output |
|
except Exception as e: |
|
logger.error(f"Error in /complete endpoint: {e}") |
|
return JSONResponse( |
|
status_code=500, content={"message": "Internal Server Error"} |
|
) |
|
""" |
|
|
|
|
|
@llm_router.post("/generate", tags=["llm"]) |
|
async def generate(gen:GenModel): |
|
""" |
|
gen.system = "You are an helpful medical AI assistant." |
|
gen.temperature = 0.5 |
|
gen.seed = 42 |
|
try: |
|
#st = time() |
|
output = llm_generate.create_completion( |
|
#messages=[ |
|
# {"role": "system", "content": gen.system}, |
|
# {"role": "user", "content": gen.question}, |
|
# ], |
|
gen.question, |
|
temperature = gen.temperature, |
|
seed= gen.seed, |
|
#chat_format="llama-2", |
|
stream=True, |
|
echo = True |
|
) |
|
|
|
for chunk in output: |
|
delta = chunk['choices'][0]#['delta'] |
|
print(delta) |
|
if 'role' in delta: |
|
print(delta['role'], end=': ') |
|
elif 'content' in delta: |
|
print(delta['content'], end='') |
|
#print(chunk) |
|
|
|
#et = time() |
|
#output["time"] = et - st |
|
#print(output) |
|
except Exception as e: |
|
logger.error(f"Error in /generate endpoint: {e}") |
|
return JSONResponse( |
|
status_code=500, content={"message": "Internal Server Error"} |
|
) |
|
""" |
|
onnx_gen = pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
generate = onnx_gen(gen.question) |
|
return generate |
|
|
|
|
|
|