import json
import sqlite3
from contextlib import asynccontextmanager
from fastapi import FastAPI, Query, HTTPException
from typing import List, Optional
from pydantic import BaseModel
from data_loader import refresh_data
import numpy as np
from pandas import Timestamp


def get_db_connection():
    conn = sqlite3.connect("datasets.db")
    conn.row_factory = sqlite3.Row
    return conn


def setup_database():
    conn = get_db_connection()
    c = conn.cursor()
    c.execute("""CREATE TABLE IF NOT EXISTS datasets
                 (hub_id TEXT PRIMARY KEY, 
                  likes INTEGER,
                  downloads INTEGER,
                  tags TEXT,
                  created_at INTEGER,
                  last_modified INTEGER,
                  license TEXT,
                  language TEXT,
                  config_name TEXT,
                  column_names TEXT,
                  features TEXT)""")
    c.execute("CREATE INDEX IF NOT EXISTS idx_column_names ON datasets (column_names)")
    conn.commit()
    conn.close()


def serialize_numpy(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if isinstance(obj, np.integer):
        return int(obj)
    if isinstance(obj, np.floating):
        return float(obj)
    if isinstance(obj, Timestamp):
        return int(obj.timestamp())
    raise TypeError(f"Object of type {type(obj)} is not JSON serializable")


def insert_data(conn, data):
    c = conn.cursor()

    created_at = data.get("created_at", 0)
    if isinstance(created_at, Timestamp):
        created_at = int(created_at.timestamp())

    last_modified = data.get("last_modified", 0)
    if isinstance(last_modified, Timestamp):
        last_modified = int(last_modified.timestamp())

    c.execute(
        """
        INSERT OR REPLACE INTO datasets 
        (hub_id, likes, downloads, tags, created_at, last_modified, license, language, config_name, column_names, features) 
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
    """,
        (
            data["hub_id"],
            data.get("likes", 0),
            data.get("downloads", 0),
            json.dumps(data.get("tags", []), default=serialize_numpy),
            created_at,
            last_modified,
            json.dumps(data.get("license", []), default=serialize_numpy),
            json.dumps(data.get("language", []), default=serialize_numpy),
            data.get("config_name", ""),
            json.dumps(data.get("column_names", []), default=serialize_numpy),
            json.dumps(data.get("features", []), default=serialize_numpy),
        ),
    )
    conn.commit()


@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup: Load data into the database
    setup_database()
    conn = get_db_connection()
    datasets = refresh_data()
    for data in datasets:
        insert_data(conn, data)
    conn.close()
    yield
    # Shutdown: You can add any cleanup operations here if needed
    # For example, closing database connections, clearing caches, etc.


app = FastAPI(lifespan=lifespan)


class SearchResponse(BaseModel):
    total: int
    page: int
    page_size: int
    results: List[dict]


@app.get("/search", response_model=SearchResponse)
async def search_datasets(
    columns: List[str] = Query(...),
    match_all: bool = Query(False),
    page: int = Query(1, ge=1),
    page_size: int = Query(10, ge=1, le=1000),
):
    offset = (page - 1) * page_size
    conn = get_db_connection()
    c = conn.cursor()

    try:
        if match_all:
            query = """
            SELECT COUNT(*) as total FROM datasets
            WHERE (SELECT COUNT(*) FROM json_each(column_names)
                   WHERE value IN ({})) = ?
            """.format(",".join("?" * len(columns)))
            c.execute(query, (*columns, len(columns)))
        else:
            query = """
            SELECT COUNT(*) as total FROM datasets
            WHERE EXISTS (
                SELECT 1 FROM json_each(column_names)
                WHERE value IN ({})
            )
            """.format(",".join("?" * len(columns)))
            c.execute(query, columns)

        total = c.fetchone()["total"]

        if match_all:
            query = """
            SELECT * FROM datasets
            WHERE (SELECT COUNT(*) FROM json_each(column_names)
                   WHERE value IN ({})) = ?
            LIMIT ? OFFSET ?
            """.format(",".join("?" * len(columns)))
            c.execute(query, (*columns, len(columns), page_size, offset))
        else:
            query = """
            SELECT * FROM datasets
            WHERE EXISTS (
                SELECT 1 FROM json_each(column_names)
                WHERE value IN ({})
            )
            LIMIT ? OFFSET ?
            """.format(",".join("?" * len(columns)))
            c.execute(query, (*columns, page_size, offset))

        results = [dict(row) for row in c.fetchall()]

        for result in results:
            result["tags"] = json.loads(result["tags"])
            result["license"] = json.loads(result["license"])
            result["language"] = json.loads(result["language"])
            result["column_names"] = json.loads(result["column_names"])
            result["features"] = json.loads(result["features"])

        return SearchResponse(
            total=total, page=page, page_size=page_size, results=results
        )

    except sqlite3.Error as e:
        raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") from e
    finally:
        conn.close()


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)