from config import *


apiUrl = os.environ['apiUrl']
uploadToken = os.environ['uploadToken']
openId = os.environ['openId']
apiKey = os.environ['apiKey']
Regions = os.environ['Regions']
tokenUrl = os.environ['tokenUrl']
LimitTask = int(os.environ['LimitTask'])


proj_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.join(proj_dir, 'Datas')
tmpFolder = os.path.join(proj_dir, 'tmp')
os.makedirs(tmpFolder, exist_ok=True)



def load_pkl(path):
    with open(path, 'rb') as f:
        return pickle.load(f)

def save_pkl(data, path, reweite=False):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    if not os.path.exists(path) or reweite: # 不存在或者强制重写
        with open(path,'wb') as file:
          pickle.dump(data, file, protocol=4)
        return data
    else:
        load_data = load_pkl(path)
        for k in data:
            load_data[k] = data[k]
        save_pkl(load_data, path, reweite=True)
        return load_data

def checkToken(token):
    '''
        输入token,检查token是否合法
    '''
    # return False
    params = {'trans_type':'hf_space', 'cost_credits':1}
    headers = {"Authorization": f"Bearer {token}"}
    session = requests.session()
    ret = requests.post(f"{tokenUrl}", data=json.dumps(params), headers=headers)
    print(ret)
    res = False
    if ret.status_code==200:
        if 'left_credits' in ret.json():
            res = (ret.json()['left_credits'])>0
    else:
        print(ret.json(), ret.status_code, 'call token failed')
    return res

class UserRecorder(object):

    def __init__(self, ):
        super(UserRecorder, self).__init__()
        record_dir = os.path.join(data_dir, f'UserRecord_{taskType}')
        self.ip_dir = os.path.join(record_dir, 'Ips')
        self.token_dir = os.path.join(record_dir, 'Tokens')
        os.makedirs(self.ip_dir, exist_ok=True)
        os.makedirs(self.token_dir, exist_ok=True)

    def save_record(self, taskRes, ip="", token=""):
        if len(ip)==0 and len(token)==0: return
        if len(token)==0: # token优先
            record_path = os.path.join(self.ip_dir, f'{ip}.pkl')
        else:
            record_path = os.path.join(self.token_dir, f'{token}.pkl')
        taskId = taskRes['id']
        status = taskRes['status']
        if 'output' in taskRes:
            input1 = taskRes['output']['job_results']['input1']
            output1 = taskRes['output']['job_results']['output1']
        else:
            input1, output1 = None, None
        data = OrderedDict()
        data[taskId] = {'input1':input1, 'output1':output1, 'status':status, }
        save_data = save_pkl(data, record_path, reweite=False)
        return save_data

    def check_record(self, ip="", token=""):
        if len(token)>0:
            token_valid = checkToken(token)
            if token_valid:
                return True, ""
            else:
                return False, "faild, api key is invalid"
        else:
            _, total_n, _ = self.get_record(ip=ip, token=token)
            if total_n>=LimitTask:
                return False, no_more_attempts
            else:
                return True, ""

    def get_record(self, ip="", token=""):
        if len(ip)==0 and len(token)==0: return
        if len(token)==0:
            identity = ip
            record_path = os.path.join(self.ip_dir, f'{ip}.pkl')
        else:
            identity = token
            record_path = os.path.join(self.token_dir, f'{token}.pkl')
        if os.path.exists(record_path):
            record_data = load_pkl(record_path)
        else:
            record_data = {}
        total_n = len(record_data)
        success_n, fail_n, process_n = 0, 0, 0
        shows = [None]*6
        show_i = 0
        for key in reversed(record_data):
            status = record_data[key]['status']
            if status in ['FAILED', 'CANCELLED', 'TIMED_OUT', ]:
                fail_n += 1

            elif status in ['COMPLETED', ]:
                success_n += 1
                if record_data[key]['input1'] is not None:
                    input1 = record_data[key]['input1']
                    output1 = record_data[key]['output1']
                    if show_i<=2:
                        shows[show_i*2] = f"<img src=\"{input1}\" >"
                        shows[show_i*2+1] = f"<img src=\"{output1}\" >"
                        show_i += 1
            elif status in ['IN_QUEUE', 'IN_PROGRESS', 'IN_QUEUE', ]:
                process_n += 1

        msg = f"Dear {identity}, You have {total_n} tasks, {success_n} successed, {fail_n} failed, {process_n} processing,  "

        return shows, total_n, msg

        
def get_temps_examples(taskType):
    temp_dir = os.path.join(data_dir, f'task{taskType}/temps')
    examples = []
    if not os.path.exists(temp_dir): return []
    files = [f for f in sorted(os.listdir(temp_dir)) if '.' in f]
    for f in files:
        temp_name = f.split(".")[0]
        if len(temp_name)==0: continue
        temp_path = os.path.join(temp_dir, f)
        examples.append([temp_path])
    examples = examples[::-1]
    return examples

def get_user_examples(taskType):
    user_dir = os.path.join(data_dir, f'task{taskType}/srcs')
    examples = []
    if not os.path.exists(user_dir): return []
    files = [f for f in sorted(os.listdir(user_dir)) if '.' in f]
    for f in files:
        user_id = f.split(".")[0]
        if len(user_id)==0: continue
        user_path = os.path.join(user_dir, f)
        examples.append([user_path])
    return examples

def get_showcase_examples(taskType):
    examples = []
    if taskType=="3":
        examples=[
            ["task3/temps/flow-water.jpg", "task3/srcs/src01.jpg", "task3/showcases/src01_flower-water.jpg"],
            ["task3/temps/mountain-water.jpg", "task3/srcs/src01.jpg", "task3/showcases/src01_mountain-water.jpg"],
            ["task3/temps/rock-on-water.jpg", "task3/srcs/src01.jpg", "task3/showcases/src01_rock-on-water.jpg"],
        ]
    elif taskType=="4":
        examples=[
            ["task4/temps/Vivienne.jpg", "task4/srcs/src02.jpg", "task4/showcases/src02_vivienne.jpg"],
            ["task4/temps/Bella.jpg", "task4/srcs/src04.jpg", "task4/showcases/src04_balle.jpg"],
            ["task4/temps/Nia.jpg", "task4/srcs/src02.jpg", "task4/showcases/src02_nia.jpg"],
            ["task4/temps/Leo.jpg", "task4/srcs/src03.jpg", "task4/showcases/src03_male.jpg"],
        ]
    elif taskType=="6":
        examples=[
            ["task6/temps/niantu.jpg", "task6/srcs/src01.jpg", "task6/showcases/src01_niantu.jpg"],
            ["task6/temps/3d-shouban.jpg", "task6/srcs/src02.jpg", "task6/showcases/src02_shouban.jpg"],
        ]

    elif taskType=="5":
        examples=[
            ["task5/temps/caption.jpg", "task5/srcs/src01.jpg", "task5/showcases/src01_street.jpg"],
            ["task5/temps/caption.jpg", "task5/srcs/src01.jpg", "task5/showcases/src01_walk.jpg"],
        ]

    elif taskType=="1":
        examples=[
            ["task1/temps/caption.jpg", "task1/srcs/src01.jpg", "task1/showcases/src01_seg.png"],
        ]
    elif taskType=="2":
        examples=[
            ["task2/temps/caption.jpg", "task2/srcs/street.webp", "task2/showcases/out1.jpg"],
        ]
    elif taskType=="7":
        examples=[
            ["task7/temps/task7.webp", "task7/srcs/305.jpg", "task7/showcases/task7out.jpg"],
        ]
    elif taskType=="9":
        examples=[
            ["task9/temps/caption.jpg", "task9/srcs/use1.jpg", "task9/showcases/show0.jpg"],
            ["task9/temps/caption.jpg", "task9/srcs/use2.jpg", "task9/showcases/show1.webp"],
            ["task9/temps/caption.jpg", "task9/srcs/use3.jpg", "task9/showcases/show2.webp"],
        ]

    for i in range(len(examples)):
        for j in range(len(examples[i])):
            examples[i][j] = os.path.join(data_dir, examples[i][j])
            assert os.path.exists(examples[i][j]), examples[i][j]
    return examples

def get_result_example(cloth_id, pose_id):
    result_dir = os.path.join(data_dir, 'ResultImgs')
    res_path = os.path.join(result_dir, f"{cloth_id}_{pose_id}.jpg")
    return res_path

def upload_user_img_mask(clientIp, img, mask=None, taskType='1'):
    if taskType in ['8', '9']:
        return "img", ''
    timeId = int(  str(time.time()).replace(".", "")  )+random.randint(1000, 9999)
    fileName = clientIp.replace(".", "")+str(timeId)+".jpg"
    local_path = os.path.join(tmpFolder, fileName)
    filemName = clientIp.replace(".", "")+str(timeId)+"_m.jpg"
    localm_path = os.path.join(tmpFolder, filemName)
    cv2.imwrite(local_path, img[:,:,::-1].astype(np.uint8))
    if mask is not None:
        cv2.imwrite(localm_path, mask)
    params = {'token':uploadToken, 'input1':fileName, 'input2':filemName}
    session = requests.session()
    ret = requests.post(f"{apiUrl}/upload", data=json.dumps(params))
    res = ""

    head_dict = {'jpg': 'image/jpeg', 'jpeg': 'image/jpeg', 'png': 'image/png'}
    ftype = (os.path.basename(local_path).split(".")[-1]).lower()
    ctype = head_dict[ftype]
    headers = {"Content-Type": ctype}

    uploadm_url = ''
    if ret.status_code==200:
        if 'upload1' in ret.json():
            upload_url = ret.json()['upload1']
            with open(local_path, 'rb') as file:
                response = requests.put(upload_url, data=file, headers=headers)
                if response.status_code == 200:
                    res = upload_url
                else:
                    print(response)
            if mask is not None:
                uploadm_url = ret.json()['upload2']
                with open(localm_path, 'rb') as file:
                    response = requests.put(uploadm_url, data=file, headers=headers)
                    if response.status_code == 200:
                        pass
                    else:
                        uploadm_url = ''
                        print(response)              
    else:
        print(ret.json(), ret.status_code, 'call upload failed')
    if os.path.exists(local_path): os.remove(local_path)
    if os.path.exists(localm_path): os.remove(localm_path)
    return res, uploadm_url


def publicSelfitTask(image, mask, temp_image, caption_text, param4_text, param5_text):
    temp_name = os.path.basename(temp_image).split('.')[0]
    params = {'openId':openId, 'apiKey':apiKey, 'image':image, 'mask':mask,
        "image_type":"2", "task_type":taskType, 'param1':temp_name, 
        'param2':str(caption_text), 'param3':"1", 'param4':param4_text, 'param5':param5_text}
    session = requests.session()
    ret = requests.post(f"{apiUrl}/public", data=json.dumps(params))
    print(ret)
    if ret.status_code==200:
        if 'id' in ret.json():
            # print(ret.json())
            return ret.json()['id']
    else:
        print(ret.json(), ret.status_code, 'call public failed')

def getTaskRes(taskId, taskType):
    params = {'id':taskId, 'task_type':taskType}
    session = requests.session()
    ret = requests.post(f"{apiUrl}/status", data=json.dumps(params))
    if ret.status_code==200:
        if 'status' in ret.json():
            return ret.json()
    else:
        print(ret.json(), ret.status_code, 'call status failed')
        return None

@func_timeout.func_set_timeout(10)
def check_region(ip):
    session = requests.session()
    ret = requests.get(f"https://webapi-pc.meitu.com/common/ip_location?ip={ip}")
    for k in ret.json()['data']:
        nat = ret.json()['data'][k]['nation']
        if nat in Regions:
            print(nat, 'invalid')
            return False
        else:
            print(nat, 'valid')
    return True
def check_region_warp(ip):
    try:
        return check_region(ip)
    except Exception as e:
        print(e)
        return True