|
from src.model.doc import Doc |
|
from src.model.block import Block |
|
|
|
|
|
class Retriever: |
|
|
|
def __init__(self, db_client, plan_doc: Doc, content_doc: Doc, content_fr_doc: Doc, collection_name: str): |
|
plan_blocks: [Block] = plan_doc.blocks |
|
content_blocks: [Block] = content_doc.blocks |
|
content_fr_blocks: [Block] = content_fr_doc.blocks |
|
for pb, cb in zip(plan_blocks, content_blocks): |
|
cb.specials = pb.specials |
|
for cb, cb_fr in zip(content_blocks, content_fr_blocks): |
|
cb.content_fr = cb_fr.content |
|
cb.title_fr = cb_fr.title |
|
self.collection = db_client.create_collection(name=collection_name) |
|
self.collection.add( |
|
documents=[block.content for block in plan_blocks], |
|
ids=[block.index for block in plan_blocks], |
|
metadatas=[block.to_dict() for block in content_blocks] |
|
) |
|
|
|
def similarity_search(self, query: str) -> {}: |
|
res = self.collection.query(query_texts=query) |
|
block_dict_sources = res['metadatas'][0] |
|
distances = res['distances'][0] |
|
blocks = [] |
|
for bd, d in zip(block_dict_sources, distances): |
|
b = Block().from_dict(bd) |
|
b.distance = d |
|
blocks.append(b) |
|
return blocks |
|
|