File size: 5,318 Bytes
711bc31
 
 
 
 
8f369fe
 
711bc31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f369fe
 
711bc31
8f369fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711bc31
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import asyncio
from concurrent.futures import ThreadPoolExecutor
import duckdb
import pandas as pd
import os
import requests
import tempfile

def find_indicator_column(table: str, indicator_columns_per_table: dict[str,str]) -> str:
    """Retrieves the name of the indicator column within a table.
    
    This function maps table names to their corresponding indicator columns
    using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
    
    Args:
        table (str): Name of the table in the database
        
    Returns:
        str: Name of the indicator column for the specified table
        
    Raises:
        KeyError: If the table name is not found in the mapping
    """
    print(f"---- Find indicator column in table {table} ----")
    return indicator_columns_per_table[table]

async def execute_sql_query(sql_query: str) -> pd.DataFrame:
    """Executes a SQL query on the DRIAS database and returns the results.
    
    This function connects to the DuckDB database containing DRIAS climate data
    and executes the provided SQL query. It handles the database connection and
    returns the results as a pandas DataFrame.
    
    Args:
        sql_query (str): The SQL query to execute
        
    Returns:
        pd.DataFrame: A DataFrame containing the query results
        
    Raises:
        duckdb.Error: If there is an error executing the SQL query
    """
    def _execute_query():
        # Execute the query
        con = duckdb.connect()
        
        # Try to use Hugging Face authentication if token is available
        HF_TTD_TOKEN = os.getenv("HF_TTD_TOKEN")
        
        try:
            if HF_TTD_TOKEN:
                # Set up Hugging Face authentication - updated syntax
                con.execute(f"""
                    CREATE SECRET IF NOT EXISTS hf_token (
                        TYPE HUGGINGFACE,
                        TOKEN '{HF_TTD_TOKEN}'
                    );
                """)
                print("Hugging Face authentication configured")
            
            # Execute the query
            results = con.execute(sql_query).fetchdf()
            return results
                
        except duckdb.HTTPException as e:
            print(f"HTTP error accessing Hugging Face dataset: {e}")
            
            # If we have a token but still get HTTP error, try without authentication
            if HF_TTD_TOKEN:
                print("Retrying without authentication...")
                try:
                    # Create a new connection without the secret
                    con_no_auth = duckdb.connect()
                    results = con_no_auth.execute(sql_query).fetchdf()
                    return results
                except Exception as e2:
                    print(f"Also failed without authentication: {e2}")
            
            # Try to download the file locally and retry
            print("Trying to download file locally and retry...")
            
            # Extract the URL from the error message or construct it from the query
            error_str = str(e)
            url = None
            
            if "HTTP GET error on '" in error_str:
                url = error_str.split("HTTP GET error on '")[1].split("'")[0]
            else:
                # Try to extract URL from the SQL query
                import re
                url_match = re.search(r"'(https://huggingface\.co/[^']+)'", sql_query)
                if url_match:
                    url = url_match.group(1)
            
            if url:
                table_name = url.split('/')[-1]
                local_path = os.path.join(tempfile.gettempdir(), table_name)
                print(f"Downloading {url} to {local_path}")
                
                # Add authentication headers if token is available
                headers = {}
                if HF_TTD_TOKEN:
                    headers['Authorization'] = f'Bearer {HF_TTD_TOKEN}'
                
                response = requests.get(url, headers=headers, stream=True)
                if response.status_code == 200:
                    with open(local_path, 'wb') as f:
                        for chunk in response.iter_content(chunk_size=8192):
                            f.write(chunk)
                    
                    # Modify the SQL query to use the local file
                    modified_sql = sql_query.replace(f"'{url}'", f"'{local_path}'")
                    results = con.execute(modified_sql).fetchdf()
                    return results
                elif response.status_code == 401:
                    print("Authentication failed - check your HF_TTD_TOKEN")
                    raise Exception("Authentication failed. Please check your HF_TTD_TOKEN environment variable.")
                else:
                    print(f"Failed to download file: {response.status_code}")
                    raise e
            else:
                print("Could not extract URL from error message")
                raise e
        
        except Exception as e:
            print(f"Unexpected error: {e}")
            raise e

    # Run the query in a thread pool to avoid blocking
    loop = asyncio.get_event_loop()
    with ThreadPoolExecutor() as executor:
        return await loop.run_in_executor(executor, _execute_query)