|
import pandas as pd |
|
import hashlib |
|
import requests |
|
from typing import List, Optional |
|
from datetime import datetime |
|
from langchain.schema.embeddings import Embeddings |
|
from streamlit.runtime.uploaded_file_manager import UploadedFile |
|
from clickhouse_connect import get_client |
|
from multiprocessing.pool import ThreadPool |
|
from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings |
|
from .helper import create_retriever_tool |
|
|
|
parser_url = "https://api.unstructured.io/general/v0/general" |
|
|
|
|
|
def parse_files(api_key, user_id, files: List[UploadedFile]): |
|
def parse_file(file: UploadedFile): |
|
headers = { |
|
"accept": "application/json", |
|
"unstructured-api-key": api_key, |
|
} |
|
data = {"strategy": "auto", "ocr_languages": ["eng"]} |
|
file_hash = hashlib.sha256(file.read()).hexdigest() |
|
file_data = {"files": (file.name, file.getvalue(), file.type)} |
|
response = requests.post( |
|
parser_url, headers=headers, data=data, files=file_data |
|
) |
|
json_response = response.json() |
|
if response.status_code != 200: |
|
raise ValueError(str(json_response)) |
|
texts = [ |
|
{ |
|
"text": t["text"], |
|
"file_name": t["metadata"]["filename"], |
|
"entity_id": hashlib.sha256( |
|
(file_hash + t["text"]).encode() |
|
).hexdigest(), |
|
"user_id": user_id, |
|
"created_by": datetime.now(), |
|
} |
|
for t in json_response |
|
if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10 |
|
] |
|
return texts |
|
|
|
with ThreadPool(8) as p: |
|
rows = [] |
|
for r in p.imap_unordered(parse_file, files): |
|
rows.extend(r) |
|
return rows |
|
|
|
|
|
def extract_embedding(embeddings: Embeddings, texts): |
|
if len(texts) > 0: |
|
embs = embeddings.embed_documents([t["text"] for _, t in enumerate(texts)]) |
|
for i, _ in enumerate(texts): |
|
texts[i]["vector"] = embs[i] |
|
return texts |
|
raise ValueError("No texts extracted!") |
|
|
|
|
|
class PrivateKnowledgeBase: |
|
def __init__( |
|
self, |
|
host, |
|
port, |
|
username, |
|
password, |
|
embedding: Embeddings, |
|
parser_api_key, |
|
db="chat", |
|
kb_table="private_kb", |
|
tool_table="private_tool", |
|
) -> None: |
|
super().__init__() |
|
kb_schema_ = f""" |
|
CREATE TABLE IF NOT EXISTS {db}.{kb_table}( |
|
entity_id String, |
|
file_name String, |
|
text String, |
|
user_id String, |
|
created_by DateTime, |
|
vector Array(Float32), |
|
CONSTRAINT cons_vec_len CHECK length(vector) = 768, |
|
VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine') |
|
) ENGINE = ReplacingMergeTree ORDER BY entity_id |
|
""" |
|
tool_schema_ = f""" |
|
CREATE TABLE IF NOT EXISTS {db}.{tool_table}( |
|
tool_id String, |
|
tool_name String, |
|
file_names Array(String), |
|
user_id String, |
|
created_by DateTime, |
|
tool_description String |
|
) ENGINE = ReplacingMergeTree ORDER BY tool_id |
|
""" |
|
self.kb_table = kb_table |
|
self.tool_table = tool_table |
|
config = MyScaleSettings( |
|
host=host, |
|
port=port, |
|
username=username, |
|
password=password, |
|
database=db, |
|
table=kb_table, |
|
) |
|
client = get_client( |
|
host=config.host, |
|
port=config.port, |
|
username=config.username, |
|
password=config.password, |
|
) |
|
client.command("SET allow_experimental_object_type=1") |
|
client.command(kb_schema_) |
|
client.command(tool_schema_) |
|
self.parser_api_key = parser_api_key |
|
self.vstore = MyScaleWithoutJSON( |
|
embedding=embedding, |
|
config=config, |
|
must_have_cols=["file_name", "text", "created_by"], |
|
) |
|
|
|
def list_files(self, user_id, tool_name=None): |
|
query = f""" |
|
SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph, |
|
arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars |
|
FROM {self.vstore.config.database}.{self.kb_table} |
|
WHERE user_id = '{user_id}' GROUP BY file_name |
|
""" |
|
return [r for r in self.vstore.client.query(query).named_results()] |
|
|
|
def add_by_file( |
|
self, user_id, files: List[UploadedFile], **kwargs |
|
): |
|
data = parse_files(self.parser_api_key, user_id, files) |
|
data = extract_embedding(self.vstore.embeddings, data) |
|
self.vstore.client.insert_df( |
|
self.kb_table, |
|
pd.DataFrame(data), |
|
database=self.vstore.config.database, |
|
) |
|
|
|
def clear(self, user_id): |
|
self.vstore.client.command( |
|
f"DELETE FROM {self.vstore.config.database}.{self.kb_table} " |
|
f"WHERE user_id='{user_id}'" |
|
) |
|
query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table} |
|
WHERE user_id = '{user_id}'""" |
|
self.vstore.client.command(query) |
|
|
|
def create_tool( |
|
self, user_id, tool_name, tool_description, files: Optional[List[str]] = None |
|
): |
|
self.vstore.client.insert_df( |
|
self.tool_table, |
|
pd.DataFrame( |
|
[ |
|
{ |
|
"tool_id": hashlib.sha256( |
|
(user_id + tool_name).encode("utf-8") |
|
).hexdigest(), |
|
"tool_name": tool_name, |
|
"file_names": files, |
|
"user_id": user_id, |
|
"created_by": datetime.now(), |
|
"tool_description": tool_description, |
|
} |
|
] |
|
), |
|
database=self.vstore.config.database, |
|
) |
|
|
|
def list_tools(self, user_id, tool_name=None): |
|
extended_where = f"AND tool_name = '{tool_name}'" if tool_name else "" |
|
query = f""" |
|
SELECT tool_name, tool_description, length(file_names) |
|
FROM {self.vstore.config.database}.{self.tool_table} |
|
WHERE user_id = '{user_id}' {extended_where} |
|
""" |
|
return [r for r in self.vstore.client.query(query).named_results()] |
|
|
|
def remove_tools(self, user_id, tool_names): |
|
tool_names = ",".join([f"'{t}'" for t in tool_names]) |
|
query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table} |
|
WHERE user_id = '{user_id}' AND tool_name IN [{tool_names}]""" |
|
self.vstore.client.command(query) |
|
|
|
def as_tools(self, user_id, tool_name=None): |
|
tools = self.list_tools(user_id=user_id, tool_name=tool_name) |
|
retrievers = { |
|
t["tool_name"]: create_retriever_tool( |
|
self.vstore.as_retriever( |
|
search_kwargs={ |
|
"where_str": ( |
|
f"user_id='{user_id}' " |
|
f"""AND file_name IN ( |
|
SELECT arrayJoin(file_names) FROM ( |
|
SELECT file_names |
|
FROM {self.vstore.config.database}.{self.tool_table} |
|
WHERE user_id = '{user_id}' AND tool_name = '{t['tool_name']}') |
|
)""" |
|
) |
|
}, |
|
), |
|
name=t["tool_name"], |
|
description=t["tool_description"], |
|
) |
|
for t in tools |
|
} |
|
return retrievers |
|
|