{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))\n", "\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "from climateqa.engine.talk_to_data.main import ask_vanna\n", "\n", "import sqlite3\n", "import os\n", "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from climateqa.engine.talk_to_data.myVanna import MyVanna\n", "from climateqa.engine.talk_to_data.utils import loc2coords, detect_location_with_openai, detectTable, nearestNeighbourSQL, detect_relevant_tables, replace_coordonates#,nearestNeighbourPostgres\n", "\n", "from climateqa.engine.llm import get_llm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Vanna Ask\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from dotenv import load_dotenv\n", "\n", "load_dotenv()\n", "\n", "llm = get_llm(provider=\"openai\")\n", "\n", "OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')\n", "PC_API_KEY = os.getenv('VANNA_PINECONE_API_KEY')\n", "INDEX_NAME = os.getenv('VANNA_INDEX_NAME')\n", "VANNA_MODEL = os.getenv('VANNA_MODEL')\n", "\n", "ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))\n", "\n", "#Vanna object\n", "vn = MyVanna(config = {\"temperature\": 0, \"api_key\": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, \"top_k\" : 4})\n", "\n", "db_vanna_path = ROOT_PATH + \"/data/drias/drias.db\"\n", "vn.connect_to_sqlite(db_vanna_path)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# User query" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "query = \"Quelle sera la température à Marseille sur les prochaines années ?\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Detect location" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "location = detect_location_with_openai(OPENAI_API_KEY, query)\n", "print(location)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convert location to longitude, latitude coordonate" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "coords = loc2coords(location)\n", "user_input = query.lower().replace(location.lower(), f\"lat, long : {coords}\")\n", "print(user_input)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Find closest coordonates and replace lat,lon\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "relevant_tables = detect_relevant_tables(user_input, llm) \n", "coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]\n", "user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)\n", "print(user_input_with_coords)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Ask Vanna with correct coordonates" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "user_input_with_coords" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)\n", "print(result_dataframe.head())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "result_dataframe" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "figure" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "climateqa", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 }