File size: 3,212 Bytes
165d3b1
 
 
 
 
 
 
265c22b
165d3b1
 
 
9915de3
165d3b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8306023
165d3b1
 
 
 
 
9915de3
165d3b1
 
 
9915de3
 
 
 
 
 
 
 
 
 
 
 
165d3b1
9915de3
e96573b
 
 
e07f113
c1c0ae0
f71e74f
e07f113
f71e74f
e96573b
 
 
c1c0ae0
9915de3
7e6be4d
 
 
 
 
3dc8d65
7e6be4d
6940770
 
3365037
7e6be4d
 
 
d7cbb36
 
 
165d3b1
 
 
 
 
d7cbb36
73ee6aa
3dc8d65
7e6be4d
3dc8d65
8306023
 
 
 
 
3365037
 
8306023
 
 
53821ac
 
dda7aa3
6ba0f89
265c22b
9915de3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from  dotenv import load_dotenv
load_dotenv()

import streamlit as st
import google.generativeai as genai
import sqlite3
import os


genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
model=genai.GenerativeModel('gemini-pro')
 

prompt=[
    """
    You are an expert in converting English questions to SQL query!
    The SQL database has the name STUDENT and has the following columns - NAME, CLASS, 
    SECTION \n\nFor example,\nExample 1 - How many entries of records are present?, 
    the SQL command will be something like this SELECT COUNT(*) FROM STUDENT ;
    \nExample 2 - Tell me all the students studying in Data Science class?, 
    the SQL command will be something like this SELECT * FROM STUDENT 
    where CLASS="Data Science"; 
    \nExample 3-i marks should be greater 40 and atleast 2 person have score above 40 , retrive that class
    the SQL command will be something like this SELECT CLASS FROM STUDENT GROUP BY CLASS HAVING COUNT(*) >= 2 AND AVG(MARKS) > 50;
    \nExample 4-Find the names and marks of students in the Science class who have scored more than 60 marks.
    SQL Command Example: SELECT NAME, MARKS FROM STUDENT WHERE CLASS='Science' AND MARKS > 60;
    \nExample 5-List the classes with the highest average marks.
    SQL Command Example: SELECT CLASS FROM STUDENT GROUP BY CLASS HAVING AVG(MARKS) = (SELECT MAX(AVG(MARKS)) FROM STUDENT GROUP BY CLASS);
    
    also the sql code should not have ``` in beginning or end and sql word in output and 
     
    """
]


#llm response
def gemini_sql_query(prompt,input):
    response=model.generate_content([prompt[0],input])
    return response.text
    
#dun to retrieve query from the sql database
def read_sql_query(sql,db):
    conn=sqlite3.connect(db)
    cursor=conn.cursor()
    cursor.execute(sql)
    rows=cursor.fetchall()
    conn.commit()
    conn.close()
    for row in rows:
        print(row)
    return rows


st.set_page_config("DataChat: Explore Your Database")
st.header("DataChat: Chat With SQL Database")


 
question=st.text_input("Enter your input/question")

table_name = st.text_input("Enter the correct table name")

input=f"{question} in {table_name} table"


#save uploaded file 
def save_uploaded_file(uploaded_file):
    file_path = os.path.join(os.getcwd(), "uploaded.db")
    with open(file_path, "wb") as f:
        f.write(uploaded_file.getbuffer())
    return file_path

# File uploader component
st.sidebar.header("Database Upload")
uploaded_file = st.sidebar.file_uploader("Upload SQLite Database", type=["db"])

if uploaded_file is not None:
    # Save the uploaded file
    db_path = save_uploaded_file(uploaded_file)
    st.sidebar.success("Database uploaded successfully.")


submit=st.button("submit")





if submit and uploaded_file and input:
    query=gemini_sql_query(prompt,input)
    response=read_sql_query(query,db_path)
    print(query)
    col1, col2 = st.columns(2)
    
    with col1:
        st.header("Response:")
        for row in response:
            values = [str(value) for value in row]
            st.write(*values)

    with col2:
        st.header("Generated SQL Query:")
        with st.container(height=300):
            st.code(query)