|
import asyncio |
|
from concurrent.futures import ThreadPoolExecutor |
|
import duckdb |
|
import pandas as pd |
|
import os |
|
import requests |
|
import tempfile |
|
|
|
def find_indicator_column(table: str, indicator_columns_per_table: dict[str,str]) -> str: |
|
"""Retrieves the name of the indicator column within a table. |
|
|
|
This function maps table names to their corresponding indicator columns |
|
using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE. |
|
|
|
Args: |
|
table (str): Name of the table in the database |
|
|
|
Returns: |
|
str: Name of the indicator column for the specified table |
|
|
|
Raises: |
|
KeyError: If the table name is not found in the mapping |
|
""" |
|
print(f"---- Find indicator column in table {table} ----") |
|
return indicator_columns_per_table[table] |
|
|
|
async def execute_sql_query(sql_query: str) -> pd.DataFrame: |
|
"""Executes a SQL query on the DRIAS database and returns the results. |
|
|
|
This function connects to the DuckDB database containing DRIAS climate data |
|
and executes the provided SQL query. It handles the database connection and |
|
returns the results as a pandas DataFrame. |
|
|
|
Args: |
|
sql_query (str): The SQL query to execute |
|
|
|
Returns: |
|
pd.DataFrame: A DataFrame containing the query results |
|
|
|
Raises: |
|
duckdb.Error: If there is an error executing the SQL query |
|
""" |
|
def _execute_query(): |
|
|
|
con = duckdb.connect() |
|
|
|
|
|
HF_TTD_TOKEN = os.getenv("HF_TTD_TOKEN") |
|
|
|
try: |
|
if HF_TTD_TOKEN: |
|
|
|
con.execute(f""" |
|
CREATE SECRET IF NOT EXISTS hf_token ( |
|
TYPE HUGGINGFACE, |
|
TOKEN '{HF_TTD_TOKEN}' |
|
); |
|
""") |
|
print("Hugging Face authentication configured") |
|
|
|
|
|
results = con.execute(sql_query).fetchdf() |
|
return results |
|
|
|
except duckdb.HTTPException as e: |
|
print(f"HTTP error accessing Hugging Face dataset: {e}") |
|
|
|
|
|
if HF_TTD_TOKEN: |
|
print("Retrying without authentication...") |
|
try: |
|
|
|
con_no_auth = duckdb.connect() |
|
results = con_no_auth.execute(sql_query).fetchdf() |
|
return results |
|
except Exception as e2: |
|
print(f"Also failed without authentication: {e2}") |
|
|
|
|
|
print("Trying to download file locally and retry...") |
|
|
|
|
|
error_str = str(e) |
|
url = None |
|
|
|
if "HTTP GET error on '" in error_str: |
|
url = error_str.split("HTTP GET error on '")[1].split("'")[0] |
|
else: |
|
|
|
import re |
|
url_match = re.search(r"'(https://huggingface\.co/[^']+)'", sql_query) |
|
if url_match: |
|
url = url_match.group(1) |
|
|
|
if url: |
|
table_name = url.split('/')[-1] |
|
local_path = os.path.join(tempfile.gettempdir(), table_name) |
|
print(f"Downloading {url} to {local_path}") |
|
|
|
|
|
headers = {} |
|
if HF_TTD_TOKEN: |
|
headers['Authorization'] = f'Bearer {HF_TTD_TOKEN}' |
|
|
|
response = requests.get(url, headers=headers, stream=True) |
|
if response.status_code == 200: |
|
with open(local_path, 'wb') as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
|
|
|
|
modified_sql = sql_query.replace(f"'{url}'", f"'{local_path}'") |
|
results = con.execute(modified_sql).fetchdf() |
|
return results |
|
elif response.status_code == 401: |
|
print("Authentication failed - check your HF_TTD_TOKEN") |
|
raise Exception("Authentication failed. Please check your HF_TTD_TOKEN environment variable.") |
|
else: |
|
print(f"Failed to download file: {response.status_code}") |
|
raise e |
|
else: |
|
print("Could not extract URL from error message") |
|
raise e |
|
|
|
except Exception as e: |
|
print(f"Unexpected error: {e}") |
|
raise e |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
with ThreadPoolExecutor() as executor: |
|
return await loop.run_in_executor(executor, _execute_query) |
|
|
|
|