Spaces:
Sleeping
Sleeping
File size: 5,232 Bytes
7f2869e af8f2bb 7f2869e af8f2bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import concurrent.futures
import threading
import torch
from datetime import datetime
import json
import gradio as gr
import re
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from langchain.document_loaders import DirectoryLoader, TextLoader # Import these from langchain
from langchain.text_splitter import RecursiveCharacterTextSplitter # Import the text splitter
class DocumentRetrievalAndGeneration:
def __init__(self, embedding_model_name, lm_model_id, data_folder):
self.all_splits = self.load_documents(data_folder)
self.embeddings = SentenceTransformer(embedding_model_name)
self.gpu_index = self.create_faiss_index()
self.llm = self.initialize_llm(lm_model_id)
self.cancel_flag = threading.Event()
def load_documents(self, folder_path):
loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250)
all_splits = text_splitter.split_documents(documents)
print('Length of documents:', len(documents))
print("LEN of all_splits", len(all_splits))
for i in range(5):
print(all_splits[i].page_content)
return all_splits
def create_faiss_index(self):
all_texts = [split.page_content for split in self.all_splits]
embeddings = self.embeddings.encode(all_texts, convert_to_tensor=True).cpu().numpy()
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
gpu_resource = faiss.StandardGpuResources()
gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, index)
return gpu_index
def initialize_llm(self, model_id):
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
generate_text = pipeline(
model=model,
tokenizer=tokenizer,
return_full_text=True,
task='text-generation',
temperature=0.6,
max_new_tokens=256,
)
return generate_text
def generate_response_with_timeout(self, model_inputs):
def target(future):
if self.cancel_flag.is_set():
return
generated_ids = self.llm.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
if not self.cancel_flag.is_set():
future.set_result(generated_ids)
else:
future.set_exception(TimeoutError("Text generation process was canceled"))
future = concurrent.futures.Future()
thread = threading.Thread(target=target, args=(future,))
thread.start()
try:
generated_ids = future.result(timeout=60) # Timeout set to 60 seconds
return generated_ids
except concurrent.futures.TimeoutError:
self.cancel_flag.set()
raise TimeoutError("Text generation process timed out")
def qa_infer_gradio(self, query):
# Set the cancel flag to false for the new query
self.cancel_flag.clear()
try:
query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
distances, indices = self.gpu_index.search(np.array([query_embedding]), k=5)
content = ""
for idx in indices[0]:
content += "-" * 50 + "\n"
content += self.all_splits[idx].page_content + "\n"
prompt = f"""<s>
Here's my question:
Query: {query}
Solution:
RETURN ONLY SOLUTION. IF THERE IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS, RETURN "NO SOLUTION AVAILABLE"
</s>
"""
messages = [{"role": "user", "content": prompt}]
encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to(self.llm.device)
start_time = datetime.now()
generated_ids = self.generate_response_with_timeout(model_inputs)
elapsed_time = datetime.now() - start_time
decoded = self.llm.tokenizer.batch_decode(generated_ids)
generated_response = decoded[0]
match = re.search(r'Solution:(.*?)</s>', generated_response, re.DOTALL | re.IGNORECASE)
if match:
solution_text = match.group(1).strip()
else:
solution_text = "NO SOLUTION AVAILABLE"
print("Generated response:", generated_response)
print("Time elapsed:", elapsed_time)
print("Device in use:", self.llm.device)
return solution_text, content
except TimeoutError:
return "timeout", content |