import concurrent.futures import threading import torch from datetime import datetime import json import gradio as gr import re import faiss import numpy as np from sentence_transformers import SentenceTransformer from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig from langchain.document_loaders import DirectoryLoader, TextLoader # Import these from langchain from langchain.text_splitter import RecursiveCharacterTextSplitter # Import the text splitter class DocumentRetrievalAndGeneration: def __init__(self, embedding_model_name, lm_model_id, data_folder): self.all_splits = self.load_documents(data_folder) self.embeddings = SentenceTransformer(embedding_model_name) self.gpu_index = self.create_faiss_index() self.llm = self.initialize_llm(lm_model_id) self.cancel_flag = threading.Event() def load_documents(self, folder_path): loader = DirectoryLoader(folder_path, loader_cls=TextLoader) documents = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250) all_splits = text_splitter.split_documents(documents) print('Length of documents:', len(documents)) print("LEN of all_splits", len(all_splits)) for i in range(5): print(all_splits[i].page_content) return all_splits def create_faiss_index(self): all_texts = [split.page_content for split in self.all_splits] embeddings = self.embeddings.encode(all_texts, convert_to_tensor=True).cpu().numpy() index = faiss.IndexFlatL2(embeddings.shape[1]) index.add(embeddings) gpu_resource = faiss.StandardGpuResources() gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, index) return gpu_index def initialize_llm(self, model_id): bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config) tokenizer = AutoTokenizer.from_pretrained(model_id) generate_text = pipeline( model=model, tokenizer=tokenizer, return_full_text=True, task='text-generation', temperature=0.6, max_new_tokens=256, ) return generate_text def generate_response_with_timeout(self, model_inputs): def target(future): if self.cancel_flag.is_set(): return generated_ids = self.llm.model.generate(model_inputs, max_new_tokens=1000, do_sample=True) if not self.cancel_flag.is_set(): future.set_result(generated_ids) else: future.set_exception(TimeoutError("Text generation process was canceled")) future = concurrent.futures.Future() thread = threading.Thread(target=target, args=(future,)) thread.start() try: generated_ids = future.result(timeout=60) # Timeout set to 60 seconds return generated_ids except concurrent.futures.TimeoutError: self.cancel_flag.set() raise TimeoutError("Text generation process timed out") def qa_infer_gradio(self, query): # Set the cancel flag to false for the new query self.cancel_flag.clear() try: query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy() distances, indices = self.gpu_index.search(np.array([query_embedding]), k=5) content = "" for idx in indices[0]: content += "-" * 50 + "\n" content += self.all_splits[idx].page_content + "\n" prompt = f""" Here's my question: Query: {query} Solution: RETURN ONLY SOLUTION. IF THERE IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS, RETURN "NO SOLUTION AVAILABLE" """ messages = [{"role": "user", "content": prompt}] encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt") model_inputs = encodeds.to(self.llm.device) start_time = datetime.now() generated_ids = self.generate_response_with_timeout(model_inputs) elapsed_time = datetime.now() - start_time decoded = self.llm.tokenizer.batch_decode(generated_ids) generated_response = decoded[0] match = re.search(r'Solution:(.*?)', generated_response, re.DOTALL | re.IGNORECASE) if match: solution_text = match.group(1).strip() else: solution_text = "NO SOLUTION AVAILABLE" print("Generated response:", generated_response) print("Time elapsed:", elapsed_time) print("Device in use:", self.llm.device) return solution_text, content except TimeoutError: return "timeout", content