File size: 1,942 Bytes
8618f46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
# Retriever Section
import datasets
from langchain.docstore.document import Document
from langchain.tools import Tool
from transformers import AutoTokenizer, TFAutoModel
# Load the dataset
guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
def concatenate_text(examples):
return {
"text": "metadata={name:"+examples["name"]+"},"+
"page_content=Name:"+examples["name"]+"\n"+
"Relation:"+examples["relation"]+"\n"+
"Description:"+examples["description"]+"\n"+
"Email:"+examples["email"]
}
docs = guest_dataset.map(concatenate_text)
model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = TFAutoModel.from_pretrained(model_ckpt, from_pt=True)
def cls_pooling(model_output):
return model_output.last_hidden_state[:, 0]
def get_embeddings(text_list):
encoded_input = tokenizer(
text_list, padding=True, truncation=True, return_tensors="tf"
)
encoded_input = {k: v for k, v in encoded_input.items()}
model_output = model(**encoded_input)
return cls_pooling(model_output)
embeddings_dataset = docs.map(
lambda x: {"embeddings": get_embeddings(x["text"]).numpy()[0]}
)
embeddings_dataset.add_faiss_index(column="embeddings")
def extract_text(query: str) -> str:
"""Retrieves detailed information about gala guests based on their name or relation."""
query_embedding = get_embeddings([query]).numpy()
scores, samples = embeddings_dataset.get_nearest_examples(
"embeddings", query_embedding, k=2
)
if samples:
return "\n\n".join([text for text in samples["text"]])
else:
return "No matching guest information found."
guest_info_tool = Tool(
name="guest_info_retriever",
func=extract_text,
description="Retrieves detailed information about gala guests based on their name or relation."
) |