from sentence_transformers import SentenceTransformer, util as st_util
from transformers import CLIPModel, CLIPProcessor

from PIL import Image
import requests
import os
import torch
torch.set_printoptions(precision=10)
from tqdm import tqdm
import s3fs
from io import BytesIO
import vector_db

"sentence-transformer-clip-ViT-L-14"
"openai-clip"
model_names = ["fashion"]

model_name_to_ids = {
    "sentence-transformer-clip-ViT-L-14": "clip-ViT-L-14",
    "fashion": "patrickjohncyh/fashion-clip",
    "openai-clip": "openai/clip-vit-base-patch32",
}

AWS_ACCESS_KEY_ID = os.environ["AWS_ACCESS_KEY_ID"]
AWS_SECRET_ACCESS_KEY = os.environ["AWS_SECRET_ACCESS_KEY"]

# Define your bucket and dataset name.
S3_BUCKET = "s3://disco-io"

fs = s3fs.S3FileSystem(
    key=AWS_ACCESS_KEY_ID,
    secret=AWS_SECRET_ACCESS_KEY,
)

ROOT_DATA_PATH = os.path.join(S3_BUCKET, 'data')

def get_data_path():
    return os.path.join(ROOT_DATA_PATH, cur_dataset)

def get_image_path():
    return os.path.join(get_data_path(), 'images')

def get_metadata_path():
    return os.path.join(get_data_path(), 'metadata')

def get_embeddings_path():
    return os.path.join(get_metadata_path(), cur_dataset + '_embeddings.pq')

model_dict = dict()


def download_to_s3(url, s3_path):
    # Download the file from the URL
    response = requests.get(url, stream=True)
    response.raise_for_status()

    # Upload the file to the S3 path
    with fs.open(s3_path, "wb") as s3_file:
        for chunk in response.iter_content(chunk_size=8192):
            s3_file.write(chunk)


def remove_all_files_from_s3_directory(s3_directory):
    # List all objects in the S3 directory
    objects = fs.ls(s3_directory)

    # Remove each object
    for obj in objects:
        try:
            fs.rm(obj)
        except:
            print('Error removing file: ' + obj)

def download_images(df, img_folder):
    remove_all_files_from_s3_directory(img_folder)
    for index, row in df.iterrows():
        try:
            download_to_s3(row['IMG_URL'], os.path.join(img_folder,
                                                        row['title'].replace('/', '_').replace('\n', '') + '.jpg'))
        except:
            print('Error downloading image: ' + str(index) + row['title'])


def load_models():
    for model_name in model_name_to_ids:
        if model_name not in model_dict:
            model_dict[model_name] = dict()
            if model_name.startswith('sentence-transformer'):
                model_dict[model_name]['model'] = SentenceTransformer(model_name_to_ids[model_name])
            else:
                model_dict[model_name]['hf_dir'] = model_name_to_ids[model_name]
                model_dict[model_name]['model'] = CLIPModel.from_pretrained(model_name_to_ids[model_name])
                model_dict[model_name]['processor'] = CLIPProcessor.from_pretrained(model_name_to_ids[model_name])


if len(model_dict) == 0:
    print('Loading models...')
    load_models()


def get_image_embedding(model_name, image):
    """
    Takes an image as input and returns an embedding vector.
    """
    model = model_dict[model_name]['model']
    if model_name.startswith('sentence-transformer'):
        return model.encode(image)
    else:
        inputs = model_dict[model_name]['processor'](images=image, return_tensors="pt")
        image_features = model.get_image_features(**inputs).detach().numpy()[0]
        return image_features

def s3_path_to_image(fs, s3_path):
    """
    Takes an S3 path as input and returns a PIL Image object.

    Args:
        s3_path (str): The path to the image in the S3 bucket, including the bucket name (e.g., "bucket_name/path/to/image.jpg").

    Returns:
        Image: A PIL Image object.
    """
    with fs.open(s3_path, "rb") as f:
        image_data = BytesIO(f.read())
        img = Image.open(image_data)
        return img

def generate_and_save_embeddings():
    # Get image embeddings
    with torch.no_grad():
        for fp in tqdm(fs.ls(get_image_path()), desc="Generate embeddings for Images"):
            if fp.endswith('.jpg'):
                name = fp.split('/')[-1]
                for model_name in model_name_to_ids.keys():
                    s3_path = 's3://' + fp
                    vector_db.add_image_embedding_to_db(
                        embedding=get_image_embedding(model_name, s3_path_to_image(fs, s3_path)),
                        model_name=model_name,
                        dataset_name=cur_dataset,
                        path_to_image=s3_path,
                        image_name=name,
                    )


def get_immediate_subdirectories(s3_path):
    return [obj.split('/')[-1] for obj in fs.glob(f"{s3_path}/*") if fs.isdir(obj)]

all_datasets = get_immediate_subdirectories(ROOT_DATA_PATH)
cur_dataset = all_datasets[0]

def set_cur_dataset(dataset):
    refresh_all_datasets()
    print(f"Setting current dataset to {dataset}")
    global cur_dataset
    cur_dataset = dataset

def refresh_all_datasets():
    global all_datasets
    all_datasets = get_immediate_subdirectories(ROOT_DATA_PATH)
    print(f"Refreshing all datasets: {all_datasets}")

def url_to_image(url):
    try:
        response = requests.get(url)
        response.raise_for_status()
        img = Image.open(BytesIO(response.content))
        return img
    except requests.exceptions.RequestException as e:
        print(f"Error fetching image from URL: {url}")
        return None