import mlflow
import pandas as pd
import joblib
import os

from fastai.data.block import MultiCategoryBlock, RandomSplitter, DataBlock, CategoryBlock
from fastai.vision.data import *
from fastai.vision.learner import *
from fastai.vision.all import *
import pathlib
plt = platform.system()
if plt == 'Linux': pathlib.WindowsPath = pathlib.PosixPath


lesion_model = None
pipe_cancer = None


def f_load_cnn_model():
    global lesion_model
    logged_model = os.path.join("mole_models", "ed52a28a7b504ff7ba851c850221d1dd")
    lesion_model = mlflow.fastai.load_model(logged_model)
    lesion_model.cbs.remove(lesion_model.cbs[6])
    lesion_model.cbs.remove(lesion_model.cbs[4])
    lesion_model.cbs.remove(lesion_model.cbs[3])


def get_image_files(df):
    # df = df.assign(path=lambda x: working_dataset_path + x["path"])
    return df

def get_x(df):
    return df["path"]

def get_y(df):
    return df["dx"]


def f_create_df_with_files_input(file_path):
    list_files = [file_path]
    list_labels = [""]
    df = pd.DataFrame(data={"path": list_files, "dx": list_labels})
    return df


def f_predict_cnn_with_tta(file_path):
    df = f_create_df_with_files_input(file_path)
    dl = lesion_model.dls.test_dl(df)
    dl.after_item = Pipeline([ToTensor, Resize(700, method=ResizeMethod.Crop), RandomResizedCrop(350)])
    pred, _targ = lesion_model.tta(dl=dl, n=4, use_max=False)
    return pred.tolist()[0]


def f_predict_cnn_simple(file_path):
    preds = lesion_model.predict(file_path)
    return preds[2].tolist()


def f_load_cancer_classifier():
    global pipe_cancer
    pipe_cancer = joblib.load(os.path.join("mole_models", "pipeline.rf_classifier_mole"))

def f_predict_cancer(preds, age, sex, localization):
    df = pd.DataFrame(data=[preds], columns=lesion_model.dls.vocab)
    df["age"] = age
    df["sex"] = sex
    df["localization"] = localization
    df["label"] = ""
    label = pipe_cancer.predict(df)[0]
    label = "Possibly suspicious" if label=="cancer" else "benign"
    return label