Spaces:
Runtime error
Runtime error
Caleb Fahlgren
commited on
Commit
·
467c2a7
1
Parent(s):
9fc2d21
fix pickle issue by using dict instead of pydantic model
Browse files
app.py
CHANGED
@@ -86,7 +86,7 @@ CREATE TABLE {} (
|
|
86 |
|
87 |
|
88 |
@spaces.GPU
|
89 |
-
def generate_query(dataset_id: str, query: str) ->
|
90 |
ddl = get_dataset_ddl(dataset_id)
|
91 |
|
92 |
system_prompt = f"""
|
@@ -118,37 +118,38 @@ def generate_query(dataset_id: str, query: str) -> str:
|
|
118 |
|
119 |
print("Received Response: ", resp)
|
120 |
|
121 |
-
return resp
|
122 |
|
123 |
|
124 |
def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.Figure]:
|
125 |
-
response
|
126 |
|
127 |
print("Querying Parquet...")
|
128 |
-
df = conn.execute(response.sql).fetchdf()
|
129 |
|
130 |
plot = None
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
response.data_key = None
|
137 |
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
142 |
plt.xticks(rotation=45, ha="right")
|
143 |
plt.tight_layout()
|
144 |
-
elif
|
145 |
-
plot = df.plot(
|
146 |
-
kind="bar", x=response.label_key, y=response.data_key
|
147 |
-
).get_figure()
|
148 |
plt.xticks(rotation=45, ha="right")
|
149 |
plt.tight_layout()
|
150 |
|
151 |
-
markdown_output = f"""```sql\n{
|
152 |
return df, markdown_output, plot
|
153 |
|
154 |
|
|
|
86 |
|
87 |
|
88 |
@spaces.GPU
|
89 |
+
def generate_query(dataset_id: str, query: str) -> dict:
|
90 |
ddl = get_dataset_ddl(dataset_id)
|
91 |
|
92 |
system_prompt = f"""
|
|
|
118 |
|
119 |
print("Received Response: ", resp)
|
120 |
|
121 |
+
return resp.model_dump()
|
122 |
|
123 |
|
124 |
def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.Figure]:
|
125 |
+
response = generate_query(dataset_id, query)
|
126 |
|
127 |
print("Querying Parquet...")
|
128 |
+
df = conn.execute(response.get("sql")).fetchdf()
|
129 |
|
130 |
plot = None
|
131 |
|
132 |
+
label_key = response.get("label_key")
|
133 |
+
data_key = response.get("data_key")
|
134 |
+
viz_type = response.get("visualization_type")
|
135 |
+
sql = response.get("sql")
|
|
|
136 |
|
137 |
+
# handle incorrect data and label keys
|
138 |
+
if label_key and label_key not in df.columns:
|
139 |
+
label_key = None
|
140 |
+
if data_key and data_key not in df.columns:
|
141 |
+
data_key = None
|
142 |
+
|
143 |
+
if viz_type == OutputTypes.LINECHART:
|
144 |
+
plot = df.plot(kind="line", x=label_key, y=data_key).get_figure()
|
145 |
plt.xticks(rotation=45, ha="right")
|
146 |
plt.tight_layout()
|
147 |
+
elif viz_type == OutputTypes.BARCHART:
|
148 |
+
plot = df.plot(kind="bar", x=label_key, y=data_key).get_figure()
|
|
|
|
|
149 |
plt.xticks(rotation=45, ha="right")
|
150 |
plt.tight_layout()
|
151 |
|
152 |
+
markdown_output = f"""```sql\n{sql}\n```"""
|
153 |
return df, markdown_output, plot
|
154 |
|
155 |
|