File size: 2,052 Bytes
06a2e89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from dotenv import load_dotenv
import streamlit as st
import sqlite3
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser

load_dotenv()

# function to retrieve data from the database
def read_sql_query(sql, db):
    print(sql)
    conn = sqlite3.connect(db)
    cursor = conn.cursor()
    rows = cursor.execute(sql)
    # for row in rows:
    #     print(row)
    return rows

model = ChatOpenAI(model='gpt-4o-mini')

system_prompt = """

You are an expert in converting English questions to SQL code!

the sql database consists of table 'student' which has columns

'name', 'class' and 'section'.



Example 1:

Input: How many entries of records are present?

Output: SELECT COUNT(*) FROM student;



Example 2:

Input: List all the students in the frontend class.

Output: SELECT * FROM student WHERE class='frontend';



also, sql code should not have ``` in neither beginning nor end.



also, if given query cannot be converted to sql, return 

"Given query cannot be converted to SQL"

"""

prompt_template = ChatPromptTemplate.from_messages(
    [
        ('system', system_prompt),
        ('human', '{input}')
    ]
)

chain = prompt_template | model | StrOutputParser()

# Streamlit app
st.set_page_config(
    page_title='I will retrieve any SQL query'
)
st.header('Retrieve SQL data in plain English')

question = st.text_input(label='Query the database', placeholder='Enter your query in plain english here')
submit = st.button('Query')

if submit:
    if question:
        sql = chain.invoke({'input': question})
        if sql == 'Given query cannot be converted to SQL':
           st.subheader('Given query cannot be translated to SQL') 
        else:
            st.subheader('Generated SQL')
            st.write(sql)
            result = read_sql_query(sql, 'student_records.db')
            st.subheader('Retrieved Data')
            for row in result:
                st.write(row)