Spaces:
Runtime error
Runtime error
import os | |
import multiprocessing | |
import concurrent.futures | |
from langchain.document_loaders import TextLoader, DirectoryLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import torch | |
import numpy as np | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig | |
from datetime import datetime | |
import json | |
import gradio as gr | |
import re | |
from huggingface_hub import InferenceClient | |
# from unsloth import FastLanguageModel | |
import transformers | |
from transformers import BloomForCausalLM | |
from transformers import BloomForTokenClassification | |
from transformers import BloomForTokenClassification | |
from transformers import BloomTokenizerFast | |
import torch | |
class DocumentRetrievalAndGeneration: | |
def __init__(self, embedding_model_name, lm_model_id, data_folder): | |
# hf_token = os.getenv('HF_TOKEN') | |
hf="hf_VuNNBwnFqlcKzV" | |
token="vCfLXEBxyAOftxvlWpwf" | |
self.hf_token=hf+token | |
# print(HF_TOKEN,hf_token) | |
self.all_splits = self.load_documents(data_folder) | |
self.embeddings = SentenceTransformer(embedding_model_name) | |
self.cpu_index = self.create_faiss_index() | |
self.llm = self.initialize_llm2(lm_model_id) | |
def load_documents(self, folder_path): | |
loader = DirectoryLoader(folder_path, loader_cls=TextLoader) | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250) | |
all_splits = text_splitter.split_documents(documents) | |
print('Length of documents:', len(documents)) | |
print("LEN of all_splits", len(all_splits)) | |
return all_splits | |
def create_faiss_index(self): | |
all_texts = [split.page_content for split in self.all_splits] | |
embeddings = self.embeddings.encode(all_texts, convert_to_tensor=True).cpu().numpy() | |
index = faiss.IndexFlatL2(embeddings.shape[1]) | |
index.add(embeddings) | |
return index | |
def initialize_llm(self, model_id): | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config,token=self.hf_token) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
generate_text = pipeline( | |
model=model, | |
tokenizer=tokenizer, | |
return_full_text=True, | |
task='text-generation', | |
temperature=0.6, | |
max_new_tokens=256, | |
) | |
return generate_text | |
def initialize_llm2(self,model_id): | |
self.client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
# except: | |
# try: | |
# pipe = pipeline("text-generation", model="microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True) | |
# except: | |
# pipe = pipeline("text-generation", model="microsoft/Phi-3-mini-4k-instruct") | |
# pipe = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2") | |
# model_name = "mistralai/Mistral-7B-Instruct-v0.2" | |
# pipeline = transformers.pipeline( | |
# "text-generation", | |
# model=model_name, | |
# model_kwargs={"torch_dtype": torch.bfloat16}, | |
# device="cpu", | |
# ) | |
# return generate_text | |
def generate_response_with_timeout(self, model_inputs): | |
try: | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
future = executor.submit(self.llm.model.generate, model_inputs, max_new_tokens=1000, do_sample=True) | |
generated_ids = future.result(timeout=800) # Timeout set to 60 seconds | |
return generated_ids | |
except concurrent.futures.TimeoutError: | |
return "Text generation process timed out" | |
raise TimeoutError("Text generation process timed out") | |
def query_and_generate_response(self, query): | |
query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy() | |
distances, indices = self.cpu_index.search(np.array([query_embedding]), k=5) | |
content = "" | |
# for idx in indices[0]: | |
# content += "-" * 50 + "\n" | |
# content += self.all_splits[idx].page_content + "\n" | |
# distance=distances[0][idx] | |
# print("CHUNK", idx) | |
# print("Distance :",distance) | |
# print(self.all_splits[idx].page_content) | |
# print("############################") | |
for idx in indices[0]: | |
if idx < len(self.all_splits) and idx < len(distances[0]): | |
content += "-" * 50 + "\n" | |
content += self.all_splits[idx].page_content + "\n" | |
distance = distances[0][idx] | |
print("CHUNK", idx) | |
print("Distance :", distance) | |
print(self.all_splits[idx].page_content) | |
print("############################") | |
else: | |
print(f"Index {idx} is out of bounds. Skipping.") | |
# {query} | |
prompt = f"""<s> | |
You are a knowledgeable assistant with access to a comprehensive database. | |
I need you to answer my question and provide related information in a specific format. | |
I have provided five relatable json files {content}, choose the most suitable chunks for answering the query | |
Here's what I need: | |
Include a final answer without additional comments, sign-offs, or extra phrases. Be direct and to the point. | |
content | |
Here's my question: | |
Query: | |
Solution==> | |
RETURN ONLY SOLUTION . IF THEIR IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS , RETURN " NO SOLUTION AVAILABLE" | |
IF THE QUERY AND THE RETRIEVED CHUNKS DO NOT CORRELATE MEANINGFULLY, OR IF THE QUERY IS NOT RELEVANT TO TDA2 OR RELATED TOPICS, THEN "NO SOLUTION AVAILABLE." | |
Example1 | |
Query: "How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM", | |
Solution: "To use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM, you need to modify the configuration file of the NDK application. Specifically, change the processor reference from 'A15_0' to 'IPU1_0'.", | |
Example2 | |
Query: "Can BQ25896 support I2C interface?", | |
Solution: "Yes, the BQ25896 charger supports the I2C interface for communication." | |
Example3 | |
Query: "Who is the fastest runner in the world", | |
Solution:"NO SOLUTION AVAILABLE" | |
Example4 | |
Query:"What is the price of latest apple MACBOOK " | |
Solution:"NO SOLUTION AVAILABLE" | |
</s> | |
""" | |
messages = [{"role": "system", "content": prompt}] | |
messages.append({"role": "user", "content": query}) | |
response = "" | |
for message in self.client.chat_completion(messages,max_tokens=2048,stream=True,temperature=0.7): | |
token = message.choices[0].delta.content | |
response += token | |
# yield response | |
generated_response=response | |
# messages = [{"role": "user", "content": prompt}] | |
# encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt") | |
# model_inputs = encodeds.to(self.llm.device) | |
# start_time = datetime.now() | |
# generated_ids = self.generate_response_with_timeout(model_inputs) | |
# elapsed_time = datetime.now() - start_time | |
# decoded = self.llm.tokenizer.batch_decode(generated_ids) | |
# generated_response = decoded[0] | |
######################################################### | |
# messages = [] | |
# # Check if history is None or empty and handle accordingly | |
# if history: | |
# for user_msg, assistant_msg in history: | |
# messages.append({"role": "user", "content": user_msg}) | |
# messages.append({"role": "assistant", "content": assistant_msg}) | |
# # Always add the current user message | |
# messages.append({"role": "user", "content": message}) | |
# # Construct the prompt using the pipeline's tokenizer | |
# prompt = pipeline.tokenizer.apply_chat_template( | |
# messages, | |
# tokenize=False, | |
# add_generation_prompt=True | |
# ) | |
# # Generate the response | |
# terminators = [ | |
# pipeline.tokenizer.eos_token_id, | |
# pipeline.tokenizer.convert_tokens_to_ids("") | |
# ] | |
# # Adjust the temperature slightly above given to ensure variety | |
# adjusted_temp = temperature + 0.1 | |
# # Generate outputs with adjusted parameters | |
# outputs = pipeline( | |
# prompt, | |
# max_new_tokens=max_new_tokens, | |
# do_sample=True, | |
# temperature=adjusted_temp, | |
# top_p=0.9 | |
# ) | |
# # Extract the generated text, skipping the length of the prompt | |
# generated_text = outputs[0]["generated_text"] | |
# generated_response = generated_text[len(prompt):] | |
match1 = re.search(r'\[/INST\](.*?)</s>', generated_response, re.DOTALL) | |
match2 = re.search(r'Solution:(.*?)</s>', generated_response, re.DOTALL | re.IGNORECASE) | |
if match1: | |
solution_text = match1.group(1).strip() | |
if "Solution:" in solution_text: | |
solution_text = solution_text.split("Solution:", 1)[1].strip() | |
elif match2: | |
solution_text = match2.group(1).strip() | |
else: | |
solution_text=generated_response | |
# print("Generated response:", generated_response) | |
# print("Time elapsed:", elapsed_time) | |
# print("Device in use:", self.llm.device) | |
return solution_text, content | |
def qa_infer_gradio(self, query): | |
response = self.query_and_generate_response(query) | |
return response | |
if __name__ == "__main__": | |
print("starting...") | |
embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12' | |
# lm_model_id = "mistralai/Mistral-7B-Instruct-v0.2" | |
lm_model_id= "unsloth/Phi-3-mini-4k-instruct-bnb-4bit" | |
data_folder = 'text_files' | |
doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder) | |
def launch_interface(): | |
css_code = """ | |
.gradio-container { | |
background-color: #ffffff; | |
} | |
/* Button styling for all buttons */ | |
button { | |
background-color: #999999; /* Default color for all other buttons */ | |
color: black; | |
border: 1px solid black; | |
padding: 10px; | |
margin-right: 10px; | |
font-size: 16px; /* Increase font size */ | |
font-weight: bold; /* Make text bold */ | |
} | |
""" | |
EXAMPLES = ["What are the main types of blood cancer, and how do they differ in terms of symptoms, progression, and treatment options? ", | |
"What are the latest advancements in the treatment of blood cancer, and how do they improve patient outcomes compared to traditional therapies?", | |
"How do genetic factors and environmental exposures contribute to the risk of developing blood cancer, and what preventive measures can be taken?"] | |
interface = gr.Interface( | |
fn=doc_retrieval_gen.qa_infer_gradio, | |
inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")], | |
allow_flagging='never', | |
examples=EXAMPLES, | |
cache_examples=False, | |
outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES")], | |
css=css_code | |
) | |
interface.launch(debug=True) | |
launch_interface() | |