import asyncio from factool.knowledge_qa.google_serper import GoogleSerperAPIWrapper from factool.utils.openai_wrapper import OpenAIEmbed import json import os import numpy as np import jsonlines import pdb class google_search(): def __init__(self, snippet_cnt): self.serper = GoogleSerperAPIWrapper(snippet_cnt=snippet_cnt) async def run(self, queries): return await self.serper.run(queries) class local_search(): def __init__(self, snippet_cnt, data_link, embedding_link=None): self.snippet_cnt = snippet_cnt self.data_link = data_link self.embedding_link = embedding_link self.openai_embed = OpenAIEmbed() self.data = None self.embedding = None asyncio.run(self.init_async()) async def init_async(self): print("init local search") self.load_data_by_link() if self.embedding_link is None: await self.calculate_embedding() else: self.load_embedding_by_link() print("loaded data and embedding") def add_suffix_to_json_filename(self, filename): base_name, extension = os.path.splitext(filename) return base_name + '_embed' + extension def load_data_by_link(self): #load data from json link self.data = [] #self.data = json.load(open(self.data_link, 'r')) with jsonlines.open(self.data_link) as reader: for obj in reader: self.data.append(obj['text']) def load_embedding_by_link(self): self.embedding = [] #self.embedding = json.load(open(self.embedding_link, 'r')) with jsonlines.open(self.embedding_link) as reader: for obj in reader: self.embedding.append(obj) def save_embeddings(self): #json.dump(self.embedding, open(self.add_suffix_to_json_filename(self.data_link), 'w')) with jsonlines.open(self.add_suffix_to_json_filename(self.data_link), mode='w') as writer: writer.write_all(self.embedding) async def calculate_embedding(self): result = await self.openai_embed.process_batch(self.data,retry=3) self.embedding = [emb["data"][0]["embedding"] for emb in result] self.save_embeddings() async def search(self, query): result = await self.openai_embed.create_embedding(query) query_embed = result["data"][0]["embedding"] dot_product = np.dot(self.embedding, query_embed) sorted_indices = np.argsort(dot_product)[::-1] top_k_indices = sorted_indices[:self.snippet_cnt] return [{"content":self.data[i],"source":"local"} for i in top_k_indices] async def run(self, queries): flattened_queries = [] for sublist in queries: if sublist is None: sublist = ['None', 'None'] for item in sublist: flattened_queries.append(item) snippets = await asyncio.gather(*[self.search(query) for query in flattened_queries]) snippets_split = [snippets[i] + snippets[i+1] for i in range(0, len(snippets), 2)] return snippets_split