import pandas as pd from typing_extensions import Any, List, Dict from loguru import logger from tqdm import tqdm from .base_conversion_utils import ( clean_query, build_schema_maps, convert_actual_code_to_modified_dict, convert_modified_to_actual_code_string ) from .line_based_parsing import ( clean_modified_dict, convert_to_lines, parse_line_based_query ) from .schema_utils import schema_to_line_based def modify_single_row_base_form(mongo_query: str, schema: Dict[str, Any]) -> str: """ Modifies a single MongoDB query string based on the provided schema and schema maps. """ try: # Clean the query mongo_query = clean_query(mongo_query) # Build schema maps in2out, out2in = build_schema_maps(schema) # Convert the actual code to modified code modified_query = convert_actual_code_to_modified_dict(mongo_query, out2in) # Collection Name collection_name = schema["collections"][0]["name"] # Convert the modified code back to actual code reconstructed_query = convert_modified_to_actual_code_string(modified_query, in2out, collection_name) # Clean the reconstructed query reconstructed_query = clean_query(reconstructed_query) if reconstructed_query != mongo_query: return None, None, None, None, None, None else: return mongo_query, modified_query, collection_name, in2out, out2in, schema except Exception as _: return None, None, None, None, None, None def modify_all_rows_base_from(mongo_queries: List[str], schemas: List[Dict[str, Any]], nl_queries: List[str], additional_infos: List[str]) -> List[Dict[str, Any]]: """ Modifies all MongoDB queries based on the provided schemas. """ modified_queries = [] for i, (mongo_query, schema) in tqdm(enumerate(zip(mongo_queries, schemas)), total=len(mongo_queries), desc="Modifying Queries"): mongo_query, modified_query, collection_name, in2out, out2in, schema = modify_single_row_base_form(mongo_query, schema) if modified_query is not None: modified_queries.append({ "mongo_query": mongo_query, "natural_language_query": nl_queries[i], "additional_info": additional_infos[i], "modified_query": modified_query, "collection_name": collection_name, "in2out": in2out, "out2in": out2in, "schema": schema }) return modified_queries def modify_line_based_parsing(modified_query_data: str) -> Dict[str, Any]: """ Tests the line-based parsing of a modified MongoDB query. """ try: modified_query = clean_modified_dict(modified_query_data["modified_query"]) lines = convert_to_lines(modified_query) reconstructed_query = parse_line_based_query(lines) if modified_query != reconstructed_query: return None else: modified_query_data["line_based_query"] = lines return modified_query_data except Exception as e: return None def modify_all_line_based_parsing(modified_queries: List[Dict[str, Any]]): """ Tests the line-based parsing for all modified MongoDB queries. """ line_based_queries = [] for query_data in tqdm(modified_queries, desc="Testing Line-based Parsing", total=len(modified_queries)): line_based_query = modify_line_based_parsing(query_data) if line_based_query: line_based_queries.append(line_based_query) return line_based_queries def modify_all_schema(query_data: List[Dict[str, Any]]) -> List[str]: """ Converts all schemas to line-based format. """ final_data = [] for query in tqdm(query_data, desc="Converting Schemas to Line-based Format", total=len(query_data)): # try: line_based_schema = schema_to_line_based(query["schema"]) # if line_based_schema: query["line_based_schema"] = line_based_schema final_data.append(query) # except Exception as e: # pass # logger.debug(f"Line-based schema: {line_based_schema}") return final_data def load_csv(file_path: str) -> pd.DataFrame: """ Loads a CSV file into a pandas DataFrame. """ try: df = pd.read_csv(file_path) logger.info(f"Loaded CSV file: {file_path}") return df except Exception as e: logger.error(f"Error loading CSV file: {e}") raise e def modify_dataframe(df: pd.DataFrame) -> pd.DataFrame: """ Modifies a DataFrame by applying the modify_all_rows function. """ logger.info("Modifying DataFrame...") logger.debug(f"input DataFrame length: {len(df)}") mongo_queries = df["mongo_query"].tolist() schemas = df["schema"].apply(eval).tolist() nl_queries = df["natural_language_query"].tolist() additional_infos = df["additional_info"].tolist() modified_queries = modify_all_rows_base_from(mongo_queries, schemas, nl_queries, additional_infos) logger.debug(f"Modified queries length: {len(modified_queries)}") line_based_queries = modify_all_line_based_parsing(modified_queries) logger.debug(f"Line-based queries length: {len(line_based_queries)}") final_data = modify_all_schema(line_based_queries) logger.debug(f"Modified schemas length: {len(final_data)}") return final_data def main(final_data: List[Dict[str, Any]]): # try reconstructing original query from line-based query for i in range(len(final_data)): index_allowed = [746] if i in index_allowed: continue original_query = final_data[i]["mongo_query"] line_based_query = final_data[i]["line_based_query"] # reconstructed modified query reconstructed_modified_query = parse_line_based_query(line_based_query) # reconstructed original query reconstructed_original_query = convert_modified_to_actual_code_string(reconstructed_modified_query, final_data[i]["in2out"], final_data[i]["collection_name"]) if original_query != clean_query(reconstructed_original_query): logger.error(f"index: {i}") logger.error(f"Original query: {original_query}") logger.error(f"Reconstructed original query: {reconstructed_original_query}") logger.error(f"Modified query: {final_data[i]['modified_query']}") logger.error(f"Reconstructed modified query: {reconstructed_modified_query}") logger.error(f"Line-based query: {line_based_query}") # logger.error(f"Schema: {final_data[i]['schema']}") logger.warning("--------------------------------------------------") assert original_query == clean_query(reconstructed_original_query), f"Original query does not match reconstructed original query at index {i}" exit(0) if __name__ == "__main__": pdf_path = "./data_v3/data_v2.csv" df = load_csv(pdf_path) final_data = modify_dataframe(df) # main(final_data) logger.info(f"Final data length: {len(final_data)}") logger.debug(f"Final data type: {final_data[0]}\n\n") for i, (query_data) in enumerate(final_data): logger.debug(f"Modified schema {i}: {query_data['line_based_schema']}") logger.debug(f"Line-based query {i}: {query_data['line_based_query']}") logger.debug(f"NL query {i}: {query_data['natural_language_query']}") logger.debug(f"Additional info {i}: {query_data['additional_info']}") print('\n\n\n\n') if i > 3: break