allinaigc commited on
Commit
944dc18
·
verified ·
1 Parent(s): eb63e0e

Upload sql_command.py

Browse files
Files changed (1) hide show
  1. sql_command.py +50 -0
sql_command.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+
3
+ '''
4
+ ##TODO:
5
+
6
+ import sqlite3
7
+ import pandas as pd
8
+
9
+ ## run the following function to set up a dedicated database using SQLite.
10
+ def construct_db(data_file=None):
11
+ excel_file = "/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/Coding/Text2SQL/模拟数据.csv" # Replace with your actual file path
12
+ df = pd.read_csv(excel_file)
13
+ print('df:', df.head())
14
+
15
+ conn = sqlite3.connect('myexcelDB.db') # Replace 'mydatabase.db' with your desired name
16
+ # Create a cursor object to execute SQL commands
17
+ cursor = conn.cursor()
18
+
19
+ ##NOTE: Insert data from DataFrame into the table. 注意这里的if_exists选项,考虑是否要覆盖原始内容。这里如果需要指定index,那么就需要指定index_label,且index=False,否则会选择默认的index,这样就会产生duplicate错误。
20
+ df.to_sql('table01', conn, if_exists='replace', index=False, index_label="产品ID")
21
+
22
+ return None
23
+
24
+
25
+ def llm_query(sql_command):
26
+ # Connect to the database (or create it if it doesn't exist)
27
+ conn = sqlite3.connect('./myexcelDB.db') # Replace 'myexcelDB.db' with your desired name
28
+
29
+ # Create a cursor object to execute SQL commands
30
+ cursor = conn.cursor()
31
+ ## SQL command
32
+ cursor.execute(sql_command)
33
+ query_result = cursor.fetchall() # Fetch all rows as a list of tuples
34
+ # print('query_result:', query_result)
35
+
36
+ ## 将列名也取出来。
37
+ column_names = [description[0] for description in cursor.description]
38
+
39
+ query_df = pd.DataFrame(query_result, columns=column_names)
40
+ # query_df.set_index("产品ID", inplace=True)
41
+ print('query_df:', query_df)
42
+
43
+ return query_df
44
+
45
+ # llm_query("SELECT * FROM table01 WHERE 宽度 > 300") ## sample function call.
46
+
47
+ # construct_db()
48
+ # sql = "SELECT 产品ID FROM table01 WHERE 长度 > 50"
49
+ # res = llm_query(sql)
50
+ # res