class HebEMO:
    def __init__(self, device=-1, emotions = ['anticipation', 'joy', 'trust', 'fear', 'surprise', 'anger',
      'sadness', 'disgust']):

        from transformers import pipeline
        from tqdm import tqdm
        self.device = device
        self.emotions = emotions
        self.hebemo_models = {}

        for emo in tqdm(emotions): 
            self.hebemo_models[emo] = pipeline(
                device = self.device #-1 run on CPU, else - device ID

    def hebemo(self, text = None, input_path=False, save_results=False, read_lines=False, plot=False):
        text (str): a text or list of text to analyze
        input_path(str): the path to the text file (txt file, each row for different instance)
        returns pandas DataFrame of the analyzed texts and save it to the same dir of the input file
        from pyplutchik import plutchik
        from spider_plot import spider_plot
        import matplotlib.pyplot as plt
        import pandas as pd
        import time
        import torch
        from tqdm import tqdm

        if text is None and type(input_path) is str:
            # read the file
            with open(input_path, encoding='utf8') as p:
                txt = p.readlines()

        elif text is not None and (input_path is None or input_path is False):
            if type(text) is str:
                if read_lines:
                    txt = text.split('\n')
                    txt = [text]
            elif type(text) is list:
                txt = text
                raise ValueError('text should be text or list of text.')
            raise ValueError('you should provide a text string, list of strings or text path.')

        # run hebEMO
        hebEMO_df = pd.DataFrame(txt) 
        for emo in tqdm(self.emotions): 
            x = self.hebemo_models[emo](txt)
            hebEMO_df = hebEMO_df.join(pd.DataFrame(x).rename(columns = {'label': emo, 'score':'confidence_'+emo}))
            del x
        hebEMO_df = hebEMO_df.applymap(lambda x: 0 if x=='LABEL_0' else 1 if x=='LABEL_1' else x)

        if save_results is not False:
            gen_name = str(int(time.time()*1e7))
            if type(save_results) is str:      
                hebEMO_df.to_csv(save_results+'/'+gen_name+'_heEMOed.csv', encoding='utf8')
                hebEMO_df.to_csv(gen_name+'_heEMOed.csv', encoding='utf8')

        if plot:
            hebEMO = pd.DataFrame()
            for emo in hebEMO_df.columns[1::2]:
                hebEMO[emo] = abs(hebEMO_df[emo]-(1-hebEMO_df['confidence_'+emo]))

            for i in range(0,1):    
                try: ax = plutchik(hebEMO.to_dict(orient='records')[i])
                except: ax = spider_plot(hebEMO)
            return (hebEMO_df[0][i], ax)
            return (hebEMO_df)
# HebEMO_model = HebEMO()