LeonceNsh commited on
Commit
01e2f1a
·
verified ·
1 Parent(s): 653fb3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -19
app.py CHANGED
@@ -1,17 +1,19 @@
1
  import json
2
- import openai
3
  import gradio as gr
4
  import duckdb
5
  import tempfile
6
  from functools import lru_cache
7
  import os
 
8
 
9
  # =========================
10
  # Configuration and Setup
11
  # =========================
12
 
13
- openai.api_key = os.getenv("OPENAI_API_KEY")
14
- dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
 
 
15
 
16
  schema = [
17
  {"column_name": "department_ind_agency", "column_type": "VARCHAR"},
@@ -59,26 +61,46 @@ COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
59
  # OpenAI API Integration
60
  # =========================
61
 
 
 
 
62
  def parse_query(nl_query):
63
- messages = [
64
- {"role": "system", "content": "You are an assistant that converts natural language queries into SQL queries for the 'contract_data' table."},
65
- {"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
66
- ]
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
68
  try:
69
- response = openai.chat.completions.create(
70
- model="gpt-4o-mini",
71
- messages=messages,
72
- temperature=0,
73
- max_tokens=150,
74
- )
75
- sql_query = response.choices[0].message.content.strip()
76
- # Remove surrounding backticks and formatting artifacts
77
- if sql_query.startswith("```") and sql_query.endswith("```"):
78
- sql_query = sql_query[sql_query.find('\n')+1:sql_query.rfind('\n')].strip()
79
- return sql_query, ""
 
 
80
  except Exception as e:
81
- return "", f"Error generating SQL query: {e}"
 
82
 
83
  # =========================
84
  # Database Interaction
 
1
  import json
 
2
  import gradio as gr
3
  import duckdb
4
  import tempfile
5
  from functools import lru_cache
6
  import os
7
+ import replicate
8
 
9
  # =========================
10
  # Configuration and Setup
11
  # =========================
12
 
13
+ replicate_api_token = os.getenv("REPLICATE_API_TOKEN")
14
+ replicate.Client(api_token=replicate_api_token)
15
+
16
+ dataset_path = 'sample_contract_df.parquet'
17
 
18
  schema = [
19
  {"column_name": "department_ind_agency", "column_type": "VARCHAR"},
 
61
  # OpenAI API Integration
62
  # =========================
63
 
64
+ import replicate
65
+ import json
66
+
67
  def parse_query(nl_query):
68
+ system_prompt = (
69
+ "You are an assistant that converts natural language into SQL queries for a DuckDB database. "
70
+ "The table is named `contract_data`. Only return the SQL query, no explanations."
71
+ )
72
+
73
+ user_prompt = f"""\
74
+ Schema:
75
+ {json.dumps(schema, indent=2)}
76
+
77
+ Natural Language Query:
78
+ "{nl_query}"
79
+
80
+ SQL:"""
81
 
82
+ full_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{system_prompt}<|eot|>" \
83
+ f"<|start_header_id|>user<|end_header_id|>\n{user_prompt}<|eot|>" \
84
+ f"<|start_header_id|>assistant<|end_header_id|>\n"
85
+
86
+ output = ""
87
  try:
88
+ for event in replicate.stream(
89
+ "meta/meta-llama-3-8b-instruct",
90
+ input={
91
+ "top_p": 0.9,
92
+ "temperature": 0.2,
93
+ "max_tokens": 300,
94
+ "prompt": full_prompt
95
+ }
96
+ ):
97
+ output += str(event)
98
+
99
+ return output.strip()
100
+
101
  except Exception as e:
102
+ return f"-- Error generating SQL: {e}"
103
+
104
 
105
  # =========================
106
  # Database Interaction