garvit2205's picture
Create app.py
c1a0d97 verified
raw
history blame
5.04 kB
from dotenv import load_dotenv
load_dotenv()
api = os.getenv("groq_api_key")
from sentence_transformers import SentenceTransformer
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity
from groq import Groq
def create_metadata_embeddings():
student="""
Table: student
Columns:
- student_id: an integer representing the unique ID of a student.
- first_name: a string containing the first name of the student.
- last_name: a string containing the last name of the student.
- date_of_birth: a date representing the student's birthdate.
- email: a string for the student's email address.
- phone_number: a string for the student's contact number.
- major: a string representing the student's major field of study.
- year_of_enrollment: an integer for the year the student enrolled.
"""
employee="""
Table: employee
Columns:
- employee_id: an integer representing the unique ID of an employee.
- first_name: a string containing the first name of the employee.
- last_name: a string containing the last name of the employee.
- email: a string for the employee's email address.
- department: a string for the department the employee works in.
- position: a string representing the employee's job title.
- salary: a float representing the employee's salary.
- date_of_joining: a date for when the employee joined the college.
"""
course="""
Table: course_info
Columns:
- course_id: an integer representing the unique ID of the course.
- course_name: a string containing the course's name.
- course_code: a string for the course's unique code.
- instructor_id: an integer for the ID of the instructor teaching the course.
- department: a string for the department offering the course.
- credits: an integer representing the course credits.
- semester: a string for the semester when the course is offered.
"""
metadata_list = [student, employee, course]
model = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = model.encode(metadata_list)
return embeddings,model,student,employee,course
def find_best_fit(embeddings,model,user_query,student,employee,course):
query_embedding = model.encode([user_query])
similarities = cosine_similarity(query_embedding, embeddings)
best_match_table = similarities.argmax()
if(best_match_table==0):
table_metadata=student
elif(best_match_table==1):
table_metadata=employee
else:
table_metadata=course
return table_metadata
def create_prompt(user_query,table_metadata):
system_prompt="""
You are a SQL query generator specialized in generating SELECT queries for a single table at a time. Your task is to accurately convert natural language queries into SQL SELECT statements based on the user's intent and the provided table metadata.
Rules:
Focus on SELECT Queries: Only generate SELECT queries. Do not generate INSERT, UPDATE, DELETE, or multi-table JOINs.
Single Table Only: Assume all queries are related to a single table provided in the metadata. Ignore any references to other tables.
Metadata-Based Validation: Always ensure the generated query matches the table name, columns, and data types provided in the metadata.
User Intent: Accurately capture the user's requirements, such as filters, sorting, or aggregations, as expressed in natural language.
SQL Syntax: Use standard SQL syntax that is compatible with most relational database systems.
Input Format:
User Query: The user's natural language request.
Table Metadata: The structure of the relevant table, including the table name, column names, and data types.
Output Format:
SQL Query: A valid SELECT query formatted for readability.
Do not output anything else except the SQL query.Not even a single word extra.Ouput the whole query in a single line only.
You are ready to generate SQL queries based on the user input and table metadata.
"""
user_prompt=f"""
User Query: {user_query}
Table Metadata: {table_metadata}
"""
return system_prompt,user_prompt
def generate_output(system_prompt,user_prompt):
client = Groq(api_key=api,)
chat_completion = client.chat.completions.create(messages=[
{"role": "system", "content": system_prompt},
{"role": "user","content": user_prompt,}],model="llama3-70b-8192",)
res = chat_completion.choices[0].message.content
select=res[0:6].lower()
if(select=="select"):
output=res
else:
output="Can't perform the task at the moment."
return output
def response(user_query):
embeddings,model,student,employee,course=create_metadata_embeddings()
table_metadata=find_best_fit(embeddings,model,user_query,student,employee,course)
system_prompt,user_prompt=create_prompt(user_query,table_metadata)
output=generate_output(system_prompt,user_prompt)
return output
demo = gr.Interface(
fn=response,
inputs=gr.Textbox(label="Please provide the natural language query"),
outputs=gr.Textbox(label="SQL Query"),
title="SQL Query generator"
)
demo.launch(share="True")