File size: 7,027 Bytes
fe0a782
03004bb
fe0a782
db91191
fe0a782
 
 
 
 
 
 
41ef5eb
fe0a782
6338c11
6ce7039
a816dca
 
fe0a782
41ef5eb
 
fe0a782
 
e767020
3515ac0
15e16ab
fe0a782
41ef5eb
 
0400fe2
41ef5eb
 
fe0a782
 
 
db91191
fe0a782
 
41ef5eb
fe0a782
 
 
 
 
 
41ef5eb
fe0a782
 
 
 
41ef5eb
03004bb
 
fe0a782
03004bb
 
fe0a782
 
 
 
 
 
 
 
 
 
 
41ef5eb
fe0a782
03004bb
 
fe0a782
3252ba7
082d0d6
fe0a782
41ef5eb
ea20e6e
 
41ef5eb
fe0a782
 
 
 
2b6142b
41ef5eb
fe0a782
 
8d4ff63
 
fe0a782
8d4ff63
fe0a782
41ef5eb
 
 
334d10d
2b6142b
 
 
 
 
41ef5eb
 
fe0a782
 
b9940a5
41ef5eb
03004bb
41ef5eb
9fcbef3
03004bb
 
 
 
 
 
 
 
 
 
 
41ef5eb
 
 
9fcbef3
 
15e16ab
9fcbef3
 
 
03004bb
41ef5eb
fe0a782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95b11af
fe0a782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41ef5eb
fe0a782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41ef5eb
 
 
 
fe0a782
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import threading  # to allow streaming response
import time  # to pave the delivery of the message

import datasets  # for loading RAG database
import faiss  # to create a search index
import gradio  # for the interface
import numpy  # to work with vectors
import sentence_transformers  # to load an embedding model
import spaces  # for GPU
import transformers  # to load an LLM

# The greeting supplied by the agent when it starts
GREETING = (
    "Howdy! I'm an AI agent that uses [retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) "
    "to answer questions about research published at [ASME IDETC](https://asmedigitalcollection.asme.org/IDETC-CIE) within the last 10 years or so. "
    "I always try to cite my sources, but sometimes things get a little weird. "
    "What can I tell you about today?"
)

# Example queries supplied in the interface
EXAMPLE_QUERIES = [
    "What's the difference between a markov chain and a hidden markov model?",
    "What can you tell me about analytical target cascading?",
    "What is known about different modes for human-AI teaming?",
    "What are some examples of opportunistic versus restrictive design for additive manufacturing? Format your answer as a table with two columns (opportunistic, restrictive)."
]

# The embedding model used
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"

# The conversational model used
LLM_MODEL_NAME = "Qwen/Qwen2-7B-Instruct"

# Load the dataset and convert to pandas
data = datasets.load_dataset("ccm/rag-idetc")["train"].to_pandas()

# Load the model for later use in embeddings
embedding_model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)

# Create an LLM pipeline that we can send queries to
tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
streamer = transformers.TextIteratorStreamer(
    tokenizer, skip_prompt=True, skip_special_tokens=True
)
chat_model = transformers.AutoModelForCausalLM.from_pretrained(
    LLM_MODEL_NAME, torch_dtype="auto", device_map="auto"
)

# Create a FAISS index for fast similarity search
vectors = numpy.stack(data["embedding"].tolist(), axis=0).astype("float32")
excerpt_index = faiss.IndexFlatL2(len(data["embedding"][0]))
excerpt_index.metric_type = faiss.METRIC_INNER_PRODUCT
faiss.normalize_L2(vectors)
excerpt_index.train(vectors)
excerpt_index.add(vectors)


def preprocess(query: str, k: int) -> tuple[str, str]:
    """
    Searches the dataset for the top k most relevant papers to the query and returns a prompt and references
    Args:
        query (str): The user's query
        k (int): The number of results to return
    Returns:
        tuple[str, str]: A tuple containing the prompt and references
    """
    encoded_query = numpy.expand_dims(embedding_model.encode(query), axis=0)
    faiss.normalize_L2(encoded_query)
    _, indices = excerpt_index.search(encoded_query, k)
    top_five = data.loc[indices[0]]

    print(top_five["text"].values)

    prompt = (
        "You are an AI assistant who delights in helping people learn about research from the IDETC Conference."
        "Your main task is to provide an ANSWER to the USER_QUERY based on the RESEARCH_EXCERPTS."
        "Your ANSWER should be concise.\n\n"
        "RESEARCH_EXCERPTS:\n{{EXCERPTS_GO_HERE}}\n\n"
        "USER_GUERY:\n{{QUERY_GOES_HERE}}\n\n"
        "ANSWER:\n"
    )

    references = {}
    research_excerpts = ""

    for i in range(k):
        title = top_five["title"].values[i]
        id = top_five["id"].values[i]
        url = "https://doi.org/10.1115/" + id
        text = top_five["text"].values[i]

        research_excerpts += (
            str(i + i) + ". This excerpt is from: '" + title + "':\n" + text + "\n"
        )
        header = "[" + title.title() + "](" + url + ")\n"

        if header not in references.keys():
            references[header] = []

        references[header].append(text)

    prompt = prompt.replace("{{EXCERPTS_GO_HERE}}", research_excerpts)
    prompt = prompt.replace("{{QUERY_GOES_HERE}}", query)

    print(references)

    list_of_references = "\n".join(
        [
            "### "
            + hyperlinked_title
            + "\n\n> ".join(
                [
                    "",
                    *[
                        '"...' + excerpt + '..."'
                        for excerpt in references[hyperlinked_title]
                    ],
                ]
            )
            for idx, hyperlinked_title in enumerate(references.keys())
        ]
    )

    return (
        prompt,
        "\n\n<details><summary><h3>References</h3></summary>\n\n"
        + list_of_references
        + "\n\n</summary>",
    )


def postprocess(response: str, bypass_from_preprocessing: str) -> str:
    """
    Applies a postprocessing step to the LLM's response before the user receives it
    Args:
        response (str): The LLM's response
        bypass_from_preprocessing (str): The bypass variable from the preprocessing step
    Returns:
        str: The postprocessed response
    """
    return response + bypass_from_preprocessing


@spaces.GPU
def reply(message: str, history: list[str]) -> str:
    """
    This function is responsible for crafting a response
    Args:
        message (str): The user's message
        history (list[str]): The conversation history
    Returns:
        str: The AI's response
    """

    # Apply preprocessing
    message, bypass = preprocess(message, 10)

    # This is some handling that is applied to the history variable to put it in a good format
    history_transformer_format = [
        {"role": role, "content": message_pair[idx]}
        for message_pair in history
        for idx, role in enumerate(["user", "assistant"])
        if message_pair[idx] is not None
    ] + [{"role": "user", "content": message}]

    # Stream a response from pipe
    text = tokenizer.apply_chat_template(
        history_transformer_format, tokenize=False, add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0")

    generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512)
    t = threading.Thread(target=chat_model.generate, kwargs=generate_kwargs)
    t.start()

    partial_message = ""
    for new_token in streamer:
        if new_token != "<":
            partial_message += new_token
            time.sleep(0.05)
            yield partial_message

    yield partial_message + bypass


# Create and run the gradio interface
gradio.ChatInterface(
    reply,
    examples=EXAMPLE_QUERIES,
    chatbot=gradio.Chatbot(
        avatar_images=(
            None,
            "https://event.asme.org/Events/media/library/images/IDETC-CIE/IDETC-Logo-Announcements.png?ext=.png",
        ),
        show_label=False,
        show_share_button=False,
        show_copy_button=False,
        value=[[None, GREETING]],
        height="60vh",
        bubble_full_width=False,
    ),
    retry_btn=None,
    undo_btn=None,
    clear_btn=None,
).launch(debug=True)