|
import fastapi |
|
from fastapi.responses import JSONResponse |
|
from fastapi_users import schemas |
|
from time import time |
|
|
|
|
|
import logging |
|
from langchain_community.llms import LlamaCpp |
|
import llama_cpp |
|
import llama_cpp.llama_tokenizer |
|
from pydantic import BaseModel |
|
from fastapi import APIRouter |
|
from app.users import current_active_user |
|
|
|
from langchain_community.document_loaders import WebBaseLoader |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from langchain_chroma import Chroma |
|
from langchain_community.embeddings import GPT4AllEmbeddings |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.prompts import PromptTemplate |
|
|
|
from langchain import hub |
|
from langchain_core.runnables import RunnablePassthrough, RunnablePick |
|
|
|
rag_prompt_llama = hub.pull("rlm/rag-prompt-llama") |
|
rag_prompt.messages |
|
|
|
llm = llama_cpp.Llama.from_pretrained( |
|
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF", |
|
filename="*q4_0.gguf", |
|
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat"), |
|
verbose=False, |
|
n_ctx=512, |
|
n_gpu_layers=0, |
|
|
|
) |
|
|
|
|
|
class RagChat: |
|
def agent(self): |
|
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/") |
|
data = loader.load() |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0) |
|
all_splits = text_splitter.split_documents(data) |
|
return all_splits |
|
|
|
def download_embedding(self): |
|
vectorstore = Chroma.from_documents(documents=self.agent, embedding=GPT4AllEmbeddings()) |
|
return vectorstore |
|
|
|
def chat(self, question): |
|
retriever = vectorstore.as_retriever() |
|
chain = ( |
|
{"context": retriever | format_docs, "question": RunnablePassthrough()} |
|
| rag_prompt_llama |
|
| llm |
|
| StrOutputParser() |
|
) |
|
return chain.invoke({"context": self.search(question), "question": question}) |
|
|
|
def search(self, question): |
|
docs = self.download_embedding().similarity_search(question) |
|
return docs |
|
|
|
|
|
class GenModel(BaseModel): |
|
question: str |
|
system: str = "You are a helpful medical AI chat assistant. Help as much as you can.Also continuously ask for possible symptoms in order to atat a conclusive ailment or sickness and possible solutions.Remember, response in English." |
|
temperature: float = 0.8 |
|
seed: int = 101 |
|
mirostat_mode: int=2 |
|
mirostat_tau: float=4.0 |
|
mirostat_eta: float=1.1 |
|
|
|
class ChatModel(BaseModel): |
|
question: list |
|
system: str = "You are chatDoctor, a helpful health and medical assistant. You are chatting with a human. Help as much as you can. Also continuously ask for possible symptoms in order to a conclusive ailment or sickness and possible solutions.Remember, response in English." |
|
temperature: float = 0.8 |
|
seed: int = 101 |
|
mirostat_mode: int=2 |
|
mirostat_tau: float=4.0 |
|
mirostat_eta: float=1.1 |
|
|
|
llm_chat = llama_cpp.Llama.from_pretrained( |
|
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF", |
|
filename="*q4_0.gguf", |
|
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat"), |
|
verbose=False, |
|
n_ctx=512, |
|
n_gpu_layers=0, |
|
|
|
) |
|
llm_generate = llama_cpp.Llama.from_pretrained( |
|
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF", |
|
filename="*q4_0.gguf", |
|
|
|
verbose=False, |
|
n_ctx=4096, |
|
n_gpu_layers=0, |
|
mirostat_mode=2, |
|
mirostat_tau=4.0, |
|
mirostat_eta=1.1, |
|
|
|
) |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
""" |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins = ["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"] |
|
) |
|
""" |
|
llm_router = APIRouter(prefix="/llm") |
|
|
|
@llm_router.get("/health", tags=["llm"]) |
|
def health(): |
|
return {"status": "ok"} |
|
|
|
|
|
|
|
|
|
@llm_router.post("/rag/", tags=["llm"]) |
|
async def ragchat(chatm:ChatModel): |
|
r = RagChat().chat(chatml.question) |
|
print(r) |
|
|
|
|
|
|
|
@llm_router.post("/chat/", tags=["llm"]) |
|
async def chat(chatm:ChatModel): |
|
|
|
try: |
|
st = time() |
|
output = llm_chat.create_chat_completion( |
|
messages = chatm.question, |
|
temperature = chatm.temperature, |
|
seed = chatm.seed, |
|
|
|
) |
|
print(output) |
|
|
|
et = time() |
|
output["time"] = et - st |
|
|
|
|
|
return output |
|
except Exception as e: |
|
logger.error(f"Error in /complete endpoint: {e}") |
|
return JSONResponse( |
|
status_code=500, content={"message": "Internal Server Error"} |
|
) |
|
|
|
|
|
@llm_router.post("/generate", tags=["llm"]) |
|
async def generate(gen:GenModel): |
|
gen.system = "You are an helpful medical AI assistant." |
|
gen.temperature = 0.5 |
|
gen.seed = 42 |
|
try: |
|
|
|
output = llm_generate.create_completion( |
|
|
|
|
|
|
|
|
|
gen.question, |
|
temperature = gen.temperature, |
|
seed= gen.seed, |
|
|
|
stream=True, |
|
echo = True |
|
) |
|
|
|
for chunk in output: |
|
delta = chunk['choices'][0] |
|
print(delta) |
|
if 'role' in delta: |
|
print(delta['role'], end=': ') |
|
elif 'content' in delta: |
|
print(delta['content'], end='') |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
logger.error(f"Error in /generate endpoint: {e}") |
|
return JSONResponse( |
|
status_code=500, content={"message": "Internal Server Error"} |
|
) |
|
|
|
|
|
|