Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import os | |
import sqlite3 | |
import google.generativeai as genai | |
import time | |
# Initialize Gemini | |
genai.configure(api_key="YOUR_GEMINI_API_KEY") | |
genai_model = genai.GenerativeModel('gemini-pro') | |
class SQLPromptModel: | |
def __init__(self, database): | |
self.database = database | |
self.conn = sqlite3.connect(self.database) | |
def fetch_table_schema(self, table_name): | |
cursor = self.conn.cursor() | |
cursor.execute(f"PRAGMA table_info({table_name})") | |
schema = cursor.fetchall() | |
return schema if schema else None | |
def text2sql_gemini(self, schema, user_prompt, inp_prompt=None): | |
table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema]) | |
prompt = f"""Below are SQL table schemas paired with instructions that describe a task. | |
Using valid SQLite, write a response that appropriately completes the request for the provided tables. | |
### Instruction: {user_prompt} ### | |
Input: CREATE TABLE sql_pdf({table_columns}); | |
### Response: (Return only generated query based on user_prompt , nothing extra)""" | |
if inp_prompt is not None: | |
prompt = prompt.replace(user_prompt, inp_prompt + " ") | |
completion = genai_model.generate_content(prompt) | |
generated_query = completion.text | |
# Extract SQL query | |
start_index = generated_query.find("SELECT") | |
end_index = generated_query.find(";", start_index) + 1 | |
if start_index != -1 and end_index != 0: | |
return generated_query[start_index:end_index] | |
return generated_query | |
def execute_query(self, query): | |
cur = self.conn.cursor() | |
cur.execute(query) | |
columns = [header[0] for header in cur.description] | |
rows = [row for row in cur.fetchall()] | |
cur.close() | |
self.conn.commit() | |
return rows, columns | |
def execute_sql_query(input_prompt): | |
database = r"sql_pdf.db" | |
sql_model = SQLPromptModel(database) | |
user_prompt = "Give complete details of properties in India" | |
for _ in range(3): # Retry logic | |
try: | |
table_schema = sql_model.fetch_table_schema("sql_pdf") | |
if table_schema: | |
if input_prompt.strip(): | |
query = sql_model.text2sql_gemini(table_schema, user_prompt, input_prompt) | |
else: | |
query = sql_model.text2sql_gemini(table_schema, user_prompt, user_prompt) | |
rows, columns = sql_model.execute_query(query) | |
return {"Query": query, "Results": rows, "Columns": columns} | |
else: | |
return {"error": "Table schema not found."} | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
time.sleep(1) | |
return {"error": "Failed to execute query after 3 retries."} | |
# Load the image | |
image = Image.open(os.path.join(os.path.abspath(''), "house_excel_sheet.png")) | |
# Create Gradio interface | |
with gr.Blocks(title="House Database Query") as demo: | |
gr.Markdown("# House Database Query System") | |
# Display the image | |
gr.Image(image) | |
gr.Markdown("""### The database contains information about different properties including their fundamental details. | |
You can query this database using natural language.""") | |
with gr.Row(): | |
# Query input and output | |
query_input = gr.Textbox( | |
lines=2, | |
label="Database Query", | |
placeholder="Enter your query or choose from examples below. Default: 'Properties in India'" | |
) | |
with gr.Row(): | |
# Add submit button | |
submit_btn = gr.Button("Submit Query", variant="primary") | |
with gr.Row(): | |
query_output = gr.JSON(label="Query Results") | |
# Connect submit button to the query function | |
submit_btn.click( | |
fn=execute_sql_query, | |
inputs=query_input, | |
outputs=query_output | |
) | |
# Example queries | |
gr.Examples( | |
examples=[ | |
"Properties in France", | |
"Properties greater than an acre", | |
"Properties with more than 400 bedrooms" | |
], | |
inputs=query_input, | |
outputs=query_output, | |
fn=execute_sql_query | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |