|
from smolagents import Tool |
|
from docling.document_converter import DocumentConverter |
|
from docling.chunking import HierarchicalChunker |
|
from sentence_transformers import SentenceTransformer, util |
|
import torch |
|
|
|
|
|
class ContentRetrieverTool(Tool): |
|
name = "retrieve_content" |
|
description = """Retrieve the content of a webpage or document in markdown format. Supports PDF, DOCX, XLSX, HTML, images, and more.""" |
|
inputs = { |
|
"url": { |
|
"type": "string", |
|
"description": "The URL or local path of the webpage or document to retrieve.", |
|
}, |
|
"query": { |
|
"type": "string", |
|
"description": "The subject on the page you are looking for. The shorter the more relevant content is returned.", |
|
}, |
|
} |
|
output_type = "string" |
|
|
|
def __init__( |
|
self, |
|
model_name: str | None = None, |
|
threshold: float = 0.2, |
|
**kwargs, |
|
): |
|
self.threshold = threshold |
|
self._document_converter = DocumentConverter() |
|
self._model = SentenceTransformer( |
|
model_name if model_name is not None else "all-MiniLM-L6-v2" |
|
) |
|
self._chunker = HierarchicalChunker() |
|
|
|
super().__init__(**kwargs) |
|
|
|
def forward(self, url: str, query: str) -> str: |
|
document = self._document_converter.convert(url).document |
|
|
|
chunks = list(self._chunker.chunk(dl_doc=document)) |
|
if len(chunks) == 0: |
|
return "No content found." |
|
|
|
chunks_text = [chunk.text for chunk in chunks] |
|
chunks_with_context = [self._chunker.contextualize(chunk) for chunk in chunks] |
|
chunks_context = [ |
|
chunks_with_context[i].replace(chunks_text[i], "").strip() |
|
for i in range(len(chunks)) |
|
] |
|
|
|
chunk_embeddings = self._model.encode(chunks_text, convert_to_tensor=True) |
|
context_embeddings = self._model.encode(chunks_context, convert_to_tensor=True) |
|
query_embedding = self._model.encode( |
|
[term.strip() for term in query.split(",") if term.strip()], |
|
convert_to_tensor=True, |
|
) |
|
|
|
selected_indices = [] |
|
for embeddings in [ |
|
context_embeddings, |
|
chunk_embeddings, |
|
]: |
|
|
|
for cos_scores in util.pytorch_cos_sim(query_embedding, embeddings): |
|
|
|
probabilities = torch.nn.functional.softmax(cos_scores, dim=0) |
|
|
|
sorted_indices = torch.argsort(probabilities, descending=True) |
|
|
|
|
|
cumulative = 0.0 |
|
for i in sorted_indices: |
|
cumulative += probabilities[i].item() |
|
selected_indices.append(i.item()) |
|
if cumulative >= self.threshold: |
|
break |
|
|
|
selected_indices = list( |
|
dict.fromkeys(selected_indices) |
|
) |
|
selected_indices = selected_indices[ |
|
::-1 |
|
] |
|
|
|
if len(selected_indices) == 0: |
|
return "No content found." |
|
|
|
return "\n\n".join([chunks_with_context[idx] for idx in selected_indices]) |
|
|