File size: 5,230 Bytes
3dd4599
 
 
 
 
 
d1567c0
3dd4599
 
c82953b
d1567c0
3dd4599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd59709
3dd4599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1567c0
 
3dd4599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c82953b
72d37a9
3dd4599
 
 
 
 
 
5048f43
 
 
 
3dd4599
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import re
import json
import torch
import spaces
import pymupdf
import gradio as gr
from qdrant_client import QdrantClient
from utils import download_pdf_from_gdrive, merge_strings_with_prefix
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

def rag_query(query: str):
    """
    Allows searching the vector database which contains
    information for a man named Suvaditya for a given query
    by performing semantic search. Returns results by
    looking at his resume, which contains a plethora of
    information about him.

    Args:
        query: The query against which the search will be run,
               in the form a single string phrase no more than
               10 words.

    Returns:
        search_results: A list of results that come closest
                        to the given query semantically,
                        determined by Cosine Similarity.
    """
    return client.query(
        collection_name="resume",
        query_text=query
    )

def generate_answer(chat_history):
    # Generate result
    tool_prompt = tokenizer.apply_chat_template(
        chat_history,
        tools=[rag_query],
        return_tensors="pt",
        return_dict=True,
        add_generation_prompt=True,
    )
    tool_prompt = tool_prompt.to(model.device)
    out = model.generate(**tool_prompt, max_new_tokens=512)
    generated_text = out[0, tool_prompt['input_ids'].shape[1]:]
    generated_text = tokenizer.decode(generated_text)
    return generated_text

def parse_tool_request(tool_call, top_k=5):
    pattern = r"<tool_call>(.*?)</tool_call>"
    match_result = re.search(pattern, tool_call, re.DOTALL)
    if match_result:
        result = match_result.group(1).strip()
    else:
        return None, None

    query = json.loads(result)["arguments"]["query"]
    query_results = [
        query_piece.metadata["document"] for query_piece in rag_query(query)
    ]

    return query_results[:top_k], query

def update_chat_history(chat_history, tool_query, query_results):
    assistant_tool_message = {
        "role": "assistant",
        "metadata": "🛠️ Using Qdrant Engine to search for the query 🛠️",
        "tool_calls": [{
            "type": "function",
            "function": {
                "name": "rag_query",
                "arguments": {"query": f"{tool_query}"}
            }
        }]
    }
    result_tool_message = {
        "role": "tool",
        "name": "rag_query",
        "content": "\n".join(query_results)
    }

    chat_history.append(assistant_tool_message)
    chat_history.append(result_tool_message)

    return chat_history

if __name__ == "__main__":
    RESUME_PATH = os.path.join(os.getcwd(), "Resume.pdf")
    RESUME_URL = "https://drive.google.com/file/d/1YMF9NNTG5gubwJ7ipI5JfxAJKhlD9h2v/"

    # Download file
    download_pdf_from_gdrive(RESUME_URL, RESUME_PATH)

    doc = pymupdf.open(RESUME_PATH)
    fulltext = doc[0].get_text().split("\n")
    fulltext = merge_strings_with_prefix(fulltext)

    # Embed the sentences
    client = QdrantClient(":memory:")

    client.set_model("sentence-transformers/all-MiniLM-L6-v2")

    if not client.collection_exists(collection_name="resume"):
        client.create_collection(
            collection_name="resume",
            vectors_config=client.get_fastembed_vector_params(),
        )

    _ = client.add(
        collection_name="resume",
        documents=fulltext,
        ids=range(len(fulltext)),
        batch_size=100,
        parallel=0,
    )

    # FOR QWEN, THIS IS WORKING

    model_name = "Qwen/Qwen2.5-3B-Instruct"

    @spaces.GPU
    def rag_process(message, chat_history):
        # Append current user message to chat history
        current_message = {
            "role": "user",
            "content": message
        }
        chat_history.append(current_message)

        # Generate LLM answer
        generated_text = generate_answer(chat_history)

        # Detect if tool call is requested by LLM. If yes, then
        # execute tool and use else return None
        query_results, tool_query = parse_tool_request(generated_text)

        # If tool call was requested
        if query_results is not None and tool_query is not None:
            print("Inside")
            # Update chat history with result of tool call
            chat_history = update_chat_history(
                chat_history, tool_query, query_results
            )
            # Generate result from the
            generated_text = generate_answer(chat_history)

        return generated_text[:-10]

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        quantization_config = BitsAndBytesConfig(load_in_8bit=True)
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    demo = gr.ChatInterface(
        fn=rag_process,
        type="messages",
        title="Resume RAG, a personal space on ZeroGPU!",
        examples=["Where did Suvaditya complete his Bachelor's Degree?", "Where is Suvaditya currently working?"],
        description="Ask any question about Suvaditya's resume and get an answer!",
        theme="ocean"
    )
    demo.launch()