Spaces:
Sleeping
Sleeping
Commit
·
0a65f9d
1
Parent(s):
f4a7a03
initial commit
Browse files- app.py +262 -0
- configs/__pycache__/prompt_config.cpython-313.pyc +0 -0
- configs/prompt_config.py +32 -0
- data_utils/__init__.py +0 -0
- data_utils/__pycache__/__init__.cpython-313.pyc +0 -0
- data_utils/__pycache__/base_conversion_utils.cpython-313.pyc +0 -0
- data_utils/__pycache__/line_based_parsing.cpython-313.pyc +0 -0
- data_utils/__pycache__/schema_utils.cpython-313.pyc +0 -0
- data_utils/base_conversion_utils.py +572 -0
- data_utils/line_based_parsing.py +180 -0
- data_utils/schema_utils.py +149 -0
- data_utils/utils.py +183 -0
- requirements.txt +62 -0
app.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
import requests
|
4 |
+
import os
|
5 |
+
import subprocess
|
6 |
+
import wget
|
7 |
+
from loguru import logger
|
8 |
+
from data_utils.line_based_parsing import parse_line_based_query, convert_to_lines
|
9 |
+
from data_utils.base_conversion_utils import (
|
10 |
+
build_schema_maps,
|
11 |
+
convert_modified_to_actual_code_string
|
12 |
+
)
|
13 |
+
from data_utils.schema_utils import schema_to_line_based
|
14 |
+
from configs.prompt_config import SYSTEM_PROMPT_V3, MODEL_PROMPT_V3
|
15 |
+
|
16 |
+
LLAMA_SERVER_URL = "http://127.0.0.1:8080/v1/chat/completions"
|
17 |
+
MODEL_PATH = "./models/unsloth.Q8_0.gguf"
|
18 |
+
|
19 |
+
def download_model():
|
20 |
+
"""Download the model if it doesn't exist"""
|
21 |
+
os.makedirs("./models", exist_ok=True)
|
22 |
+
if not os.path.exists(MODEL_PATH):
|
23 |
+
logger.info("Downloading model weights...")
|
24 |
+
wget.download(
|
25 |
+
"https://huggingface.co/ByteMaster01/NL2SQL/resolve/main/unsloth.Q8_0.gguf",
|
26 |
+
MODEL_PATH
|
27 |
+
)
|
28 |
+
logger.info("\nModel download complete!")
|
29 |
+
|
30 |
+
def start_llama_server():
|
31 |
+
"""Start the llama.cpp server with the downloaded model"""
|
32 |
+
try:
|
33 |
+
logger.info("Starting llama.cpp server...")
|
34 |
+
subprocess.Popen([
|
35 |
+
"python", "-m", "llama_cpp.server",
|
36 |
+
"--model", MODEL_PATH,
|
37 |
+
"--port", "8080"
|
38 |
+
])
|
39 |
+
logger.info("Server started successfully!")
|
40 |
+
except Exception as e:
|
41 |
+
logger.error(f"Failed to start server: {e}")
|
42 |
+
raise
|
43 |
+
|
44 |
+
def convert_line_parsed_to_mongo(line_parsed: str, schema: dict) -> str:
|
45 |
+
try:
|
46 |
+
modified_query = parse_line_based_query(line_parsed)
|
47 |
+
collection_name = schema["collections"][0]["name"]
|
48 |
+
in2out, _ = build_schema_maps(schema)
|
49 |
+
reconstructed_query = convert_modified_to_actual_code_string(modified_query, in2out, collection_name)
|
50 |
+
return reconstructed_query
|
51 |
+
except Exception as e:
|
52 |
+
logger.error(f"Error converting line parsed to MongoDB query: {e}")
|
53 |
+
return ""
|
54 |
+
|
55 |
+
def process_query(schema_text: str, nl_query: str, additional_info: str = "") -> str:
|
56 |
+
try:
|
57 |
+
# Parse schema from string to dict
|
58 |
+
schema = json.loads(schema_text)
|
59 |
+
|
60 |
+
# Convert schema to line-based format
|
61 |
+
line_based_schema = schema_to_line_based(schema)
|
62 |
+
|
63 |
+
# Format prompt with line-based schema
|
64 |
+
prompt = MODEL_PROMPT_V3.format(
|
65 |
+
schema=line_based_schema,
|
66 |
+
natural_language_query=nl_query,
|
67 |
+
additional_info=additional_info
|
68 |
+
)
|
69 |
+
|
70 |
+
# Prepare request payload
|
71 |
+
payload = {
|
72 |
+
"slot_id": 0,
|
73 |
+
"temperature": 0.1,
|
74 |
+
"n_keep": -1,
|
75 |
+
"cache_prompt": True,
|
76 |
+
"messages": [
|
77 |
+
{
|
78 |
+
"role": "system",
|
79 |
+
"content": SYSTEM_PROMPT_V3,
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"role": "user",
|
83 |
+
"content": prompt
|
84 |
+
},
|
85 |
+
]
|
86 |
+
}
|
87 |
+
|
88 |
+
# Make request to llama.cpp server
|
89 |
+
response = requests.post(LLAMA_SERVER_URL, json=payload)
|
90 |
+
response.raise_for_status()
|
91 |
+
|
92 |
+
# Extract output from response
|
93 |
+
output = response.json()["choices"][0]["message"]["content"].strip()
|
94 |
+
logger.info(f"Model output: {output}")
|
95 |
+
|
96 |
+
# Convert line-based output to MongoDB query
|
97 |
+
mongo_query = convert_line_parsed_to_mongo(output, schema)
|
98 |
+
|
99 |
+
return [
|
100 |
+
mongo_query,
|
101 |
+
output
|
102 |
+
]
|
103 |
+
except Exception as e:
|
104 |
+
logger.error(f"Error processing query: {e}")
|
105 |
+
error_msg = f"Error: {str(e)}"
|
106 |
+
return [error_msg, error_msg, error_msg]
|
107 |
+
|
108 |
+
def create_interface():
|
109 |
+
# Create Gradio interface
|
110 |
+
iface = gr.Interface(
|
111 |
+
fn=process_query,
|
112 |
+
inputs=[
|
113 |
+
gr.Textbox(
|
114 |
+
label="Schema (JSON format)",
|
115 |
+
placeholder="Enter your MongoDB schema in JSON format...",
|
116 |
+
lines=10
|
117 |
+
),
|
118 |
+
gr.Textbox(
|
119 |
+
label="Natural Language Query",
|
120 |
+
placeholder="Enter your query in natural language..."
|
121 |
+
),
|
122 |
+
gr.Textbox(
|
123 |
+
label="Additional Info (Optional)",
|
124 |
+
placeholder="Enter any additional context (timestamps, etc)..."
|
125 |
+
)
|
126 |
+
],
|
127 |
+
outputs=[
|
128 |
+
gr.Code(label="MongoDB Query", language="javascript", lines=1),
|
129 |
+
gr.Textbox(label="Line-based Query")
|
130 |
+
],
|
131 |
+
title="Natural Language to MongoDB Query Converter",
|
132 |
+
description="Convert natural language queries to MongoDB queries based on your schema.",
|
133 |
+
examples=[
|
134 |
+
[
|
135 |
+
'''{
|
136 |
+
"collections": [{
|
137 |
+
"name": "events",
|
138 |
+
"document": {
|
139 |
+
"properties": {
|
140 |
+
"timestamp": {"bsonType": "int"},
|
141 |
+
"severity": {"bsonType": "int"},
|
142 |
+
"location": {
|
143 |
+
"bsonType": "object",
|
144 |
+
"properties": {
|
145 |
+
"lat": {"bsonType": "double"},
|
146 |
+
"lon": {"bsonType": "double"}
|
147 |
+
}
|
148 |
+
}
|
149 |
+
}
|
150 |
+
}
|
151 |
+
}]}''',
|
152 |
+
"Find all events with severity greater than 5",
|
153 |
+
""
|
154 |
+
],
|
155 |
+
[
|
156 |
+
'''{
|
157 |
+
"collections": [{
|
158 |
+
"name": "vehicles",
|
159 |
+
"document": {
|
160 |
+
"properties": {
|
161 |
+
"timestamp": {"bsonType": "int"},
|
162 |
+
"vehicle_details": {
|
163 |
+
"bsonType": "object",
|
164 |
+
"properties": {
|
165 |
+
"license_plate": {"bsonType": "string"},
|
166 |
+
"make": {"bsonType": "string"},
|
167 |
+
"model": {"bsonType": "string"},
|
168 |
+
"year": {"bsonType": "int"},
|
169 |
+
"color": {"bsonType": "string"}
|
170 |
+
}
|
171 |
+
},
|
172 |
+
"speed": {"bsonType": "double"},
|
173 |
+
"location": {
|
174 |
+
"bsonType": "object",
|
175 |
+
"properties": {
|
176 |
+
"lat": {"bsonType": "double"},
|
177 |
+
"lon": {"bsonType": "double"}
|
178 |
+
}
|
179 |
+
}
|
180 |
+
}
|
181 |
+
}
|
182 |
+
}]}''',
|
183 |
+
"Find red Toyota vehicles manufactured after 2020 with speed above 60",
|
184 |
+
""
|
185 |
+
],
|
186 |
+
[
|
187 |
+
'''{
|
188 |
+
"collections": [{
|
189 |
+
"name": "sensors",
|
190 |
+
"document": {
|
191 |
+
"properties": {
|
192 |
+
"sensor_id": {"bsonType": "string"},
|
193 |
+
"readings": {
|
194 |
+
"bsonType": "object",
|
195 |
+
"properties": {
|
196 |
+
"temperature": {"bsonType": "double"},
|
197 |
+
"humidity": {"bsonType": "double"},
|
198 |
+
"pressure": {"bsonType": "double"}
|
199 |
+
}
|
200 |
+
},
|
201 |
+
"timestamp": {"bsonType": "date"},
|
202 |
+
"status": {"bsonType": "string"}
|
203 |
+
}
|
204 |
+
}
|
205 |
+
}]}''',
|
206 |
+
"Find active sensors with temperature above 30 degrees in the last one day",
|
207 |
+
'''current date is 21 january 2025'''
|
208 |
+
],
|
209 |
+
[
|
210 |
+
'''{
|
211 |
+
"collections": [{
|
212 |
+
"name": "orders",
|
213 |
+
"document": {
|
214 |
+
"properties": {
|
215 |
+
"order_id": {"bsonType": "string"},
|
216 |
+
"customer": {
|
217 |
+
"bsonType": "object",
|
218 |
+
"properties": {
|
219 |
+
"id": {"bsonType": "string"},
|
220 |
+
"name": {"bsonType": "string"},
|
221 |
+
"email": {"bsonType": "string"}
|
222 |
+
}
|
223 |
+
},
|
224 |
+
"items": {
|
225 |
+
"bsonType": "array",
|
226 |
+
"items": {
|
227 |
+
"bsonType": "object",
|
228 |
+
"properties": {
|
229 |
+
"product_id": {"bsonType": "string"},
|
230 |
+
"quantity": {"bsonType": "int"},
|
231 |
+
"price": {"bsonType": "double"}
|
232 |
+
}
|
233 |
+
}
|
234 |
+
},
|
235 |
+
"total_amount": {"bsonType": "double"},
|
236 |
+
"status": {"bsonType": "string"},
|
237 |
+
"created_at": {"bsonType": "int"}
|
238 |
+
}
|
239 |
+
}
|
240 |
+
}]}''',
|
241 |
+
"Find orders with total amount greater than $100 that contain more than 3 items and were created in the last 24 hours",
|
242 |
+
'''{"current_time": 1685890800, "last_24_hours": 1685804400}'''
|
243 |
+
]
|
244 |
+
]
|
245 |
+
)
|
246 |
+
return iface
|
247 |
+
|
248 |
+
if __name__ == "__main__":
|
249 |
+
# Download the model
|
250 |
+
download_model()
|
251 |
+
|
252 |
+
# Start the llama.cpp server
|
253 |
+
start_llama_server()
|
254 |
+
|
255 |
+
# Give the server a moment to start
|
256 |
+
import time
|
257 |
+
time.sleep(5)
|
258 |
+
|
259 |
+
# Launch the Gradio interface
|
260 |
+
print("Starting Gradio interface...")
|
261 |
+
iface = create_interface()
|
262 |
+
iface.launch()
|
configs/__pycache__/prompt_config.cpython-313.pyc
ADDED
Binary file (1.35 kB). View file
|
|
configs/prompt_config.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SYSTEM_PROMPT_V3 = """You are a MongoDB query parsing assistant. Your task is to convert a natural language query into a structured, line-by-line parsed format suitable for building MongoDB queries.
|
2 |
+
|
3 |
+
You will receive:
|
4 |
+
- schema: <MongoDB schema fields and their descriptions>
|
5 |
+
- natural_language_query: <A plain English query describing the intent of user.>
|
6 |
+
- additional_info: <optional context or constraints>
|
7 |
+
|
8 |
+
Your job is to extract the relevant conditions and represent them in the following parsed format:
|
9 |
+
- Each filter is on a separate line
|
10 |
+
- Use operators like:
|
11 |
+
= - equality
|
12 |
+
$gt - greater than
|
13 |
+
$lt - less than
|
14 |
+
$gte - greater than or equal to
|
15 |
+
$lte - less than or equal to
|
16 |
+
$in - inclusion list (comma-separated values)
|
17 |
+
$regex - regular expression for matching
|
18 |
+
- Optionally, include:
|
19 |
+
sort = <field_name> (ascending or descending)
|
20 |
+
limit = <number>
|
21 |
+
|
22 |
+
Follow the schema strictly. Do not hallucinate field names. Output only the parsed query format with no explanations.
|
23 |
+
"""
|
24 |
+
|
25 |
+
MODEL_PROMPT_V3 = """schema:
|
26 |
+
{schema}
|
27 |
+
|
28 |
+
natural_language_query: {natural_language_query}
|
29 |
+
|
30 |
+
additional_info: {additional_info}
|
31 |
+
|
32 |
+
parsed_mongo_query:"""
|
data_utils/__init__.py
ADDED
File without changes
|
data_utils/__pycache__/__init__.cpython-313.pyc
ADDED
Binary file (147 Bytes). View file
|
|
data_utils/__pycache__/base_conversion_utils.cpython-313.pyc
ADDED
Binary file (22.1 kB). View file
|
|
data_utils/__pycache__/line_based_parsing.cpython-313.pyc
ADDED
Binary file (7.58 kB). View file
|
|
data_utils/__pycache__/schema_utils.cpython-313.pyc
ADDED
Binary file (4.13 kB). View file
|
|
data_utils/base_conversion_utils.py
ADDED
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Tuple, List
|
2 |
+
from loguru import logger
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
|
6 |
+
def _normalize_number(match):
|
7 |
+
num_str = match.group(0)
|
8 |
+
if '.' in num_str:
|
9 |
+
# Normalize float by removing trailing zeros and decimal point if needed
|
10 |
+
return str(float(num_str))
|
11 |
+
return num_str # Leave integers as is
|
12 |
+
|
13 |
+
|
14 |
+
def clean_query(query: str) -> str:
|
15 |
+
"""
|
16 |
+
Cleans the MongoDB query string by removing unnecessary whitespace and formatting.
|
17 |
+
|
18 |
+
to do:
|
19 |
+
- replace ' with "
|
20 |
+
- remove all spaces
|
21 |
+
- strip the query
|
22 |
+
- convert '''<query>''' to <query>
|
23 |
+
- remove \n
|
24 |
+
- remove empty brackets {}
|
25 |
+
"""
|
26 |
+
# replace \' with "
|
27 |
+
query = query.replace("'", "\"")
|
28 |
+
# Remove all spaces
|
29 |
+
query = query.replace(" ", "")
|
30 |
+
# Strip the query
|
31 |
+
query = query.strip()
|
32 |
+
# Convert '''<query>''' to <query>
|
33 |
+
if query.startswith("'''") and query.endswith("'''"):
|
34 |
+
query = query[3:-3]
|
35 |
+
# Remove \n
|
36 |
+
query = query.replace("\n", "")
|
37 |
+
# Remove empty brackets {}
|
38 |
+
query = query.replace("{}", "")
|
39 |
+
# Replace .toArray() with ""
|
40 |
+
query = query.replace(".toArray()", "")
|
41 |
+
# Normalize number strings
|
42 |
+
query = re.sub(r'(?<!["\w])(-?\d+\.\d+)(?!["\w])', _normalize_number, query)
|
43 |
+
return query
|
44 |
+
|
45 |
+
|
46 |
+
def extract_field_paths(properties: Dict[str, Any], prefix: str = "") -> Dict[str, str]:
|
47 |
+
"""
|
48 |
+
Recursively extract all leaf property names to full dot-paths
|
49 |
+
from a Mongo JSON Schema 'properties' dict.
|
50 |
+
Handles nested objects and arrays of objects.
|
51 |
+
Returns {field_name: full_path}
|
52 |
+
"""
|
53 |
+
paths: Dict[str, str] = {}
|
54 |
+
for key, val in properties.items():
|
55 |
+
current = prefix + key
|
56 |
+
# If nested object, recurse
|
57 |
+
if val.get("bsonType") == "object" and "properties" in val:
|
58 |
+
paths.update(extract_field_paths(val["properties"], current + "."))
|
59 |
+
# If array of objects, recurse into items
|
60 |
+
elif val.get("bsonType") == "array" and "items" in val and val["items"].get("bsonType") == "object" and "properties" in val["items"]:
|
61 |
+
paths.update(extract_field_paths(val["items"]["properties"], current + "."))
|
62 |
+
else:
|
63 |
+
paths[key] = current
|
64 |
+
return paths
|
65 |
+
|
66 |
+
|
67 |
+
def build_schema_maps(schema: Dict[str, Any]) -> Tuple[Dict[str, str], Dict[str, str]]:
|
68 |
+
"""
|
69 |
+
From a full JSON Schema, return two maps:
|
70 |
+
- input_to_output: field_name -> nested field path
|
71 |
+
- output_to_input: nested field path -> field_name
|
72 |
+
Handles both nested and flat schemas correctly.
|
73 |
+
"""
|
74 |
+
props = schema["collections"][0]["document"]["properties"]
|
75 |
+
input_to_output = extract_field_paths(props)
|
76 |
+
output_to_input = {v: k for k, v in input_to_output.items()}
|
77 |
+
return input_to_output, output_to_input
|
78 |
+
|
79 |
+
|
80 |
+
def set_nested(d: Dict[str, Any], keys: List[str], value: Any) -> None:
|
81 |
+
"""
|
82 |
+
Helper to set a nested value in a dict given a list of keys.
|
83 |
+
"""
|
84 |
+
for k in keys[:-1]:
|
85 |
+
d = d.setdefault(k, {})
|
86 |
+
d[keys[-1]] = value
|
87 |
+
|
88 |
+
|
89 |
+
def dot_notation_to_nested(dot: Dict[str, Any]) -> Dict[str, Any]:
|
90 |
+
"""
|
91 |
+
Convert a dict with dot-notation keys to nested dict structure.
|
92 |
+
E.g. {"a.b": v} -> {"a": {"b": v}}
|
93 |
+
"""
|
94 |
+
out: Dict[str, Any] = {}
|
95 |
+
for key, val in dot.items():
|
96 |
+
parts = key.split('.')
|
97 |
+
set_nested(out, parts, val)
|
98 |
+
return out
|
99 |
+
|
100 |
+
|
101 |
+
def nested_to_dot(d: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
|
102 |
+
"""
|
103 |
+
Convert nested dict to dot-notation keys. Treat operator-dicts as leaves.
|
104 |
+
"""
|
105 |
+
out: Dict[str, Any] = {}
|
106 |
+
for k, v in d.items():
|
107 |
+
new_pref = f"{prefix}.{k}" if prefix else k
|
108 |
+
# operator-dict leaf?
|
109 |
+
if isinstance(v, dict) and v and all(str(kk).startswith("$") for kk in v):
|
110 |
+
out[new_pref] = v
|
111 |
+
elif isinstance(v, dict):
|
112 |
+
out.update(nested_to_dot(v, new_pref))
|
113 |
+
else:
|
114 |
+
out[new_pref] = v
|
115 |
+
return out
|
116 |
+
|
117 |
+
|
118 |
+
def modified_to_actual_query(modified: Dict[str, Any],
|
119 |
+
input_to_output: Dict[str, str]) -> Dict[str, Any]:
|
120 |
+
"""
|
121 |
+
Convert a flat filter dict (field_name -> value/operator) into
|
122 |
+
a nested Mongo query dict according to the schema map.
|
123 |
+
If a key is not in the schema, treat it as dot notation.
|
124 |
+
"""
|
125 |
+
query: Dict[str, Any] = {}
|
126 |
+
for field_name, val in modified.items():
|
127 |
+
if field_name in input_to_output:
|
128 |
+
path = input_to_output[field_name].split('.')
|
129 |
+
set_nested(query, path, val)
|
130 |
+
else:
|
131 |
+
# fallback: treat as dot notation
|
132 |
+
set_nested(query, field_name.split('.'), val)
|
133 |
+
return query
|
134 |
+
|
135 |
+
|
136 |
+
def actual_to_modified_query(actual: Dict[str, Any],
|
137 |
+
output_to_input: Dict[str, str]) -> Dict[str, Any]:
|
138 |
+
"""
|
139 |
+
Flatten a nested Mongo query dict back into field_name -> value/operator.
|
140 |
+
Operator-dicts (keys starting with $) are treated as leaves.
|
141 |
+
If a path is not in output_to_input mapping, preserve it as-is.
|
142 |
+
"""
|
143 |
+
flat: Dict[str, Any] = {}
|
144 |
+
|
145 |
+
def recurse(d: Any, prefix: str = "") -> None:
|
146 |
+
# operator-dict leaf
|
147 |
+
if isinstance(d, dict) and d and all(k.startswith("$") for k in d):
|
148 |
+
if prefix in output_to_input:
|
149 |
+
flat[output_to_input[prefix]] = d
|
150 |
+
else:
|
151 |
+
flat[prefix] = d
|
152 |
+
return
|
153 |
+
|
154 |
+
# leaf non-dict
|
155 |
+
if not isinstance(d, dict):
|
156 |
+
if prefix in output_to_input:
|
157 |
+
flat[output_to_input[prefix]] = d
|
158 |
+
else:
|
159 |
+
flat[prefix] = d
|
160 |
+
return
|
161 |
+
|
162 |
+
# recurse deeper
|
163 |
+
for k, v in d.items():
|
164 |
+
new_pref = f"{prefix}.{k}" if prefix else k
|
165 |
+
recurse(v, new_pref)
|
166 |
+
|
167 |
+
recurse(actual)
|
168 |
+
return flat
|
169 |
+
|
170 |
+
|
171 |
+
def build_query_and_options(
|
172 |
+
modified: Dict[str, Any],
|
173 |
+
input_to_output: Dict[str, str]
|
174 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
175 |
+
"""
|
176 |
+
From a flat input dict that may include filter fields plus
|
177 |
+
special options (limit, skip, sort, projection), return:
|
178 |
+
- nested Mongo filter dict
|
179 |
+
- options dict with keys: limit, skip, sort, projection
|
180 |
+
"""
|
181 |
+
# extract special keys
|
182 |
+
options: Dict[str, Any] = {}
|
183 |
+
for opt in ("limit", "skip", "sort", "projection"): # in this order
|
184 |
+
if opt in modified:
|
185 |
+
options[opt] = modified.pop(opt)
|
186 |
+
|
187 |
+
# build nested filter
|
188 |
+
query = modified_to_actual_query(modified, input_to_output)
|
189 |
+
return query, options
|
190 |
+
|
191 |
+
|
192 |
+
def convert_modified_to_actual_code_string(
|
193 |
+
modified_input: dict,
|
194 |
+
in2out: dict,
|
195 |
+
collection_name: str = "events"
|
196 |
+
) -> str:
|
197 |
+
"""
|
198 |
+
Converts a modified (flat) dict into a MongoDB code string.
|
199 |
+
Omits the projection argument if opts['projection'] is empty.
|
200 |
+
Prints filter in dot-notation to match db.find syntax.
|
201 |
+
"""
|
202 |
+
import re
|
203 |
+
|
204 |
+
# Remove internal metadata fields before processing
|
205 |
+
modified_input = {k: v for k, v in modified_input.items() if not k.startswith('_')}
|
206 |
+
|
207 |
+
filter_dict, opts = build_query_and_options(modified_input.copy(), in2out)
|
208 |
+
|
209 |
+
# 1) dot-ify the filter dict
|
210 |
+
dot_filter = nested_to_dot(filter_dict)
|
211 |
+
filter_str = json.dumps(dot_filter, separators=(",", ":"))
|
212 |
+
|
213 |
+
# 2) Convert date strings back to appropriate MongoDB date format
|
214 |
+
# This regex matches ISO date strings like "2024-01-01T00:00:00Z"
|
215 |
+
date_pattern = r'"(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z)"'
|
216 |
+
|
217 |
+
# Check if there was a newDate string in the original query
|
218 |
+
# If so, we need to preserve that format instead of using ISODate
|
219 |
+
if "newDate" in modified_input.get("_original_query_format", ""):
|
220 |
+
filter_str = re.sub(date_pattern, r'newDate("\1")', filter_str)
|
221 |
+
else:
|
222 |
+
# Default to ISODate format
|
223 |
+
filter_str = re.sub(date_pattern, r'ISODate("\1")', filter_str)
|
224 |
+
|
225 |
+
# 3) Restore special time expressions that might have been converted
|
226 |
+
time_expr_pattern = r'"(newDate\.getTime\(\)-\d+)"'
|
227 |
+
filter_str = re.sub(time_expr_pattern, r'\1', filter_str)
|
228 |
+
|
229 |
+
# 4) only include projection if non-empty
|
230 |
+
projection = opts.get("projection", None)
|
231 |
+
projection_str = ""
|
232 |
+
if projection:
|
233 |
+
projection_str = json.dumps(projection, separators=(',', ':'))
|
234 |
+
# Also convert date strings in projection if any
|
235 |
+
if "newDate" in modified_input.get("_original_query_format", ""):
|
236 |
+
projection_str = re.sub(date_pattern, r'newDate("\1")', projection_str)
|
237 |
+
else:
|
238 |
+
projection_str = re.sub(date_pattern, r'ISODate("\1")', projection_str)
|
239 |
+
|
240 |
+
parts = [f"db.{collection_name}.find({filter_str}"
|
241 |
+
+ (f", {projection_str}" if projection else "")
|
242 |
+
+ ")"]
|
243 |
+
|
244 |
+
# 5) chain optional methods
|
245 |
+
if opts.get("sort"):
|
246 |
+
# Handle different sort formats
|
247 |
+
sort_value = opts['sort']
|
248 |
+
if isinstance(sort_value, list):
|
249 |
+
# Convert array format to object format
|
250 |
+
sort_obj = {}
|
251 |
+
for key, direction in sort_value:
|
252 |
+
sort_obj[key] = direction
|
253 |
+
sort_value = sort_obj
|
254 |
+
|
255 |
+
# For sort parameters, we want to preserve the MongoDB format exactly
|
256 |
+
# Convert the sort object to a string without quotes around the entire thing
|
257 |
+
if isinstance(sort_value, dict):
|
258 |
+
sort_items = []
|
259 |
+
for k, v in sort_value.items():
|
260 |
+
sort_items.append(f'"{k}":{v}')
|
261 |
+
sort_str = '{' + ','.join(sort_items) + '}'
|
262 |
+
else:
|
263 |
+
sort_str = str(sort_value)
|
264 |
+
|
265 |
+
parts.append(f".sort({sort_str})")
|
266 |
+
|
267 |
+
if opts.get("skip"):
|
268 |
+
parts.append(f".skip({opts['skip']})")
|
269 |
+
|
270 |
+
if opts.get("limit"):
|
271 |
+
parts.append(f".limit({opts['limit']})")
|
272 |
+
|
273 |
+
return "".join(parts)
|
274 |
+
|
275 |
+
|
276 |
+
def convert_actual_code_to_modified_dict(actual_code: str, out2in: dict) -> dict:
|
277 |
+
"""
|
278 |
+
Converts an actual MongoDB query string into a modified flat dictionary.
|
279 |
+
WARNING: This assumes the input is sanitized and safe (e.g., evaluated from a trusted source).
|
280 |
+
"""
|
281 |
+
import ast
|
282 |
+
import re
|
283 |
+
import json
|
284 |
+
from datetime import datetime, timedelta
|
285 |
+
|
286 |
+
# Store original number strings
|
287 |
+
original_numbers = {}
|
288 |
+
|
289 |
+
def store_number_strings(s: str) -> str:
|
290 |
+
def replace_number(match):
|
291 |
+
num_str = match.group(0)
|
292 |
+
# Only store if it has a decimal point (to preserve trailing zeros)
|
293 |
+
if '.' in num_str:
|
294 |
+
try:
|
295 |
+
num = float(num_str)
|
296 |
+
# Store the longest representation for this float
|
297 |
+
key = str(num)
|
298 |
+
if key not in original_numbers or len(num_str) > len(original_numbers[key]):
|
299 |
+
original_numbers[key] = num_str
|
300 |
+
except ValueError:
|
301 |
+
pass
|
302 |
+
return num_str
|
303 |
+
|
304 |
+
# Match numbers with optional decimal places and trailing zeros
|
305 |
+
number_pattern = r'-?\d+\.\d+'
|
306 |
+
re.sub(number_pattern, replace_number, s)
|
307 |
+
return s
|
308 |
+
|
309 |
+
def preprocess_mongo_syntax(query_str):
|
310 |
+
store_number_strings(query_str)
|
311 |
+
|
312 |
+
# Replace ISODate("..."), ISODate('...') with the date string
|
313 |
+
query_str = re.sub(r'ISODate\("([^"]+)"\)', r'"\1"', query_str)
|
314 |
+
query_str = re.sub(r"ISODate\('([^']+)'\)", r'"\1"', query_str)
|
315 |
+
|
316 |
+
# Handle newDate(newDate().getTime()-<expr>)
|
317 |
+
def newdate_minus_expr(match):
|
318 |
+
expr = match.group(1)
|
319 |
+
try:
|
320 |
+
# Evaluate the expression safely (only numbers and operators)
|
321 |
+
ms = int(eval(expr, {"__builtins__": None}, {}))
|
322 |
+
from datetime import datetime, timedelta
|
323 |
+
dt = datetime.utcnow() + timedelta(milliseconds=ms)
|
324 |
+
return '"' + dt.strftime('%Y-%m-%dT%H:%M:%SZ') + '"'
|
325 |
+
except Exception:
|
326 |
+
return '"1970-01-01T00:00:00Z"' # fallback
|
327 |
+
query_str = re.sub(r'newDate\(newDate\(\)\.getTime\(\)([-+*/0-9 ]+)\)', newdate_minus_expr, query_str)
|
328 |
+
|
329 |
+
# Replace newDate() with current UTC time
|
330 |
+
from datetime import datetime
|
331 |
+
now = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
|
332 |
+
query_str = re.sub(r'newDate\(\)', f'"{now}"', query_str)
|
333 |
+
|
334 |
+
# Replace newDate(expr) with a string (handle both quote types)
|
335 |
+
query_str = re.sub(r'newDate\("([^"]+)"\)', r'"\1"', query_str)
|
336 |
+
query_str = re.sub(r"newDate\('([^']+)'\)", r'"\1"', query_str)
|
337 |
+
query_str = re.sub(r'newDate\((.*?)\)', r'"\1"', query_str)
|
338 |
+
|
339 |
+
# Fix unbalanced brackets
|
340 |
+
if query_str.count('{') > query_str.count('}'):
|
341 |
+
query_str += "}" * (query_str.count('{') - query_str.count('}'))
|
342 |
+
return query_str
|
343 |
+
|
344 |
+
# Extract filter dictionary from find() call using regex
|
345 |
+
def extract_filter_dict(code):
|
346 |
+
# Match db.collection.find(...) pattern
|
347 |
+
find_pattern = r'db\.[^.]+\.find\((.*?)(?:\)|,\s*{)'
|
348 |
+
find_match = re.search(find_pattern, code)
|
349 |
+
if not find_match:
|
350 |
+
raise ValueError("Could not extract filter parameters from find() call")
|
351 |
+
|
352 |
+
filter_str = find_match.group(1)
|
353 |
+
|
354 |
+
# If empty, return empty dict
|
355 |
+
if not filter_str.strip():
|
356 |
+
return {}
|
357 |
+
|
358 |
+
try:
|
359 |
+
# Try parsing as JSON
|
360 |
+
return json.loads(filter_str)
|
361 |
+
except json.JSONDecodeError:
|
362 |
+
# Try with ast.literal_eval
|
363 |
+
try:
|
364 |
+
return ast.literal_eval(filter_str)
|
365 |
+
except:
|
366 |
+
# Last resort - try fixing common issues and retry
|
367 |
+
fixed_str = filter_str.replace("'", '"')
|
368 |
+
try:
|
369 |
+
return json.loads(fixed_str)
|
370 |
+
except:
|
371 |
+
raise ValueError(f"Could not parse filter dictionary: {filter_str}")
|
372 |
+
|
373 |
+
# Extract projection dictionary from find() call using regex
|
374 |
+
def extract_projection_dict(code):
|
375 |
+
# Match find(..., {projection}) pattern
|
376 |
+
proj_pattern = r'find\([^{]*({[^}]*})[^{]*,\s*{([^}]*)}\s*\)'
|
377 |
+
proj_match = re.search(proj_pattern, code)
|
378 |
+
if not proj_match:
|
379 |
+
return None
|
380 |
+
|
381 |
+
proj_str = proj_match.group(2)
|
382 |
+
try:
|
383 |
+
# Try parsing as JSON
|
384 |
+
return json.loads(proj_str.replace("'", '"'))
|
385 |
+
except:
|
386 |
+
# Try with ast.literal_eval
|
387 |
+
try:
|
388 |
+
return ast.literal_eval(proj_str)
|
389 |
+
except:
|
390 |
+
return None
|
391 |
+
|
392 |
+
# Extract method parameters using regex for cases where ast.literal_eval fails
|
393 |
+
def extract_method_params(code, method_name):
|
394 |
+
# Look for .method_name({...}) or .method_name([...]) or .method_name(123) pattern
|
395 |
+
pattern = fr'\.{method_name}\s*\((.*?)\)(?:\.|\s*$)'
|
396 |
+
match = re.search(pattern, code)
|
397 |
+
if not match:
|
398 |
+
return None
|
399 |
+
|
400 |
+
param_str = match.group(1).strip()
|
401 |
+
|
402 |
+
# Empty parameter
|
403 |
+
if not param_str:
|
404 |
+
return None
|
405 |
+
|
406 |
+
# Try to handle different parameter types
|
407 |
+
try:
|
408 |
+
# Simple number?
|
409 |
+
if param_str.isdigit():
|
410 |
+
return int(param_str)
|
411 |
+
|
412 |
+
# JSON object or array?
|
413 |
+
try:
|
414 |
+
# Handle MongoDB format with double quotes
|
415 |
+
return json.loads(param_str.replace("'", '"'))
|
416 |
+
except json.JSONDecodeError:
|
417 |
+
# If direct JSON parsing fails, try to use ast.literal_eval
|
418 |
+
try:
|
419 |
+
return ast.literal_eval(param_str)
|
420 |
+
except:
|
421 |
+
# Return as is if all else fails
|
422 |
+
return param_str
|
423 |
+
except Exception as e:
|
424 |
+
# Return None if all parsing fails
|
425 |
+
logger.warning(f"Failed to parse parameter for {method_name}: {e}")
|
426 |
+
return None
|
427 |
+
|
428 |
+
# Pre-process the query
|
429 |
+
preprocessed_code = preprocess_mongo_syntax(actual_code)
|
430 |
+
|
431 |
+
try:
|
432 |
+
# Try to use our more robust regex-based parsing first
|
433 |
+
filter_dict = extract_filter_dict(preprocessed_code)
|
434 |
+
projection = extract_projection_dict(preprocessed_code)
|
435 |
+
|
436 |
+
# Handle empty projection
|
437 |
+
options = {"projection": projection} if projection else {}
|
438 |
+
|
439 |
+
# Extract sort, limit and skip parameters
|
440 |
+
sort_param = extract_method_params(preprocessed_code, "sort")
|
441 |
+
if sort_param is not None:
|
442 |
+
options["sort"] = sort_param
|
443 |
+
|
444 |
+
limit_param = extract_method_params(preprocessed_code, "limit")
|
445 |
+
if limit_param is not None:
|
446 |
+
options["limit"] = int(limit_param) if isinstance(limit_param, (int, str)) else limit_param
|
447 |
+
|
448 |
+
skip_param = extract_method_params(preprocessed_code, "skip")
|
449 |
+
if skip_param is not None:
|
450 |
+
options["skip"] = int(skip_param) if isinstance(skip_param, (int, str)) else skip_param
|
451 |
+
|
452 |
+
# Convert actual filter_dict back to modified
|
453 |
+
flat_filter = actual_to_modified_query(filter_dict, out2in)
|
454 |
+
|
455 |
+
# Merge projection, sort, limit into modified if relevant
|
456 |
+
for key in ("projection", "sort", "skip", "limit"):
|
457 |
+
if key in options and options[key] is not None:
|
458 |
+
flat_filter[key] = options[key]
|
459 |
+
|
460 |
+
# Add original number strings to the result
|
461 |
+
flat_filter['_original_numbers'] = original_numbers
|
462 |
+
|
463 |
+
return flat_filter
|
464 |
+
|
465 |
+
except Exception as e:
|
466 |
+
# Fall back to traditional AST-based parsing if regex fails
|
467 |
+
try:
|
468 |
+
node = ast.parse(preprocessed_code.strip(), mode='eval')
|
469 |
+
if not isinstance(node.body, ast.Call) or not hasattr(node.body.func, 'attr') or node.body.func.attr != "find":
|
470 |
+
raise ValueError("Expected .find(...) style query")
|
471 |
+
|
472 |
+
# extract find(filter, projection)
|
473 |
+
args = node.body.args
|
474 |
+
filter_dict = ast.literal_eval(args[0])
|
475 |
+
projection = ast.literal_eval(args[1]) if len(args) > 1 else None
|
476 |
+
|
477 |
+
# extract chained methods: sort, skip, limit
|
478 |
+
options = {"projection": projection} if projection else {}
|
479 |
+
current = node.body
|
480 |
+
while isinstance(current, ast.Call):
|
481 |
+
func = current.func
|
482 |
+
if hasattr(func, "attr"):
|
483 |
+
if func.attr == "sort":
|
484 |
+
options["sort"] = ast.literal_eval(current.args[0])
|
485 |
+
elif func.attr == "skip":
|
486 |
+
options["skip"] = ast.literal_eval(current.args[0])
|
487 |
+
elif func.attr == "limit":
|
488 |
+
options["limit"] = ast.literal_eval(current.args[0])
|
489 |
+
current = func.value if hasattr(func, "value") else None
|
490 |
+
|
491 |
+
# Convert actual filter_dict back to modified
|
492 |
+
flat_filter = actual_to_modified_query(filter_dict, out2in)
|
493 |
+
|
494 |
+
# Merge projection, sort, limit into modified if relevant
|
495 |
+
for key in ("projection", "sort", "skip", "limit"):
|
496 |
+
if key in options:
|
497 |
+
flat_filter[key] = options[key]
|
498 |
+
|
499 |
+
return flat_filter
|
500 |
+
except Exception as nested_e:
|
501 |
+
raise ValueError(f"Failed to parse MongoDB query string: {e}. AST fallback also failed: {nested_e}")
|
502 |
+
|
503 |
+
|
504 |
+
# -------------------- Example Usage --------------------
|
505 |
+
if __name__ == "__main__":
|
506 |
+
# Example JSON Schema
|
507 |
+
schema = {
|
508 |
+
"collections": [{
|
509 |
+
"name": "events",
|
510 |
+
"document": {
|
511 |
+
"properties": {
|
512 |
+
"event_id": {"bsonType": "int"},
|
513 |
+
"timestamp": {"bsonType": "int"},
|
514 |
+
"severity_level": {"bsonType": "int"},
|
515 |
+
"camera_id": {"bsonType": "int"},
|
516 |
+
"vehicle_details": {"bsonType": "object", "properties": {
|
517 |
+
"license_plate_number": {"bsonType": "string"},
|
518 |
+
"vehicle_type": {"bsonType": "string"},
|
519 |
+
"color": {"bsonType": "string"}
|
520 |
+
}},
|
521 |
+
"person_details": {"bsonType": "object", "properties": {
|
522 |
+
"match_id": {"bsonType": "int"},
|
523 |
+
"age": {"bsonType": "int"},
|
524 |
+
"gender": {"bsonType": "string"},
|
525 |
+
"clothing_description": {"bsonType": "string"}
|
526 |
+
}},
|
527 |
+
"location": {"bsonType": "object", "properties": {
|
528 |
+
"latitude": {"bsonType": "double"},
|
529 |
+
"longitude": {"bsonType": "double"}
|
530 |
+
}},
|
531 |
+
"sensor_readings": {"bsonType": "object", "properties": {
|
532 |
+
"temperature": {"bsonType": "double"},
|
533 |
+
"speed": {"bsonType": "double"},
|
534 |
+
"distance": {"bsonType": "double"}
|
535 |
+
}},
|
536 |
+
"incident_type": {"bsonType": "string"}
|
537 |
+
}
|
538 |
+
}
|
539 |
+
}],
|
540 |
+
"version": 1
|
541 |
+
}
|
542 |
+
|
543 |
+
# Build mappings once
|
544 |
+
in2out, out2in = build_schema_maps(schema)
|
545 |
+
|
546 |
+
# Flat user input including filters + options
|
547 |
+
modified_input = {
|
548 |
+
"license_plate_number": {"$regex": "^MH12"},
|
549 |
+
"timestamp": {"$gte": 1684080000, "$lte": 1684166400},
|
550 |
+
"severity_level": 3,
|
551 |
+
"limit": 50,
|
552 |
+
"skip": 10,
|
553 |
+
"sort": [("timestamp", -1)],
|
554 |
+
"projection": {
|
555 |
+
"vehicle_details.license_plate_number": 1,
|
556 |
+
"timestamp": 1,
|
557 |
+
"_id": 0
|
558 |
+
}
|
559 |
+
}
|
560 |
+
|
561 |
+
# Build actual nested query + options
|
562 |
+
filter_dict, opts = build_query_and_options(modified_input.copy(), in2out)
|
563 |
+
|
564 |
+
print("filter_dict =", filter_dict)
|
565 |
+
print("options =", opts)
|
566 |
+
# You can then do:
|
567 |
+
# cursor = (
|
568 |
+
# db.events.find(filter_dict, opts.get("projection"))
|
569 |
+
# .sort(opts.get("sort", []))
|
570 |
+
# .skip(opts.get("skip", 0))
|
571 |
+
# .limit(opts.get("limit", 0))
|
572 |
+
# )
|
data_utils/line_based_parsing.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any
|
2 |
+
import ast
|
3 |
+
|
4 |
+
from typing import Any, Dict
|
5 |
+
|
6 |
+
def clean_modified_dict(modified_dict: Dict[str, Any]) -> Dict[str, Any]:
|
7 |
+
"""
|
8 |
+
Cleans the modified dictionary by removing only values that are:
|
9 |
+
- None
|
10 |
+
- empty list []
|
11 |
+
- empty dict {}
|
12 |
+
- empty string ''
|
13 |
+
But keeps values like 0, False, etc.
|
14 |
+
"""
|
15 |
+
def is_meaningfully_empty(value):
|
16 |
+
return value in (None, '', []) or (isinstance(value, dict) and not value)
|
17 |
+
|
18 |
+
return {k: v for k, v in modified_dict.items() if not is_meaningfully_empty(v)}
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
def convert_to_lines(query_dict):
|
23 |
+
lines = []
|
24 |
+
for field, condition in query_dict.items():
|
25 |
+
if isinstance(condition, dict):
|
26 |
+
for operator, value in condition.items():
|
27 |
+
# Special handling for $ne with '' or []
|
28 |
+
if operator in ['$ne', 'ne']:
|
29 |
+
if value == '':
|
30 |
+
value_str = "''"
|
31 |
+
elif value == []:
|
32 |
+
value_str = '[]'
|
33 |
+
elif isinstance(value, list):
|
34 |
+
value_str = ','.join(map(str, value))
|
35 |
+
elif isinstance(value, str):
|
36 |
+
value_str = f"'{value}'"
|
37 |
+
else:
|
38 |
+
value_str = str(int(value) if isinstance(value, float) and value.is_integer() else value)
|
39 |
+
elif isinstance(value, list):
|
40 |
+
# Output lists as valid Python lists for complex cases
|
41 |
+
value_str = repr(value)
|
42 |
+
elif isinstance(value, str):
|
43 |
+
value_str = f"'{value}'"
|
44 |
+
else:
|
45 |
+
value_str = str(int(value) if isinstance(value, float) and value.is_integer() else value)
|
46 |
+
lines.append(f"{field} {operator} {value_str}")
|
47 |
+
else:
|
48 |
+
if isinstance(condition, str):
|
49 |
+
condition_str = f"'{condition}'"
|
50 |
+
else:
|
51 |
+
condition_str = str(condition)
|
52 |
+
lines.append(f"{field} = {condition_str}")
|
53 |
+
return '\n'.join(lines)
|
54 |
+
|
55 |
+
|
56 |
+
def parse_line_based_query(lines):
|
57 |
+
query = {}
|
58 |
+
for line in lines.strip().split('\n'):
|
59 |
+
if not line.strip():
|
60 |
+
continue
|
61 |
+
parts = line.split(maxsplit=2)
|
62 |
+
if len(parts) < 3:
|
63 |
+
# If operator is present but value is empty, set value to empty string
|
64 |
+
if len(parts) == 2:
|
65 |
+
field, operator = parts
|
66 |
+
value = ''
|
67 |
+
else:
|
68 |
+
continue # Skip invalid lines
|
69 |
+
else:
|
70 |
+
field, operator, value = parts
|
71 |
+
|
72 |
+
# Special handling for sort, limit, skip, etc.
|
73 |
+
if field in {"sort", "order_by"}:
|
74 |
+
# Handle both 'sort field value' and 'sort = {field: value}'
|
75 |
+
if operator == "=":
|
76 |
+
query[field] = _convert_value(value)
|
77 |
+
else:
|
78 |
+
if field not in query:
|
79 |
+
query[field] = {}
|
80 |
+
query[field][operator] = _convert_value(value)
|
81 |
+
continue
|
82 |
+
if field in {"limit", "skip", "offset"}:
|
83 |
+
query[field] = _convert_value(value)
|
84 |
+
continue
|
85 |
+
# Special handling for _original_numbers (parse value as string if quoted, else as number)
|
86 |
+
if field == "_original_numbers":
|
87 |
+
if field not in query:
|
88 |
+
query[field] = {}
|
89 |
+
v = value.strip()
|
90 |
+
if (v.startswith("'") and v.endswith("'")) or (v.startswith('"') and v.endswith('"')):
|
91 |
+
query[field][operator] = v[1:-1]
|
92 |
+
else:
|
93 |
+
try:
|
94 |
+
# Try to parse as int or float
|
95 |
+
query[field][operator] = int(v)
|
96 |
+
except ValueError:
|
97 |
+
try:
|
98 |
+
query[field][operator] = float(v)
|
99 |
+
except ValueError:
|
100 |
+
query[field][operator] = v
|
101 |
+
continue
|
102 |
+
|
103 |
+
# Handle equality operator
|
104 |
+
if operator == "=":
|
105 |
+
query[field] = _convert_value(value)
|
106 |
+
continue
|
107 |
+
|
108 |
+
# Handle other operators
|
109 |
+
# If operator is $in, $nin, $all and value is empty, use []
|
110 |
+
empty_list_ops = {'in', '$in', 'nin', '$nin', 'all', '$all'}
|
111 |
+
op_key = operator if operator.startswith('$') else f'${operator}'
|
112 |
+
if operator in empty_list_ops and value == '':
|
113 |
+
value_obj = []
|
114 |
+
elif operator in {'ne', '$ne'}:
|
115 |
+
if value.strip() == '[]':
|
116 |
+
value_obj = []
|
117 |
+
elif value.strip() == "''" or value.strip() == '""':
|
118 |
+
value_obj = ''
|
119 |
+
elif value == '':
|
120 |
+
value_obj = []
|
121 |
+
else:
|
122 |
+
value_obj = _convert_value(value, operator)
|
123 |
+
else:
|
124 |
+
value_obj = _convert_value(value, operator)
|
125 |
+
if field in query:
|
126 |
+
if isinstance(query[field], dict):
|
127 |
+
query[field][op_key] = value_obj
|
128 |
+
else:
|
129 |
+
raise ValueError(f"Conflict in {field}: direct value and operator")
|
130 |
+
else:
|
131 |
+
query[field] = {op_key: value_obj}
|
132 |
+
return query
|
133 |
+
|
134 |
+
def _convert_value(value_str, operator=None):
|
135 |
+
"""Convert string values to appropriate types"""
|
136 |
+
# Handle lists for $in and $all operators
|
137 |
+
if operator in ('in', '$in', 'all', '$all'):
|
138 |
+
s = value_str.strip()
|
139 |
+
if s.startswith('[') and s.endswith(']'):
|
140 |
+
try:
|
141 |
+
return ast.literal_eval(s)
|
142 |
+
except Exception:
|
143 |
+
pass
|
144 |
+
if ',' in value_str:
|
145 |
+
return [_parse_single_value(v) for v in value_str.split(',')]
|
146 |
+
|
147 |
+
# Handle regex flags (e.g., "pattern i" → "pattern" with $options: 'i')
|
148 |
+
if operator == 'regex' and ' ' in value_str:
|
149 |
+
pattern, *flags = value_str.split()
|
150 |
+
return {'$regex': pattern, '$options': ''.join(flags)}
|
151 |
+
|
152 |
+
return _parse_single_value(value_str)
|
153 |
+
|
154 |
+
def _parse_single_value(s):
|
155 |
+
"""Convert individual values to int/float/string/dict/bool"""
|
156 |
+
s = s.strip()
|
157 |
+
# Remove surrounding quotes if present
|
158 |
+
if (s.startswith("'") and s.endswith("'")) or (s.startswith('"') and s.endswith('"')):
|
159 |
+
return s[1:-1].strip() # Always return as string if quoted
|
160 |
+
# Handle None
|
161 |
+
if s == 'None':
|
162 |
+
return None
|
163 |
+
# Try to parse as dict if it looks like one
|
164 |
+
if (s.startswith('{') and s.endswith('}')) or (s.startswith('[') and s.endswith(']')):
|
165 |
+
try:
|
166 |
+
return ast.literal_eval(s)
|
167 |
+
except Exception:
|
168 |
+
pass
|
169 |
+
# Handle booleans
|
170 |
+
if s.lower() == 'true':
|
171 |
+
return True
|
172 |
+
if s.lower() == 'false':
|
173 |
+
return False
|
174 |
+
try:
|
175 |
+
return int(s)
|
176 |
+
except ValueError:
|
177 |
+
try:
|
178 |
+
return float(s)
|
179 |
+
except ValueError:
|
180 |
+
return s
|
data_utils/schema_utils.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any
|
2 |
+
|
3 |
+
|
4 |
+
def schema_to_line_based(schema: dict) -> str:
|
5 |
+
"""
|
6 |
+
Converts a schema dictionary to a line-based format:
|
7 |
+
field // description and format (str, int, ...)
|
8 |
+
Only shows field names without parent prefix (e.g. 'age' instead of 'involved_persons.age')
|
9 |
+
"""
|
10 |
+
def get_type(info):
|
11 |
+
return info.get("bsonType") or info.get("type") or ""
|
12 |
+
|
13 |
+
def process_properties(properties: dict) -> list:
|
14 |
+
lines = []
|
15 |
+
for field, info in properties.items():
|
16 |
+
typ = get_type(info)
|
17 |
+
desc = info.get("description", "")
|
18 |
+
fmt = info.get("format", "")
|
19 |
+
|
20 |
+
# Compose type/format string
|
21 |
+
type_fmt = typ
|
22 |
+
if fmt:
|
23 |
+
type_fmt += f", {fmt}"
|
24 |
+
|
25 |
+
# Compose comment
|
26 |
+
comment = desc.strip()
|
27 |
+
if type_fmt:
|
28 |
+
comment = f"{comment} ({type_fmt})" if comment else f"({type_fmt})"
|
29 |
+
|
30 |
+
lines.append(f"{field} // {comment}" if comment else field)
|
31 |
+
|
32 |
+
# Recursively process nested objects and arrays, but only add the field names without prefix
|
33 |
+
if typ == "object" and "properties" in info:
|
34 |
+
for nested_line in process_properties(info["properties"]):
|
35 |
+
lines.append(nested_line)
|
36 |
+
elif typ == "array" and "items" in info:
|
37 |
+
items = info["items"]
|
38 |
+
if get_type(items) == "object" and "properties" in items:
|
39 |
+
for nested_line in process_properties(items["properties"]):
|
40 |
+
lines.append(nested_line)
|
41 |
+
|
42 |
+
return lines
|
43 |
+
|
44 |
+
collections = schema.get("collections", [])
|
45 |
+
if not collections:
|
46 |
+
return ""
|
47 |
+
collection = collections[0]
|
48 |
+
# Support both "document" and direct "properties"
|
49 |
+
if "document" in collection and "properties" in collection["document"]:
|
50 |
+
properties = collection["document"]["properties"]
|
51 |
+
else:
|
52 |
+
properties = collection.get("properties", {})
|
53 |
+
return "\n".join(process_properties(properties))
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
example_schema = {
|
58 |
+
"collections": [
|
59 |
+
{
|
60 |
+
"name": "events",
|
61 |
+
"document": {
|
62 |
+
"bsonType": "object",
|
63 |
+
"properties": {
|
64 |
+
"identifier": {
|
65 |
+
"bsonType": "object",
|
66 |
+
"properties": {
|
67 |
+
"camgroup_id": {
|
68 |
+
"bsonType": "string",
|
69 |
+
"description": "Use this to filter events by group"
|
70 |
+
},
|
71 |
+
"task_id": {
|
72 |
+
"bsonType": "string",
|
73 |
+
"description": "Use this to filter events by tasks"
|
74 |
+
},
|
75 |
+
"camera_id": {
|
76 |
+
"bsonType": "string",
|
77 |
+
"description": "Use this to filter events by camera"
|
78 |
+
}
|
79 |
+
}
|
80 |
+
},
|
81 |
+
"response": {
|
82 |
+
"bsonType": "object",
|
83 |
+
"properties": {
|
84 |
+
"event": {
|
85 |
+
"bsonType": "object",
|
86 |
+
"properties": {
|
87 |
+
"severity": {
|
88 |
+
"bsonType": "string",
|
89 |
+
"description": "Can be Low, Medium, Critical"
|
90 |
+
},
|
91 |
+
"type": {
|
92 |
+
"bsonType": "string",
|
93 |
+
"description": "Type of the event. Use this to filter events of person and vehicle"
|
94 |
+
},
|
95 |
+
"blobs": {
|
96 |
+
"bsonType": "array",
|
97 |
+
"items": {
|
98 |
+
"bsonType": "object",
|
99 |
+
"properties": {
|
100 |
+
"url": {
|
101 |
+
"bsonType": "string"
|
102 |
+
},
|
103 |
+
"attribs": {
|
104 |
+
"bsonType": "object",
|
105 |
+
"description": "Use this for attributes like Gender (Only Male, Female), Upper Clothing, Lower Clothing, Age (Ranges like 20-30, 30-40 and so on) for people and Make (like maruti suzuki, toyota, tata), Color, Type (like Hatchback, sedan, xuv), label (like car, truck, van, three wheeler, motorcycle) for Vehicles"
|
106 |
+
},
|
107 |
+
"label": {
|
108 |
+
"bsonType": "string",
|
109 |
+
"description": "Use this label for number plate"
|
110 |
+
},
|
111 |
+
"score": {
|
112 |
+
"bsonType": "number",
|
113 |
+
"description": "Use this for confidence for the blob"
|
114 |
+
},
|
115 |
+
"match_id": {
|
116 |
+
"bsonType": "string",
|
117 |
+
"description": "Use this match_id for name of the person"
|
118 |
+
},
|
119 |
+
"severity": {
|
120 |
+
"bsonType": "string"
|
121 |
+
},
|
122 |
+
"subclass": {
|
123 |
+
"bsonType": "string",
|
124 |
+
"description": "Use this for subclass for the blob"
|
125 |
+
}
|
126 |
+
}
|
127 |
+
}
|
128 |
+
},
|
129 |
+
"c_timestamp": {
|
130 |
+
"bsonType": "date",
|
131 |
+
"description": "Use this for timestamp"
|
132 |
+
},
|
133 |
+
"label": {
|
134 |
+
"bsonType": "string",
|
135 |
+
"description": "Use this label for number plate"
|
136 |
+
}
|
137 |
+
}
|
138 |
+
}
|
139 |
+
}
|
140 |
+
}
|
141 |
+
}
|
142 |
+
}
|
143 |
+
}
|
144 |
+
],
|
145 |
+
"version": 1
|
146 |
+
}
|
147 |
+
|
148 |
+
parsed_schema = schema_to_line_based(example_schema)
|
149 |
+
print(parsed_schema)
|
data_utils/utils.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from typing_extensions import Any, List, Dict
|
3 |
+
from loguru import logger
|
4 |
+
from tqdm import tqdm
|
5 |
+
from .base_conversion_utils import (
|
6 |
+
clean_query,
|
7 |
+
build_schema_maps,
|
8 |
+
convert_actual_code_to_modified_dict,
|
9 |
+
convert_modified_to_actual_code_string
|
10 |
+
)
|
11 |
+
from .line_based_parsing import (
|
12 |
+
clean_modified_dict,
|
13 |
+
convert_to_lines,
|
14 |
+
parse_line_based_query
|
15 |
+
)
|
16 |
+
from .schema_utils import schema_to_line_based
|
17 |
+
|
18 |
+
|
19 |
+
def modify_single_row_base_form(mongo_query: str, schema: Dict[str, Any]) -> str:
|
20 |
+
"""
|
21 |
+
Modifies a single MongoDB query string based on the provided schema and schema maps.
|
22 |
+
"""
|
23 |
+
try:
|
24 |
+
# Clean the query
|
25 |
+
mongo_query = clean_query(mongo_query)
|
26 |
+
# Build schema maps
|
27 |
+
in2out, out2in = build_schema_maps(schema)
|
28 |
+
# Convert the actual code to modified code
|
29 |
+
modified_query = convert_actual_code_to_modified_dict(mongo_query, out2in)
|
30 |
+
# Collection Name
|
31 |
+
collection_name = schema["collections"][0]["name"]
|
32 |
+
# Convert the modified code back to actual code
|
33 |
+
reconstructed_query = convert_modified_to_actual_code_string(modified_query, in2out, collection_name)
|
34 |
+
# Clean the reconstructed query
|
35 |
+
reconstructed_query = clean_query(reconstructed_query)
|
36 |
+
if reconstructed_query != mongo_query:
|
37 |
+
return None, None, None, None, None, None
|
38 |
+
else:
|
39 |
+
return mongo_query, modified_query, collection_name, in2out, out2in, schema
|
40 |
+
except Exception as _:
|
41 |
+
return None, None, None, None, None, None
|
42 |
+
|
43 |
+
|
44 |
+
def modify_all_rows_base_from(mongo_queries: List[str], schemas: List[Dict[str, Any]], nl_queries: List[str], additional_infos: List[str]) -> List[Dict[str, Any]]:
|
45 |
+
"""
|
46 |
+
Modifies all MongoDB queries based on the provided schemas.
|
47 |
+
"""
|
48 |
+
modified_queries = []
|
49 |
+
for i, (mongo_query, schema) in tqdm(enumerate(zip(mongo_queries, schemas)), total=len(mongo_queries), desc="Modifying Queries"):
|
50 |
+
mongo_query, modified_query, collection_name, in2out, out2in, schema = modify_single_row_base_form(mongo_query, schema)
|
51 |
+
if modified_query is not None:
|
52 |
+
modified_queries.append({
|
53 |
+
"mongo_query": mongo_query,
|
54 |
+
"natural_language_query": nl_queries[i],
|
55 |
+
"additional_info": additional_infos[i],
|
56 |
+
"modified_query": modified_query,
|
57 |
+
"collection_name": collection_name,
|
58 |
+
"in2out": in2out,
|
59 |
+
"out2in": out2in,
|
60 |
+
"schema": schema
|
61 |
+
})
|
62 |
+
return modified_queries
|
63 |
+
|
64 |
+
|
65 |
+
def modify_line_based_parsing(modified_query_data: str) -> Dict[str, Any]:
|
66 |
+
"""
|
67 |
+
Tests the line-based parsing of a modified MongoDB query.
|
68 |
+
"""
|
69 |
+
try:
|
70 |
+
modified_query = clean_modified_dict(modified_query_data["modified_query"])
|
71 |
+
lines = convert_to_lines(modified_query)
|
72 |
+
reconstructed_query = parse_line_based_query(lines)
|
73 |
+
if modified_query != reconstructed_query:
|
74 |
+
return None
|
75 |
+
else:
|
76 |
+
modified_query_data["line_based_query"] = lines
|
77 |
+
return modified_query_data
|
78 |
+
except Exception as e:
|
79 |
+
return None
|
80 |
+
|
81 |
+
|
82 |
+
def modify_all_line_based_parsing(modified_queries: List[Dict[str, Any]]):
|
83 |
+
"""
|
84 |
+
Tests the line-based parsing for all modified MongoDB queries.
|
85 |
+
"""
|
86 |
+
line_based_queries = []
|
87 |
+
for query_data in tqdm(modified_queries, desc="Testing Line-based Parsing", total=len(modified_queries)):
|
88 |
+
line_based_query = modify_line_based_parsing(query_data)
|
89 |
+
if line_based_query:
|
90 |
+
line_based_queries.append(line_based_query)
|
91 |
+
return line_based_queries
|
92 |
+
|
93 |
+
|
94 |
+
def modify_all_schema(query_data: List[Dict[str, Any]]) -> List[str]:
|
95 |
+
"""
|
96 |
+
Converts all schemas to line-based format.
|
97 |
+
"""
|
98 |
+
final_data = []
|
99 |
+
for query in tqdm(query_data, desc="Converting Schemas to Line-based Format", total=len(query_data)):
|
100 |
+
# try:
|
101 |
+
line_based_schema = schema_to_line_based(query["schema"])
|
102 |
+
# if line_based_schema:
|
103 |
+
query["line_based_schema"] = line_based_schema
|
104 |
+
final_data.append(query)
|
105 |
+
# except Exception as e:
|
106 |
+
# pass
|
107 |
+
# logger.debug(f"Line-based schema: {line_based_schema}")
|
108 |
+
return final_data
|
109 |
+
|
110 |
+
|
111 |
+
def load_csv(file_path: str) -> pd.DataFrame:
|
112 |
+
"""
|
113 |
+
Loads a CSV file into a pandas DataFrame.
|
114 |
+
"""
|
115 |
+
try:
|
116 |
+
df = pd.read_csv(file_path)
|
117 |
+
logger.info(f"Loaded CSV file: {file_path}")
|
118 |
+
return df
|
119 |
+
except Exception as e:
|
120 |
+
logger.error(f"Error loading CSV file: {e}")
|
121 |
+
raise e
|
122 |
+
|
123 |
+
|
124 |
+
def modify_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
125 |
+
"""
|
126 |
+
Modifies a DataFrame by applying the modify_all_rows function.
|
127 |
+
"""
|
128 |
+
logger.info("Modifying DataFrame...")
|
129 |
+
logger.debug(f"input DataFrame length: {len(df)}")
|
130 |
+
mongo_queries = df["mongo_query"].tolist()
|
131 |
+
schemas = df["schema"].apply(eval).tolist()
|
132 |
+
nl_queries = df["natural_language_query"].tolist()
|
133 |
+
additional_infos = df["additional_info"].tolist()
|
134 |
+
modified_queries = modify_all_rows_base_from(mongo_queries, schemas, nl_queries, additional_infos)
|
135 |
+
logger.debug(f"Modified queries length: {len(modified_queries)}")
|
136 |
+
line_based_queries = modify_all_line_based_parsing(modified_queries)
|
137 |
+
logger.debug(f"Line-based queries length: {len(line_based_queries)}")
|
138 |
+
final_data = modify_all_schema(line_based_queries)
|
139 |
+
logger.debug(f"Modified schemas length: {len(final_data)}")
|
140 |
+
return final_data
|
141 |
+
|
142 |
+
def main(final_data: List[Dict[str, Any]]):
|
143 |
+
# try reconstructing original query from line-based query
|
144 |
+
for i in range(len(final_data)):
|
145 |
+
index_allowed = [746]
|
146 |
+
if i in index_allowed:
|
147 |
+
continue
|
148 |
+
original_query = final_data[i]["mongo_query"]
|
149 |
+
line_based_query = final_data[i]["line_based_query"]
|
150 |
+
# reconstructed modified query
|
151 |
+
reconstructed_modified_query = parse_line_based_query(line_based_query)
|
152 |
+
# reconstructed original query
|
153 |
+
reconstructed_original_query = convert_modified_to_actual_code_string(reconstructed_modified_query, final_data[i]["in2out"], final_data[i]["collection_name"])
|
154 |
+
if original_query != clean_query(reconstructed_original_query):
|
155 |
+
|
156 |
+
logger.error(f"index: {i}")
|
157 |
+
logger.error(f"Original query: {original_query}")
|
158 |
+
logger.error(f"Reconstructed original query: {reconstructed_original_query}")
|
159 |
+
logger.error(f"Modified query: {final_data[i]['modified_query']}")
|
160 |
+
logger.error(f"Reconstructed modified query: {reconstructed_modified_query}")
|
161 |
+
logger.error(f"Line-based query: {line_based_query}")
|
162 |
+
# logger.error(f"Schema: {final_data[i]['schema']}")
|
163 |
+
logger.warning("--------------------------------------------------")
|
164 |
+
assert original_query == clean_query(reconstructed_original_query), f"Original query does not match reconstructed original query at index {i}"
|
165 |
+
exit(0)
|
166 |
+
|
167 |
+
|
168 |
+
if __name__ == "__main__":
|
169 |
+
pdf_path = "./data_v3/data_v2.csv"
|
170 |
+
df = load_csv(pdf_path)
|
171 |
+
final_data = modify_dataframe(df)
|
172 |
+
# main(final_data)
|
173 |
+
logger.info(f"Final data length: {len(final_data)}")
|
174 |
+
logger.debug(f"Final data type: {final_data[0]}\n\n")
|
175 |
+
|
176 |
+
for i, (query_data) in enumerate(final_data):
|
177 |
+
logger.debug(f"Modified schema {i}: {query_data['line_based_schema']}")
|
178 |
+
logger.debug(f"Line-based query {i}: {query_data['line_based_query']}")
|
179 |
+
logger.debug(f"NL query {i}: {query_data['natural_language_query']}")
|
180 |
+
logger.debug(f"Additional info {i}: {query_data['additional_info']}")
|
181 |
+
print('\n\n\n\n')
|
182 |
+
if i > 3:
|
183 |
+
break
|
requirements.txt
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles>=23.0
|
2 |
+
annotated-types>=0.5.0
|
3 |
+
anyio>=3.7.1
|
4 |
+
certifi>=2023.11.17
|
5 |
+
charset-normalizer==3.4.2
|
6 |
+
click==8.2.1
|
7 |
+
diskcache==5.6.3
|
8 |
+
fastapi==0.115.12
|
9 |
+
ffmpy==0.6.0
|
10 |
+
filelock==3.18.0
|
11 |
+
fsspec==2025.5.1
|
12 |
+
gradio==5.32.1
|
13 |
+
gradio_client==1.10.2
|
14 |
+
groovy==0.1.2
|
15 |
+
h11==0.16.0
|
16 |
+
hf-xet==1.1.3
|
17 |
+
httpcore==1.0.9
|
18 |
+
httpx==0.28.1
|
19 |
+
huggingface-hub==0.32.4
|
20 |
+
idna==3.10
|
21 |
+
Jinja2==3.1.6
|
22 |
+
llama_cpp_python==0.3.9
|
23 |
+
loguru==0.7.3
|
24 |
+
markdown-it-py==3.0.0
|
25 |
+
MarkupSafe==3.0.2
|
26 |
+
mdurl==0.1.2
|
27 |
+
numpy==2.2.6
|
28 |
+
orjson==3.10.18
|
29 |
+
packaging==25.0
|
30 |
+
pandas==2.2.3
|
31 |
+
pillow==11.2.1
|
32 |
+
pydantic==2.11.5
|
33 |
+
pydantic-settings==2.9.1
|
34 |
+
pydantic_core==2.33.2
|
35 |
+
pydub==0.25.1
|
36 |
+
Pygments==2.19.1
|
37 |
+
python-dateutil==2.9.0.post0
|
38 |
+
python-dotenv==1.1.0
|
39 |
+
python-multipart==0.0.20
|
40 |
+
pytz==2025.2
|
41 |
+
PyYAML==6.0.2
|
42 |
+
requests==2.32.3
|
43 |
+
rich==14.0.0
|
44 |
+
ruff==0.11.12
|
45 |
+
safehttpx==0.1.6
|
46 |
+
semantic-version==2.10.0
|
47 |
+
shellingham==1.5.4
|
48 |
+
six==1.17.0
|
49 |
+
sniffio==1.3.1
|
50 |
+
sse-starlette==2.3.6
|
51 |
+
starlette==0.46.2
|
52 |
+
starlette-context==0.4.0
|
53 |
+
tomlkit==0.13.2
|
54 |
+
tqdm==4.67.1
|
55 |
+
typer==0.16.0
|
56 |
+
typing-inspection==0.4.1
|
57 |
+
typing_extensions==4.14.0
|
58 |
+
tzdata==2025.2
|
59 |
+
urllib3==2.4.0
|
60 |
+
uvicorn==0.34.3
|
61 |
+
websockets==15.0.1
|
62 |
+
wget
|