Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import os | |
import sqlite3 | |
import google.generativeai as genai | |
import time | |
# Initialize Gemini | |
gemini_key = os.getenv("gemini_key") | |
genai.configure(api_key=gemini_key) | |
genai_model = genai.GenerativeModel('gemini-pro') | |
class SQLPromptModel: | |
def __init__(self, database): | |
# Initialize with database file path and create connection | |
self.database = database | |
self.conn = sqlite3.connect(self.database) | |
def fetch_table_schema(self, table_name): | |
# Get database table structure | |
cursor = self.conn.cursor() | |
# PRAGMA table_info returns: | |
# (id, name, type, notnull, default_value, primary_key) | |
cursor.execute(f"PRAGMA table_info({table_name})") | |
schema = cursor.fetchall() | |
return schema if schema else None | |
def text2sql_gemini(self, schema, user_prompt, inp_prompt=None): | |
# Convert table columns to string format | |
print("Schema",schema) | |
table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema]) | |
# Create prompt for Gemini AI | |
prompt = f"""Below are SQL table schemas paired with instructions that describe a task. | |
Using valid SQLite, write a response that appropriately completes the request for the provided tables. | |
### Instruction: {user_prompt} ### | |
Input: CREATE TABLE sql_pdf({table_columns}); | |
### Response: (Return only generated query based on user_prompt , nothing extra)""" | |
# Replace default prompt with user input if provided | |
if inp_prompt is not None: | |
prompt = prompt.replace(user_prompt, inp_prompt + " ") | |
# Get SQL query from Gemini | |
completion = genai_model.generate_content(prompt) | |
generated_query = completion.text | |
# Extract just the SQL query | |
start_index = generated_query.find("SELECT") | |
end_index = generated_query.find(";", start_index) + 1 | |
if start_index != -1 and end_index != 0: | |
return generated_query[start_index:end_index] | |
return generated_query | |
def execute_query(self, query): | |
# Execute SQL query and get results | |
cur = self.conn.cursor() | |
cur.execute(query) | |
# Get column names | |
columns = [header[0] for header in cur.description] | |
# Get all rows | |
rows = [row for row in cur.fetchall()] | |
cur.close() | |
self.conn.commit() | |
return rows, columns | |
def execute_sql_query(input_prompt): | |
# Database file path | |
database = r"sql_pdf.db" | |
sql_model = SQLPromptModel(database) | |
# Default prompt if none provided | |
user_prompt = "Give complete details of properties in India" | |
# Try operation up to 3 times | |
for _ in range(3): | |
try: | |
# Get database structure | |
table_schema = sql_model.fetch_table_schema("sql_pdf") | |
if table_schema: | |
# Generate and execute query | |
if input_prompt.strip(): | |
query = sql_model.text2sql_gemini(table_schema, user_prompt, input_prompt) | |
else: | |
query = sql_model.text2sql_gemini(table_schema, user_prompt, user_prompt) | |
rows, columns = sql_model.execute_query(query) | |
# Return formatted results | |
return {"Query": query, "Results": rows, "Columns": columns} | |
else: | |
return {"error": "Table schema not found."} | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
time.sleep(1) # Wait 1 second before retry | |
return {"error": "Failed to execute query after 3 retries."} | |
# Load the image | |
# Load database schema image | |
image = Image.open(os.path.join(os.path.abspath(''), "house_excel_sheet.png")) | |
# Create web interface | |
with gr.Blocks(title="House Database Query") as demo: | |
# Header | |
gr.Markdown("# House Database Query System") | |
# Display database schema image | |
gr.Image(image) | |
# Description | |
gr.Markdown("""### The database contains information about different properties including their fundamental details. | |
You can query this database using natural language.""") | |
# Input section | |
with gr.Row(): | |
query_input = gr.Textbox( | |
lines=2, | |
label="Database Query", | |
placeholder="Enter your query or choose from examples below. Default: 'Properties in India'" | |
) | |
# Submit button section | |
with gr.Row(): | |
submit_btn = gr.Button("Submit Query", variant="primary") | |
# Results section | |
with gr.Row(): | |
query_output = gr.JSON(label="Query Results") | |
# Connect button click to query function | |
submit_btn.click( | |
fn=execute_sql_query, | |
inputs=query_input, | |
outputs=query_output | |
) | |
# Example queries section | |
gr.Examples( | |
examples=[ | |
"Properties in France", | |
"Properties greater than an acre", | |
"Properties with more than 400 bedrooms" | |
], | |
inputs=query_input, | |
outputs=query_output, | |
fn=execute_sql_query | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |