File size: 1,942 Bytes
8618f46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# Retriever Section

import datasets
from langchain.docstore.document import Document
from langchain.tools import Tool
from transformers import AutoTokenizer, TFAutoModel

# Load the dataset
guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")

def concatenate_text(examples):
    return {
        "text": "metadata={name:"+examples["name"]+"},"+
        "page_content=Name:"+examples["name"]+"\n"+
        "Relation:"+examples["relation"]+"\n"+
        "Description:"+examples["description"]+"\n"+
        "Email:"+examples["email"]
           }

docs = guest_dataset.map(concatenate_text)

model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = TFAutoModel.from_pretrained(model_ckpt, from_pt=True)

def cls_pooling(model_output):
    return model_output.last_hidden_state[:, 0]

def get_embeddings(text_list):
    encoded_input = tokenizer(
        text_list, padding=True, truncation=True, return_tensors="tf"
    )
    encoded_input = {k: v for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return cls_pooling(model_output)

embeddings_dataset = docs.map(
    lambda x: {"embeddings": get_embeddings(x["text"]).numpy()[0]}
)

embeddings_dataset.add_faiss_index(column="embeddings")

def extract_text(query: str) -> str:
    """Retrieves detailed information about gala guests based on their name or relation."""
    query_embedding = get_embeddings([query]).numpy()
    scores, samples = embeddings_dataset.get_nearest_examples(
        "embeddings", query_embedding, k=2
    )
    if samples:
        return "\n\n".join([text for text in samples["text"]])
    else:
        return "No matching guest information found."

guest_info_tool = Tool(
    name="guest_info_retriever",
    func=extract_text,
    description="Retrieves detailed information about gala guests based on their name or relation."
)