ByteMaster01's picture
initial commit
0a65f9d
import pandas as pd
from typing_extensions import Any, List, Dict
from loguru import logger
from tqdm import tqdm
from .base_conversion_utils import (
clean_query,
build_schema_maps,
convert_actual_code_to_modified_dict,
convert_modified_to_actual_code_string
)
from .line_based_parsing import (
clean_modified_dict,
convert_to_lines,
parse_line_based_query
)
from .schema_utils import schema_to_line_based
def modify_single_row_base_form(mongo_query: str, schema: Dict[str, Any]) -> str:
"""
Modifies a single MongoDB query string based on the provided schema and schema maps.
"""
try:
# Clean the query
mongo_query = clean_query(mongo_query)
# Build schema maps
in2out, out2in = build_schema_maps(schema)
# Convert the actual code to modified code
modified_query = convert_actual_code_to_modified_dict(mongo_query, out2in)
# Collection Name
collection_name = schema["collections"][0]["name"]
# Convert the modified code back to actual code
reconstructed_query = convert_modified_to_actual_code_string(modified_query, in2out, collection_name)
# Clean the reconstructed query
reconstructed_query = clean_query(reconstructed_query)
if reconstructed_query != mongo_query:
return None, None, None, None, None, None
else:
return mongo_query, modified_query, collection_name, in2out, out2in, schema
except Exception as _:
return None, None, None, None, None, None
def modify_all_rows_base_from(mongo_queries: List[str], schemas: List[Dict[str, Any]], nl_queries: List[str], additional_infos: List[str]) -> List[Dict[str, Any]]:
"""
Modifies all MongoDB queries based on the provided schemas.
"""
modified_queries = []
for i, (mongo_query, schema) in tqdm(enumerate(zip(mongo_queries, schemas)), total=len(mongo_queries), desc="Modifying Queries"):
mongo_query, modified_query, collection_name, in2out, out2in, schema = modify_single_row_base_form(mongo_query, schema)
if modified_query is not None:
modified_queries.append({
"mongo_query": mongo_query,
"natural_language_query": nl_queries[i],
"additional_info": additional_infos[i],
"modified_query": modified_query,
"collection_name": collection_name,
"in2out": in2out,
"out2in": out2in,
"schema": schema
})
return modified_queries
def modify_line_based_parsing(modified_query_data: str) -> Dict[str, Any]:
"""
Tests the line-based parsing of a modified MongoDB query.
"""
try:
modified_query = clean_modified_dict(modified_query_data["modified_query"])
lines = convert_to_lines(modified_query)
reconstructed_query = parse_line_based_query(lines)
if modified_query != reconstructed_query:
return None
else:
modified_query_data["line_based_query"] = lines
return modified_query_data
except Exception as e:
return None
def modify_all_line_based_parsing(modified_queries: List[Dict[str, Any]]):
"""
Tests the line-based parsing for all modified MongoDB queries.
"""
line_based_queries = []
for query_data in tqdm(modified_queries, desc="Testing Line-based Parsing", total=len(modified_queries)):
line_based_query = modify_line_based_parsing(query_data)
if line_based_query:
line_based_queries.append(line_based_query)
return line_based_queries
def modify_all_schema(query_data: List[Dict[str, Any]]) -> List[str]:
"""
Converts all schemas to line-based format.
"""
final_data = []
for query in tqdm(query_data, desc="Converting Schemas to Line-based Format", total=len(query_data)):
# try:
line_based_schema = schema_to_line_based(query["schema"])
# if line_based_schema:
query["line_based_schema"] = line_based_schema
final_data.append(query)
# except Exception as e:
# pass
# logger.debug(f"Line-based schema: {line_based_schema}")
return final_data
def load_csv(file_path: str) -> pd.DataFrame:
"""
Loads a CSV file into a pandas DataFrame.
"""
try:
df = pd.read_csv(file_path)
logger.info(f"Loaded CSV file: {file_path}")
return df
except Exception as e:
logger.error(f"Error loading CSV file: {e}")
raise e
def modify_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""
Modifies a DataFrame by applying the modify_all_rows function.
"""
logger.info("Modifying DataFrame...")
logger.debug(f"input DataFrame length: {len(df)}")
mongo_queries = df["mongo_query"].tolist()
schemas = df["schema"].apply(eval).tolist()
nl_queries = df["natural_language_query"].tolist()
additional_infos = df["additional_info"].tolist()
modified_queries = modify_all_rows_base_from(mongo_queries, schemas, nl_queries, additional_infos)
logger.debug(f"Modified queries length: {len(modified_queries)}")
line_based_queries = modify_all_line_based_parsing(modified_queries)
logger.debug(f"Line-based queries length: {len(line_based_queries)}")
final_data = modify_all_schema(line_based_queries)
logger.debug(f"Modified schemas length: {len(final_data)}")
return final_data
def main(final_data: List[Dict[str, Any]]):
# try reconstructing original query from line-based query
for i in range(len(final_data)):
index_allowed = [746]
if i in index_allowed:
continue
original_query = final_data[i]["mongo_query"]
line_based_query = final_data[i]["line_based_query"]
# reconstructed modified query
reconstructed_modified_query = parse_line_based_query(line_based_query)
# reconstructed original query
reconstructed_original_query = convert_modified_to_actual_code_string(reconstructed_modified_query, final_data[i]["in2out"], final_data[i]["collection_name"])
if original_query != clean_query(reconstructed_original_query):
logger.error(f"index: {i}")
logger.error(f"Original query: {original_query}")
logger.error(f"Reconstructed original query: {reconstructed_original_query}")
logger.error(f"Modified query: {final_data[i]['modified_query']}")
logger.error(f"Reconstructed modified query: {reconstructed_modified_query}")
logger.error(f"Line-based query: {line_based_query}")
# logger.error(f"Schema: {final_data[i]['schema']}")
logger.warning("--------------------------------------------------")
assert original_query == clean_query(reconstructed_original_query), f"Original query does not match reconstructed original query at index {i}"
exit(0)
if __name__ == "__main__":
pdf_path = "./data_v3/data_v2.csv"
df = load_csv(pdf_path)
final_data = modify_dataframe(df)
# main(final_data)
logger.info(f"Final data length: {len(final_data)}")
logger.debug(f"Final data type: {final_data[0]}\n\n")
for i, (query_data) in enumerate(final_data):
logger.debug(f"Modified schema {i}: {query_data['line_based_schema']}")
logger.debug(f"Line-based query {i}: {query_data['line_based_query']}")
logger.debug(f"NL query {i}: {query_data['natural_language_query']}")
logger.debug(f"Additional info {i}: {query_data['additional_info']}")
print('\n\n\n\n')
if i > 3:
break