Spaces:
Running
Running
File size: 7,667 Bytes
0a65f9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import pandas as pd
from typing_extensions import Any, List, Dict
from loguru import logger
from tqdm import tqdm
from .base_conversion_utils import (
clean_query,
build_schema_maps,
convert_actual_code_to_modified_dict,
convert_modified_to_actual_code_string
)
from .line_based_parsing import (
clean_modified_dict,
convert_to_lines,
parse_line_based_query
)
from .schema_utils import schema_to_line_based
def modify_single_row_base_form(mongo_query: str, schema: Dict[str, Any]) -> str:
"""
Modifies a single MongoDB query string based on the provided schema and schema maps.
"""
try:
# Clean the query
mongo_query = clean_query(mongo_query)
# Build schema maps
in2out, out2in = build_schema_maps(schema)
# Convert the actual code to modified code
modified_query = convert_actual_code_to_modified_dict(mongo_query, out2in)
# Collection Name
collection_name = schema["collections"][0]["name"]
# Convert the modified code back to actual code
reconstructed_query = convert_modified_to_actual_code_string(modified_query, in2out, collection_name)
# Clean the reconstructed query
reconstructed_query = clean_query(reconstructed_query)
if reconstructed_query != mongo_query:
return None, None, None, None, None, None
else:
return mongo_query, modified_query, collection_name, in2out, out2in, schema
except Exception as _:
return None, None, None, None, None, None
def modify_all_rows_base_from(mongo_queries: List[str], schemas: List[Dict[str, Any]], nl_queries: List[str], additional_infos: List[str]) -> List[Dict[str, Any]]:
"""
Modifies all MongoDB queries based on the provided schemas.
"""
modified_queries = []
for i, (mongo_query, schema) in tqdm(enumerate(zip(mongo_queries, schemas)), total=len(mongo_queries), desc="Modifying Queries"):
mongo_query, modified_query, collection_name, in2out, out2in, schema = modify_single_row_base_form(mongo_query, schema)
if modified_query is not None:
modified_queries.append({
"mongo_query": mongo_query,
"natural_language_query": nl_queries[i],
"additional_info": additional_infos[i],
"modified_query": modified_query,
"collection_name": collection_name,
"in2out": in2out,
"out2in": out2in,
"schema": schema
})
return modified_queries
def modify_line_based_parsing(modified_query_data: str) -> Dict[str, Any]:
"""
Tests the line-based parsing of a modified MongoDB query.
"""
try:
modified_query = clean_modified_dict(modified_query_data["modified_query"])
lines = convert_to_lines(modified_query)
reconstructed_query = parse_line_based_query(lines)
if modified_query != reconstructed_query:
return None
else:
modified_query_data["line_based_query"] = lines
return modified_query_data
except Exception as e:
return None
def modify_all_line_based_parsing(modified_queries: List[Dict[str, Any]]):
"""
Tests the line-based parsing for all modified MongoDB queries.
"""
line_based_queries = []
for query_data in tqdm(modified_queries, desc="Testing Line-based Parsing", total=len(modified_queries)):
line_based_query = modify_line_based_parsing(query_data)
if line_based_query:
line_based_queries.append(line_based_query)
return line_based_queries
def modify_all_schema(query_data: List[Dict[str, Any]]) -> List[str]:
"""
Converts all schemas to line-based format.
"""
final_data = []
for query in tqdm(query_data, desc="Converting Schemas to Line-based Format", total=len(query_data)):
# try:
line_based_schema = schema_to_line_based(query["schema"])
# if line_based_schema:
query["line_based_schema"] = line_based_schema
final_data.append(query)
# except Exception as e:
# pass
# logger.debug(f"Line-based schema: {line_based_schema}")
return final_data
def load_csv(file_path: str) -> pd.DataFrame:
"""
Loads a CSV file into a pandas DataFrame.
"""
try:
df = pd.read_csv(file_path)
logger.info(f"Loaded CSV file: {file_path}")
return df
except Exception as e:
logger.error(f"Error loading CSV file: {e}")
raise e
def modify_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""
Modifies a DataFrame by applying the modify_all_rows function.
"""
logger.info("Modifying DataFrame...")
logger.debug(f"input DataFrame length: {len(df)}")
mongo_queries = df["mongo_query"].tolist()
schemas = df["schema"].apply(eval).tolist()
nl_queries = df["natural_language_query"].tolist()
additional_infos = df["additional_info"].tolist()
modified_queries = modify_all_rows_base_from(mongo_queries, schemas, nl_queries, additional_infos)
logger.debug(f"Modified queries length: {len(modified_queries)}")
line_based_queries = modify_all_line_based_parsing(modified_queries)
logger.debug(f"Line-based queries length: {len(line_based_queries)}")
final_data = modify_all_schema(line_based_queries)
logger.debug(f"Modified schemas length: {len(final_data)}")
return final_data
def main(final_data: List[Dict[str, Any]]):
# try reconstructing original query from line-based query
for i in range(len(final_data)):
index_allowed = [746]
if i in index_allowed:
continue
original_query = final_data[i]["mongo_query"]
line_based_query = final_data[i]["line_based_query"]
# reconstructed modified query
reconstructed_modified_query = parse_line_based_query(line_based_query)
# reconstructed original query
reconstructed_original_query = convert_modified_to_actual_code_string(reconstructed_modified_query, final_data[i]["in2out"], final_data[i]["collection_name"])
if original_query != clean_query(reconstructed_original_query):
logger.error(f"index: {i}")
logger.error(f"Original query: {original_query}")
logger.error(f"Reconstructed original query: {reconstructed_original_query}")
logger.error(f"Modified query: {final_data[i]['modified_query']}")
logger.error(f"Reconstructed modified query: {reconstructed_modified_query}")
logger.error(f"Line-based query: {line_based_query}")
# logger.error(f"Schema: {final_data[i]['schema']}")
logger.warning("--------------------------------------------------")
assert original_query == clean_query(reconstructed_original_query), f"Original query does not match reconstructed original query at index {i}"
exit(0)
if __name__ == "__main__":
pdf_path = "./data_v3/data_v2.csv"
df = load_csv(pdf_path)
final_data = modify_dataframe(df)
# main(final_data)
logger.info(f"Final data length: {len(final_data)}")
logger.debug(f"Final data type: {final_data[0]}\n\n")
for i, (query_data) in enumerate(final_data):
logger.debug(f"Modified schema {i}: {query_data['line_based_schema']}")
logger.debug(f"Line-based query {i}: {query_data['line_based_query']}")
logger.debug(f"NL query {i}: {query_data['natural_language_query']}")
logger.debug(f"Additional info {i}: {query_data['additional_info']}")
print('\n\n\n\n')
if i > 3:
break
|