Caleb Fahlgren commited on
Commit
467c2a7
·
1 Parent(s): 9fc2d21

fix pickle issue by using dict instead of pydantic model

Browse files
Files changed (1) hide show
  1. app.py +19 -18
app.py CHANGED
@@ -86,7 +86,7 @@ CREATE TABLE {} (
86
 
87
 
88
  @spaces.GPU
89
- def generate_query(dataset_id: str, query: str) -> str:
90
  ddl = get_dataset_ddl(dataset_id)
91
 
92
  system_prompt = f"""
@@ -118,37 +118,38 @@ def generate_query(dataset_id: str, query: str) -> str:
118
 
119
  print("Received Response: ", resp)
120
 
121
- return resp
122
 
123
 
124
  def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.Figure]:
125
- response: SQLResponse = generate_query(dataset_id, query)
126
 
127
  print("Querying Parquet...")
128
- df = conn.execute(response.sql).fetchdf()
129
 
130
  plot = None
131
 
132
- # handle incorrect data and label keys better
133
- if response.label_key and response.label_key not in df.columns:
134
- response.label_key = None
135
- if response.data_key and response.data_key not in df.columns:
136
- response.data_key = None
137
 
138
- if response.visualization_type == OutputTypes.LINECHART:
139
- plot = df.plot(
140
- kind="line", x=response.label_key, y=response.data_key
141
- ).get_figure()
 
 
 
 
142
  plt.xticks(rotation=45, ha="right")
143
  plt.tight_layout()
144
- elif response.visualization_type == OutputTypes.BARCHART:
145
- plot = df.plot(
146
- kind="bar", x=response.label_key, y=response.data_key
147
- ).get_figure()
148
  plt.xticks(rotation=45, ha="right")
149
  plt.tight_layout()
150
 
151
- markdown_output = f"""```sql\n{response.sql}\n```"""
152
  return df, markdown_output, plot
153
 
154
 
 
86
 
87
 
88
  @spaces.GPU
89
+ def generate_query(dataset_id: str, query: str) -> dict:
90
  ddl = get_dataset_ddl(dataset_id)
91
 
92
  system_prompt = f"""
 
118
 
119
  print("Received Response: ", resp)
120
 
121
+ return resp.model_dump()
122
 
123
 
124
  def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.Figure]:
125
+ response = generate_query(dataset_id, query)
126
 
127
  print("Querying Parquet...")
128
+ df = conn.execute(response.get("sql")).fetchdf()
129
 
130
  plot = None
131
 
132
+ label_key = response.get("label_key")
133
+ data_key = response.get("data_key")
134
+ viz_type = response.get("visualization_type")
135
+ sql = response.get("sql")
 
136
 
137
+ # handle incorrect data and label keys
138
+ if label_key and label_key not in df.columns:
139
+ label_key = None
140
+ if data_key and data_key not in df.columns:
141
+ data_key = None
142
+
143
+ if viz_type == OutputTypes.LINECHART:
144
+ plot = df.plot(kind="line", x=label_key, y=data_key).get_figure()
145
  plt.xticks(rotation=45, ha="right")
146
  plt.tight_layout()
147
+ elif viz_type == OutputTypes.BARCHART:
148
+ plot = df.plot(kind="bar", x=label_key, y=data_key).get_figure()
 
 
149
  plt.xticks(rotation=45, ha="right")
150
  plt.tight_layout()
151
 
152
+ markdown_output = f"""```sql\n{sql}\n```"""
153
  return df, markdown_output, plot
154
 
155