# Retriever Section import datasets from langchain.docstore.document import Document from langchain.tools import Tool from transformers import AutoTokenizer, TFAutoModel # Load the dataset guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train") def concatenate_text(examples): return { "text": "metadata={name:"+examples["name"]+"},"+ "page_content=Name:"+examples["name"]+"\n"+ "Relation:"+examples["relation"]+"\n"+ "Description:"+examples["description"]+"\n"+ "Email:"+examples["email"] } docs = guest_dataset.map(concatenate_text) model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1" tokenizer = AutoTokenizer.from_pretrained(model_ckpt) model = TFAutoModel.from_pretrained(model_ckpt, from_pt=True) def cls_pooling(model_output): return model_output.last_hidden_state[:, 0] def get_embeddings(text_list): encoded_input = tokenizer( text_list, padding=True, truncation=True, return_tensors="tf" ) encoded_input = {k: v for k, v in encoded_input.items()} model_output = model(**encoded_input) return cls_pooling(model_output) embeddings_dataset = docs.map( lambda x: {"embeddings": get_embeddings(x["text"]).numpy()[0]} ) embeddings_dataset.add_faiss_index(column="embeddings") def extract_text(query: str) -> str: """Retrieves detailed information about gala guests based on their name or relation.""" query_embedding = get_embeddings([query]).numpy() scores, samples = embeddings_dataset.get_nearest_examples( "embeddings", query_embedding, k=2 ) if samples: return "\n\n".join([text for text in samples["text"]]) else: return "No matching guest information found." guest_info_tool = Tool( name="guest_info_retriever", func=extract_text, description="Retrieves detailed information about gala guests based on their name or relation." )