Spaces:
Sleeping
Sleeping
import os | |
import re | |
import json | |
import torch | |
import spaces | |
import pymupdf | |
import gradio as gr | |
from qdrant_client import QdrantClient | |
from utils import download_pdf_from_gdrive, merge_strings_with_prefix | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
def rag_query(query: str): | |
""" | |
Allows searching the vector database which contains | |
information for a man named Suvaditya for a given query | |
by performing semantic search. Returns results by | |
looking at his resume, which contains a plethora of | |
information about him. | |
Args: | |
query: The query against which the search will be run, | |
in the form a single string phrase no more than | |
10 words. | |
Returns: | |
search_results: A list of results that come closest | |
to the given query semantically, | |
determined by Cosine Similarity. | |
""" | |
return client.query( | |
collection_name="resume", | |
query_text=query | |
) | |
def generate_answer(chat_history): | |
# Generate result | |
tool_prompt = tokenizer.apply_chat_template( | |
chat_history, | |
tools=[rag_query], | |
return_tensors="pt", | |
return_dict=True, | |
add_generation_prompt=True, | |
) | |
tool_prompt = tool_prompt.to(model.device) | |
out = model.generate(**tool_prompt, max_new_tokens=512) | |
generated_text = out[0, tool_prompt['input_ids'].shape[1]:] | |
generated_text = tokenizer.decode(generated_text) | |
return generated_text | |
def parse_tool_request(tool_call, top_k=5): | |
pattern = r"<tool_call>(.*?)</tool_call>" | |
match_result = re.search(pattern, tool_call, re.DOTALL) | |
if match_result: | |
result = match_result.group(1).strip() | |
else: | |
return None, None | |
query = json.loads(result)["arguments"]["query"] | |
query_results = [ | |
query_piece.metadata["document"] for query_piece in rag_query(query) | |
] | |
return query_results[:top_k], query | |
def update_chat_history(chat_history, tool_query, query_results): | |
assistant_tool_message = { | |
"role": "assistant", | |
"metadata": "🛠️ Using Qdrant Engine to search for the query 🛠️", | |
"tool_calls": [{ | |
"type": "function", | |
"function": { | |
"name": "rag_query", | |
"arguments": {"query": f"{tool_query}"} | |
} | |
}] | |
} | |
result_tool_message = { | |
"role": "tool", | |
"name": "rag_query", | |
"content": "\n".join(query_results) | |
} | |
chat_history.append(assistant_tool_message) | |
chat_history.append(result_tool_message) | |
return chat_history | |
if __name__ == "__main__": | |
RESUME_PATH = os.path.join(os.getcwd(), "Resume.pdf") | |
RESUME_URL = "https://drive.google.com/file/d/1YMF9NNTG5gubwJ7ipI5JfxAJKhlD9h2v/" | |
# Download file | |
download_pdf_from_gdrive(RESUME_URL, RESUME_PATH) | |
doc = pymupdf.open(RESUME_PATH) | |
fulltext = doc[0].get_text().split("\n") | |
fulltext = merge_strings_with_prefix(fulltext) | |
# Embed the sentences | |
client = QdrantClient(":memory:") | |
client.set_model("sentence-transformers/all-MiniLM-L6-v2") | |
if not client.collection_exists(collection_name="resume"): | |
client.create_collection( | |
collection_name="resume", | |
vectors_config=client.get_fastembed_vector_params(), | |
) | |
_ = client.add( | |
collection_name="resume", | |
documents=fulltext, | |
ids=range(len(fulltext)), | |
batch_size=100, | |
parallel=0, | |
) | |
# FOR QWEN, THIS IS WORKING | |
model_name = "Qwen/Qwen2.5-3B-Instruct" | |
def rag_process(message, chat_history): | |
# Append current user message to chat history | |
current_message = { | |
"role": "user", | |
"content": message | |
} | |
chat_history.append(current_message) | |
# Generate LLM answer | |
generated_text = generate_answer(chat_history) | |
# Detect if tool call is requested by LLM. If yes, then | |
# execute tool and use else return None | |
query_results, tool_query = parse_tool_request(generated_text) | |
# If tool call was requested | |
if query_results is not None and tool_query is not None: | |
print("Inside") | |
# Update chat history with result of tool call | |
chat_history = update_chat_history( | |
chat_history, tool_query, query_results | |
) | |
# Generate result from the | |
generated_text = generate_answer(chat_history) | |
return generated_text[:-10] | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
quantization_config = BitsAndBytesConfig(load_in_8bit=True) | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
demo = gr.ChatInterface( | |
fn=rag_process, | |
type="messages", | |
title="Resume RAG, a personal space on ZeroGPU!", | |
examples=["Where did Suvaditya complete his Bachelor's Degree?", "Where is Suvaditya currently working?"], | |
description="Ask any question about Suvaditya's resume and get an answer!", | |
theme="ocean" | |
) | |
demo.launch() |