import tempfile import itertools import gradio as gr from __init__ import * from llama_cpp import Llama from chromadb.config import Settings from typing import List, Optional, Union from langchain.vectorstores import Chroma from langchain.docstore.document import Document from huggingface_hub.file_download import http_get from langchain.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter class LocalChatGPT: def __init__(self): self.llama_model: Optional[Llama] = None self.embeddings: HuggingFaceEmbeddings = self.initialize_app() def initialize_app(self) -> HuggingFaceEmbeddings: """ Load all models from the list :return: """ os.makedirs(MODELS_DIR, exist_ok=True) model_url, model_name = list(DICT_REPO_AND_MODELS.items())[0] final_model_path = os.path.join(MODELS_DIR, model_name) os.makedirs("/".join(final_model_path.split("/")[:-1]), exist_ok=True) if not os.path.exists(final_model_path): with open(final_model_path, "wb") as f: http_get(model_url, f) self.llama_model = Llama( model_path=final_model_path, n_ctx=2000, n_parts=1, ) return HuggingFaceEmbeddings(model_name=EMBEDDER_NAME, cache_folder=MODELS_DIR) def load_model(self, model_name): """ :param model_name: :return: """ final_model_path = os.path.join(MODELS_DIR, model_name) os.makedirs("/".join(final_model_path.split("/")[:-1]), exist_ok=True) if not os.path.exists(final_model_path): with open(final_model_path, "wb") as f: if model_url := [i for i in DICT_REPO_AND_MODELS if DICT_REPO_AND_MODELS[i] == model_name]: http_get(model_url[0], f) self.llama_model = Llama( model_path=final_model_path, n_ctx=2000, n_parts=1, ) return model_name @staticmethod def load_single_document(file_path: str) -> Document: """ Upload one document. :param file_path: :return: """ ext: str = "." + file_path.rsplit(".", 1)[-1] assert ext in LOADER_MAPPING loader_class, loader_args = LOADER_MAPPING[ext] loader = loader_class(file_path, **loader_args) return loader.load()[0] @staticmethod def get_message_tokens(model: Llama, role: str, content: str) -> list: """ :param model: :param role: :param content: :return: """ message_tokens: list = model.tokenize(content.encode("utf-8")) message_tokens.insert(1, ROLE_TOKENS[role]) message_tokens.insert(2, LINEBREAK_TOKEN) message_tokens.append(model.token_eos()) return message_tokens def get_system_tokens(self, model: Llama) -> list: """ :param model: :return: """ system_message: dict = {"role": "system", "content": SYSTEM_PROMPT} return self.get_message_tokens(model, **system_message) @staticmethod def upload_files(files: List[tempfile.TemporaryFile]) -> List[str]: """ :param files: :return: """ return [f.name for f in files] @staticmethod def process_text(text: str) -> Optional[str]: """ :param text: :return: """ lines: list = text.split("\n") lines = [line for line in lines if len(line.strip()) > 2] text = "\n".join(lines).strip() return None if len(text) < 10 else text @staticmethod def update_text_db( db: Optional[Chroma], fixed_documents: List[Document], ids: List[str] ) -> Union[Optional[Chroma], str]: if db: data: dict = db.get() files_db = {dict_data['source'].split('/')[-1] for dict_data in data["metadatas"]} files_load = {dict_data.metadata["source"].split('/')[-1] for dict_data in fixed_documents} if files_load == files_db: # db.delete([item for item in data['ids'] if item not in ids]) # db.update_documents(ids, fixed_documents) db.delete(data['ids']) db.add_texts( texts=[doc.page_content for doc in fixed_documents], metadatas=[doc.metadata for doc in fixed_documents], ids=ids ) file_warning = f"Uploaded {len(fixed_documents)} fragments! You can ask questions" return db, file_warning def build_index( self, file_paths: List[str], db: Optional[Chroma], chunk_size: int, chunk_overlap: int ): """ :param file_paths: :param db: :param chunk_size: :param chunk_overlap: :return: """ documents: List[Document] = [self.load_single_document(path) for path in file_paths] text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap ) documents = text_splitter.split_documents(documents) fixed_documents: List[Document] = [] for doc in documents: doc.page_content = self.process_text(doc.page_content) if not doc.page_content: continue fixed_documents.append(doc) ids: List[str] = [ f"{path.split('/')[-1].replace('.txt', '')}{i}" for path, i in itertools.product(file_paths, range(1, len(fixed_documents) + 1)) ] self.update_text_db(db, fixed_documents, ids) db = Chroma.from_documents( documents=fixed_documents, embedding=self.embeddings, ids=ids, client_settings=Settings( anonymized_telemetry=False, persist_directory="db" ) ) file_warning = f"Uploaded {len(fixed_documents)} fragments! You can ask questions." return db, file_warning @staticmethod def user(message, history): new_history = history + [[message, None]] return "", new_history @staticmethod def regenerate_response(history): """ :param history: :return: """ return "", history @staticmethod def retrieve(history, db: Optional[Chroma], retrieved_docs): """ :param history: :param db: :param retrieved_docs: :return: """ if db: last_user_message = history[-1][0] try: docs = db.similarity_search(last_user_message, k=4) # retriever = db.as_retriever(search_kwargs={"k": k_documents}) # docs = retriever.get_relevant_documents(last_user_message) except RuntimeError: docs = db.similarity_search(last_user_message, k=1) # retriever = db.as_retriever(search_kwargs={"k": 1}) # docs = retriever.get_relevant_documents(last_user_message) source_docs = set() for doc in docs: for content in doc.metadata.values(): source_docs.add(content.split("/")[-1]) retrieved_docs = "\n\n".join([doc.page_content for doc in docs]) retrieved_docs = f"A document- {''.join(list(source_docs))}.\n\n{retrieved_docs}" return retrieved_docs def bot(self, history, retrieved_docs): """ :param history: :param retrieved_docs: :return: """ if not history: return tokens = self.get_system_tokens(self.llama_model)[:] tokens.append(LINEBREAK_TOKEN) for user_message, bot_message in history[:-1]: message_tokens = self.get_message_tokens(model=self.llama_model, role="user", content=user_message) tokens.extend(message_tokens) last_user_message = history[-1][0] if retrieved_docs: last_user_message = f"Context: {retrieved_docs}\n\nUsing context, answer the question:" \ f"{last_user_message}" message_tokens = self.get_message_tokens(model=self.llama_model, role="user", content=last_user_message) tokens.extend(message_tokens) role_tokens = [self.llama_model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN] tokens.extend(role_tokens) generator = self.llama_model.generate( tokens, top_k=30, top_p=0.9, temp=0.1 ) partial_text = "" for i, token in enumerate(generator): if token == self.llama_model.token_eos() or (MAX_NEW_TOKENS is not None and i >= MAX_NEW_TOKENS): break partial_text += self.llama_model.detokenize([token]).decode("utf-8", "ignore") history[-1][1] = partial_text yield history def run(self): """ :return: """ with gr.Blocks(theme=gr.themes.Soft(), css=BLOCK_CSS) as demo: db: Optional[Chroma] = gr.State(None) favicon = f'' gr.Markdown( f"""

{favicon} GPT-based text assistant

""" ) with gr.Row(elem_id="model_selector_row"): models: list = list(DICT_REPO_AND_MODELS.values()) model_selector = gr.Dropdown( choices=models, value=models[0] if models else "", interactive=True, show_label=False, container=False, ) with gr.Row(): with gr.Column(scale=5): chatbot = gr.Chatbot(label="Dialogue", height=400) with gr.Column(min_width=200, scale=4): retrieved_docs = gr.Textbox( label="Extracted fragments", placeholder="Will appear after asking questions", interactive=False ) with gr.Row(): with gr.Column(scale=20): msg = gr.Textbox( label="send a message", show_label=False, placeholder="send a message", container=False ) with gr.Column(scale=3, min_width=100): submit = gr.Button("📤 Send", variant="primary") with gr.Row(): # gr.Button(value="👍 Понравилось") # gr.Button(value="👎 Не понравилось") stop = gr.Button(value="⛔ Stop") regenerate = gr.Button(value="🔄 Repeat") clear = gr.Button(value="🗑️ Clear") # # Upload files # file_output.upload( # fn=self.upload_files, # inputs=[file_output], # outputs=[file_paths], # queue=True, # ).success( # fn=self.build_index, # inputs=[file_paths, db, chunk_size, chunk_overlap], # outputs=[db, file_warning], # queue=True # ) model_selector.change( fn=self.load_model, inputs=[model_selector], outputs=[model_selector] ) # Pressing Enter submit_event = msg.submit( fn=self.user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False, ).success( fn=self.retrieve, inputs=[chatbot, db, retrieved_docs], outputs=[retrieved_docs], queue=True, ).success( fn=self.bot, inputs=[chatbot, retrieved_docs], outputs=chatbot, queue=True, ) # Pressing the button submit_click_event = submit.click( fn=self.user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False, ).success( fn=self.retrieve, inputs=[chatbot, db, retrieved_docs], outputs=[retrieved_docs], queue=True, ).success( fn=self.bot, inputs=[chatbot, retrieved_docs], outputs=chatbot, queue=True, ) # Stop generation stop.click( fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False, ) # Regenerate regenerate.click( fn=self.regenerate_response, inputs=[chatbot], outputs=[msg, chatbot], queue=False, ).success( fn=self.retrieve, inputs=[chatbot, db, retrieved_docs], outputs=[retrieved_docs], queue=True, ).success( fn=self.bot, inputs=[chatbot, retrieved_docs], outputs=chatbot, queue=True, ) # Clear history clear.click(lambda: None, None, chatbot, queue=False) demo.queue(max_size=128, default_concurrency_limit=10, api_open=False) demo.launch(server_name="0.0.0.0", max_threads=200) if __name__ == "__main__": local_chat_gpt = LocalChatGPT() local_chat_gpt.run()