virtual-data-analyst / functions /query_functions.py
nolanzandi's picture
doc_db_integration (#32)
0614630 verified
raw
history blame
6.41 kB
from typing import List
from typing import AnyStr
from haystack import component
import pandas as pd
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)
import sqlite3
import psycopg2
from pymongo import MongoClient
import pymongoarrow.monkey
from utils import TEMP_DIR
import ast
@component
class SQLiteQuery:
def __init__(self, sql_database: str):
self.connection = sqlite3.connect(sql_database, check_same_thread=False)
@component.output_types(results=List[str], queries=List[str])
def run(self, queries: List[str], session_hash):
print("ATTEMPTING TO RUN SQLITE QUERY")
dir_path = TEMP_DIR / str(session_hash)
results = []
for query in queries:
result = pd.read_sql(query, self.connection)
result.to_csv(f'{dir_path}/file_upload/query.csv', index=False)
results.append(f"{result}")
self.connection.close()
return {"results": results, "queries": queries}
def sqlite_query_func(queries: List[str], session_hash, **kwargs):
dir_path = TEMP_DIR / str(session_hash)
sql_query = SQLiteQuery(f'{dir_path}/file_upload/data_source.db')
try:
result = sql_query.run(queries, session_hash)
if len(result["results"][0]) > 1000:
print("QUERY TOO LARGE")
return {"reply": "query result too large to be processed by llm, the query results are in our query.csv file. If you need to display the results directly, perhaps use the table_generation_func function."}
else:
return {"reply": result["results"][0]}
except Exception as e:
reply = f"""There was an error running the SQL Query = {queries}
The error is {e},
You should probably try again.
"""
return {"reply": reply}
@component
class PostgreSQLQuery:
def __init__(self, url: str, sql_port: int, sql_user: str, sql_pass: str, sql_db_name: str):
self.connection = psycopg2.connect(
database=sql_db_name,
user=sql_user,
password=sql_pass,
host=url, # e.g., "localhost" or an IP address
port=sql_port # default is 5432
)
@component.output_types(results=List[str], queries=List[str])
def run(self, queries: List[str], session_hash):
print("ATTEMPTING TO RUN POSTGRESQL QUERY")
dir_path = TEMP_DIR / str(session_hash)
results = []
for query in queries:
print(query)
result = pd.read_sql_query(query, self.connection)
result.to_csv(f'{dir_path}/sql/query.csv', index=False)
results.append(f"{result}")
self.connection.close()
return {"results": results, "queries": queries}
def sql_query_func(queries: List[str], session_hash, db_url, db_port, db_user, db_pass, db_name, **kwargs):
sql_query = PostgreSQLQuery(db_url, db_port, db_user, db_pass, db_name)
try:
result = sql_query.run(queries, session_hash)
print("RESULT")
print(result)
if len(result["results"][0]) > 1000:
print("QUERY TOO LARGE")
return {"reply": "query result too large to be processed by llm, the query results are in our query.csv file. If you need to display the results directly, perhaps use the table_generation_func function."}
else:
return {"reply": result["results"][0]}
except Exception as e:
reply = f"""There was an error running the SQL Query = {queries}
The error is {e},
You should probably try again.
"""
print(reply)
return {"reply": reply}
@component
class DocDBQuery:
def __init__(self, connection_string: str, doc_db_name: str):
client = MongoClient(connection_string)
self.client = client
self.connection = client[doc_db_name]
@component.output_types(results=List[str], queries=List[str])
def run(self, aggregation_pipeline: List[str], db_collection, session_hash):
pymongoarrow.monkey.patch_all()
print("ATTEMPTING TO RUN MONGODB QUERY")
dir_path = TEMP_DIR / str(session_hash)
results = []
print(aggregation_pipeline)
aggregation_pipeline = aggregation_pipeline.replace(" ", "")
false_replace = [':false', ': false']
false_value = ':False'
true_replace = [':true', ': true']
true_value = ':True'
for replace in false_replace:
aggregation_pipeline = aggregation_pipeline.replace(replace, false_value)
for replace in true_replace:
aggregation_pipeline = aggregation_pipeline.replace(replace, true_value)
query_list = ast.literal_eval(aggregation_pipeline)
print("QUERY List")
print(query_list)
print(db_collection)
db = self.connection
collection = db[db_collection]
print(collection)
docs = collection.aggregate_pandas_all(query_list)
print("DATA FRAME COMPLETE")
docs.to_csv(f'{dir_path}/doc_db/query.csv', index=False)
print("CSV COMPLETE")
results.append(f"{docs}")
self.client.close()
return {"results": results, "queries": aggregation_pipeline}
def doc_db_query_func(aggregation_pipeline: List[str], db_collection: AnyStr, session_hash, connection_string, doc_db_name, **kwargs):
doc_db_query = DocDBQuery(connection_string, doc_db_name)
try:
result = doc_db_query.run(aggregation_pipeline, db_collection, session_hash)
print("RESULT")
if len(result["results"][0]) > 1000:
print("QUERY TOO LARGE")
return {"reply": "query result too large to be processed by llm, the query results are in our query.csv file. If you need to display the results directly, perhaps use the table_generation_func function."}
else:
return {"reply": result["results"][0]}
except Exception as e:
reply = f"""There was an error running the NoSQL (Mongo) Query = {aggregation_pipeline}
The error is {e},
You should probably try again.
"""
print(reply)
return {"reply": reply}