File size: 3,074 Bytes
165d3b1
 
 
 
 
 
 
265c22b
 
165d3b1
 
 
9915de3
165d3b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8306023
165d3b1
 
 
 
 
9915de3
165d3b1
 
 
9915de3
 
 
 
 
 
 
 
 
 
 
 
165d3b1
 
2cb3c29
 
3dc8d65
9915de3
3dc8d65
9915de3
 
7e6be4d
 
 
 
 
3dc8d65
7e6be4d
 
3365037
7e6be4d
 
 
 
3365037
165d3b1
 
 
 
 
73ee6aa
3dc8d65
7e6be4d
3dc8d65
 
8306023
 
 
 
 
3365037
 
8306023
 
 
 
dda7aa3
6ba0f89
265c22b
9915de3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from  dotenv import load_dotenv
load_dotenv()

import streamlit as st
import google.generativeai as genai
import sqlite3
import os
import pyperclip


genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
model=genai.GenerativeModel('gemini-pro')
 

prompt=[
    """
    You are an expert in converting English questions to SQL query!
    The SQL database has the name STUDENT and has the following columns - NAME, CLASS, 
    SECTION \n\nFor example,\nExample 1 - How many entries of records are present?, 
    the SQL command will be something like this SELECT COUNT(*) FROM STUDENT ;
    \nExample 2 - Tell me all the students studying in Data Science class?, 
    the SQL command will be something like this SELECT * FROM STUDENT 
    where CLASS="Data Science"; 
    \nExample 3-i marks should be greater 40 and atleast 2 person have score above 40 , retrive that class
    the SQL command will be something like this SELECT CLASS FROM STUDENT GROUP BY CLASS HAVING COUNT(*) >= 2 AND AVG(MARKS) > 50;
    \nExample 4-Find the names and marks of students in the Science class who have scored more than 60 marks.
    SQL Command Example: SELECT NAME, MARKS FROM STUDENT WHERE CLASS='Science' AND MARKS > 60;
    \nExample 5-List the classes with the highest average marks.
    SQL Command Example: SELECT CLASS FROM STUDENT GROUP BY CLASS HAVING AVG(MARKS) = (SELECT MAX(AVG(MARKS)) FROM STUDENT GROUP BY CLASS);
    
    also the sql code should not have ``` in beginning or end and sql word in output and 
     
    """
]


#llm response
def gemini_sql_query(prompt,input):
    response=model.generate_content([prompt[0],input])
    return response.text
    
#dun to retrieve query from the sql database
def read_sql_query(sql,db):
    conn=sqlite3.connect(db)
    cursor=conn.cursor()
    cursor.execute(sql)
    rows=cursor.fetchall()
    conn.commit()
    conn.close()
    for row in rows:
        print(row)
    return rows


st.set_page_config("DataChat: Explore Your Database")
st.header("chat with your sql database")

input=st.text_input("enter your input/question and specify correct table name")


#save uploaded file 
def save_uploaded_file(uploaded_file):
    file_path = os.path.join(os.getcwd(), "uploaded.db")
    with open(file_path, "wb") as f:
        f.write(uploaded_file.getbuffer())
    return file_path

# File uploader component
uploaded_file = st.file_uploader("Upload SQLite Database", type=["db"])

if uploaded_file is not None:
    # Save the uploaded file
    db_path = save_uploaded_file(uploaded_file)
    st.success("Database uploaded successfully.")

submit=st.button("submit")




if submit and uploaded_file and input:
    query=gemini_sql_query(prompt,input)
    response=read_sql_query(query,db_path)
    print(query)
    st.header("response")
    col1, col2 = st.columns(2)
    
    with col1:
        st.header("Response:")
        for row in response:
            values = [str(value) for value in row]
            st.write(*values)

    with col2:
        st.header("Generated SQL Query:")
        st.code(query)