|
|
|
|
|
import datasets |
|
from langchain.docstore.document import Document |
|
from langchain.tools import Tool |
|
from transformers import AutoTokenizer, TFAutoModel |
|
|
|
|
|
guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train") |
|
|
|
def concatenate_text(examples): |
|
return { |
|
"text": "metadata={name:"+examples["name"]+"},"+ |
|
"page_content=Name:"+examples["name"]+"\n"+ |
|
"Relation:"+examples["relation"]+"\n"+ |
|
"Description:"+examples["description"]+"\n"+ |
|
"Email:"+examples["email"] |
|
} |
|
|
|
docs = guest_dataset.map(concatenate_text) |
|
|
|
model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1" |
|
tokenizer = AutoTokenizer.from_pretrained(model_ckpt) |
|
model = TFAutoModel.from_pretrained(model_ckpt, from_pt=True) |
|
|
|
def cls_pooling(model_output): |
|
return model_output.last_hidden_state[:, 0] |
|
|
|
def get_embeddings(text_list): |
|
encoded_input = tokenizer( |
|
text_list, padding=True, truncation=True, return_tensors="tf" |
|
) |
|
encoded_input = {k: v for k, v in encoded_input.items()} |
|
model_output = model(**encoded_input) |
|
return cls_pooling(model_output) |
|
|
|
embeddings_dataset = docs.map( |
|
lambda x: {"embeddings": get_embeddings(x["text"]).numpy()[0]} |
|
) |
|
|
|
embeddings_dataset.add_faiss_index(column="embeddings") |
|
|
|
def extract_text(query: str) -> str: |
|
"""Retrieves detailed information about gala guests based on their name or relation.""" |
|
query_embedding = get_embeddings([query]).numpy() |
|
scores, samples = embeddings_dataset.get_nearest_examples( |
|
"embeddings", query_embedding, k=2 |
|
) |
|
if samples: |
|
return "\n\n".join([text for text in samples["text"]]) |
|
else: |
|
return "No matching guest information found." |
|
|
|
guest_info_tool = Tool( |
|
name="guest_info_retriever", |
|
func=extract_text, |
|
description="Retrieves detailed information about gala guests based on their name or relation." |
|
) |