import os import torch import duckdb import spaces import lancedb import gradio as gr import pandas as pd import pyarrow as pa from langchain import hub from langsmith import traceable from sentence_transformers import SentenceTransformer from langchain_huggingface.llms import HuggingFacePipeline from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline # Height of the Tabs Text Area TAB_LINES = 8 #----------CONNECT TO DATABASE---------- md_token = os.getenv('MD_TOKEN') conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) #--------------------------------------- if torch.cuda.is_available(): device = torch.device("cuda") print(f"Using GPU: {torch.cuda.get_device_name(device)}") else: device = torch.device("cpu") print("Using CPU") #--------------------------------------- #--------------LanceDB------------- lance_db = lancedb.connect( uri=os.getenv('lancedb_uri'), api_key=os.getenv('lancedb_api_key'), region=os.getenv('lancedb_region') ) lance_schema = pa.schema([ pa.field("vector", pa.list_(pa.float32())), pa.field("sql-query", pa.utf8()) ]) try: table = lance_db.create_table(name="SQL-Queries", schema=lance_schema) except: table = lance_db.open_table(name="SQL-Queries") #--------------------------------------- #-------LOAD HUGGINGFACE PIPELINE------- tokenizer = AutoTokenizer.from_pretrained("defog/llama-3-sqlcoder-8b") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type= "nf4") model = AutoModelForCausalLM.from_pretrained("defog/llama-3-sqlcoder-8b", quantization_config=quantization_config, device_map="auto", torch_dtype=torch.bfloat16) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, return_full_text=False) hf = HuggingFacePipeline(pipeline=pipe) #--------------------------------------- #-----LOAD PROMPT FROM LANCHAIN HUB----- prompt = hub.pull("sql-agent-prompt") #--------------------------------------- #-----LOAD EMBEDDING MODEL----- embedding_model = SentenceTransformer("all-MiniLM-L6-v2", device=device) #--------------------------------------- #--------------ALL UTILS---------------- # Get Databases def get_schemas(): schemas = conn.execute(""" SELECT DISTINCT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('information_schema', 'pg_catalog') """).fetchall() return [item[0] for item in schemas] # Get Tables def get_tables(schema_name): tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall() return [table[0] for table in tables] # Update Tables def update_tables(schema_name): tables = get_tables(schema_name) return gr.update(choices=tables) # Get Schema def get_table_schema(table): result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df() ddl_create = result.iloc[0,0] parent_database = result.iloc[0,1] schema_name = result.iloc[0,2] full_path = f"{parent_database}.{schema_name}.{table}" if schema_name != "main": old_path = f"{schema_name}.{table}" else: old_path = table ddl_create = ddl_create.replace(old_path, full_path) return ddl_create # Get Prompt def get_prompt(schema, query_input): return prompt.format(schema=schema, query_input=query_input) @spaces.GPU(duration=60) @traceable() def generate_sql(prompt): result = hf.invoke(prompt) return result.strip() @spaces.GPU(duration=10) def embed_query(sql_query): print(f'Creating Emebeddings {sql_query}') if sql_query is not None: embeddings = embedding_model.encode(sql_query, normalize_embeddings=True).tolist() return embeddings def log2lancedb(embeddings, sql_query): data = [{ "sql-query": sql_query, "vector": embeddings }] table.add(data) print(f'Added to Lance DB.') #--------------------------------------- # Generate SQL def text2sql(table, query_input): if table is None: return { table_schema: "", input_prompt: "", generated_query: "", result_output:pd.DataFrame([{"error": "❌ Please Select Table, Schema.}"}]) } schema = get_table_schema(table) print(f'Schema Generated...') prompt = get_prompt(schema, query_input) print(f'Prompt Generated...') try: print(f'Generating SQL... {model.device}') result = generate_sql(prompt) print('SQL Generated...') except Exception as e: return { table_schema: schema, input_prompt: prompt, generated_query: "", result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}]) } try: embeddings = embed_query(result) log2lancedb(embeddings, result) except Exception as e: print("Error Generating and Logging Embeddings...") print(e) try: query_result = conn.sql(result).df() except Exception as e: return { table_schema: schema, input_prompt: prompt, generated_query: result, result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}]) } return { table_schema: schema, input_prompt: prompt, generated_query: result, result_output:query_result } # Custom CSS styling custom_css = """ .gradio-container { background-color: #f0f4f8; } .logo { max-width: 200px; margin: 20px auto; display: block; } .gr-button { background-color: #4a90e2 !important; } .gr-button:hover { background-color: #3a7bc8 !important; } """ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo: gr.Image("logo.png", label=None, show_label=False, container=False, height=100) gr.Markdown("""