Spaces:
Sleeping
Sleeping
import concurrent.futures | |
import threading | |
import torch | |
from datetime import datetime | |
import json | |
import gradio as gr | |
import re | |
import faiss | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig | |
from langchain.document_loaders import DirectoryLoader, TextLoader # Import these from langchain | |
from langchain.text_splitter import RecursiveCharacterTextSplitter # Import the text splitter | |
class DocumentRetrievalAndGeneration: | |
def __init__(self, embedding_model_name, lm_model_id, data_folder): | |
self.all_splits = self.load_documents(data_folder) | |
self.embeddings = SentenceTransformer(embedding_model_name) | |
self.gpu_index = self.create_faiss_index() | |
self.llm = self.initialize_llm(lm_model_id) | |
self.cancel_flag = threading.Event() | |
def load_documents(self, folder_path): | |
loader = DirectoryLoader(folder_path, loader_cls=TextLoader) | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250) | |
all_splits = text_splitter.split_documents(documents) | |
print('Length of documents:', len(documents)) | |
print("LEN of all_splits", len(all_splits)) | |
for i in range(5): | |
print(all_splits[i].page_content) | |
return all_splits | |
def create_faiss_index(self): | |
all_texts = [split.page_content for split in self.all_splits] | |
embeddings = self.embeddings.encode(all_texts, convert_to_tensor=True).cpu().numpy() | |
index = faiss.IndexFlatL2(embeddings.shape[1]) | |
index.add(embeddings) | |
gpu_resource = faiss.StandardGpuResources() | |
gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, index) | |
return gpu_index | |
def initialize_llm(self, model_id): | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
generate_text = pipeline( | |
model=model, | |
tokenizer=tokenizer, | |
return_full_text=True, | |
task='text-generation', | |
temperature=0.6, | |
max_new_tokens=256, | |
) | |
return generate_text | |
def generate_response_with_timeout(self, model_inputs): | |
def target(future): | |
if self.cancel_flag.is_set(): | |
return | |
generated_ids = self.llm.model.generate(model_inputs, max_new_tokens=1000, do_sample=True) | |
if not self.cancel_flag.is_set(): | |
future.set_result(generated_ids) | |
else: | |
future.set_exception(TimeoutError("Text generation process was canceled")) | |
future = concurrent.futures.Future() | |
thread = threading.Thread(target=target, args=(future,)) | |
thread.start() | |
try: | |
generated_ids = future.result(timeout=60) # Timeout set to 60 seconds | |
return generated_ids | |
except concurrent.futures.TimeoutError: | |
self.cancel_flag.set() | |
raise TimeoutError("Text generation process timed out") | |
def qa_infer_gradio(self, query): | |
# Set the cancel flag to false for the new query | |
self.cancel_flag.clear() | |
try: | |
query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy() | |
distances, indices = self.gpu_index.search(np.array([query_embedding]), k=5) | |
content = "" | |
for idx in indices[0]: | |
content += "-" * 50 + "\n" | |
content += self.all_splits[idx].page_content + "\n" | |
prompt = f"""<s> | |
Here's my question: | |
Query: {query} | |
Solution: | |
RETURN ONLY SOLUTION. IF THERE IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS, RETURN "NO SOLUTION AVAILABLE" | |
</s> | |
""" | |
messages = [{"role": "user", "content": prompt}] | |
encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt") | |
model_inputs = encodeds.to(self.llm.device) | |
start_time = datetime.now() | |
generated_ids = self.generate_response_with_timeout(model_inputs) | |
elapsed_time = datetime.now() - start_time | |
decoded = self.llm.tokenizer.batch_decode(generated_ids) | |
generated_response = decoded[0] | |
match = re.search(r'Solution:(.*?)</s>', generated_response, re.DOTALL | re.IGNORECASE) | |
if match: | |
solution_text = match.group(1).strip() | |
else: | |
solution_text = "NO SOLUTION AVAILABLE" | |
print("Generated response:", generated_response) | |
print("Time elapsed:", elapsed_time) | |
print("Device in use:", self.llm.device) | |
return solution_text, content | |
except TimeoutError: | |
return "timeout", content |