QnA / src /tools /retriever.py
YvesP's picture
updated version with extracts from documents
988c713
raw
history blame
1.29 kB
from src.model.doc import Doc
from src.model.block import Block
class Retriever:
def __init__(self, db_client, plan_doc: Doc, content_doc: Doc, content_fr_doc: Doc, collection_name: str):
plan_blocks: [Block] = plan_doc.blocks
content_blocks: [Block] = content_doc.blocks
content_fr_blocks: [Block] = content_fr_doc.blocks
for pb, cb in zip(plan_blocks, content_blocks):
cb.specials = pb.specials
for cb, cb_fr in zip(content_blocks, content_fr_blocks):
cb.content_fr = cb_fr.content
cb.title_fr = cb_fr.title
self.collection = db_client.create_collection(name=collection_name)
self.collection.add(
documents=[block.content for block in plan_blocks],
ids=[block.index for block in plan_blocks],
metadatas=[block.to_dict() for block in content_blocks]
)
def similarity_search(self, query: str) -> {}:
res = self.collection.query(query_texts=query)
block_dict_sources = res['metadatas'][0]
distances = res['distances'][0]
blocks = []
for bd, d in zip(block_dict_sources, distances):
b = Block().from_dict(bd)
b.distance = d
blocks.append(b)
return blocks