import requests
from io import BytesIO
import numpy as np
from gensim.models.fasttext import FastText
from scipy import spatial
import itertools
import gdown
import warnings
import nltk
# warnings.filterwarnings('ignore')

import pickle
import pdb
from concurrent.futures import ProcessPoolExecutor

import matplotlib.pyplot as plt
import streamlit as st
import argparse


# NLTK Datasets
nltk.download('wordnet')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

# Average embedding → Compare
def recommend_ingredients(yum, leftovers, n=10):
  '''
  Uses a mean aggregation method 

  :params
  yum -> FastText Word2Vec Obj
  leftovers -> list of str
  n -> int top_n to return  

  :returns
  output -> top_n recommendations
  '''
  leftovers_embedding_sum = np.zeros([32,])
  for ingredient in leftovers:
    # pdb.set_trace()
    ingredient_embedding = yum.get_vector(ingredient, norm=True)
    leftovers_embedding_sum += ingredient_embedding
  leftovers_embedding = leftovers_embedding_sum / len(leftovers) # Embedding for leftovers
  top_matches = yum.similar_by_vector(leftovers_embedding, topn=100)
  top_matches = [(x[0].replace('_',' '), x[1]) for x in top_matches]
  output = [x for x in top_matches if not any(ignore in x[0] for ignore in leftovers)] # Remove boring same item matches, e.g. "romaine lettuce" if leftovers already contain "lettuce".
  return output[:n]

# Compare → Find intersection
def recommend_ingredients_intersect(yum, leftovers, n=10):
  '''
  Finds top combined probabilities
  
  :params
  yum -> FastText Word2Vec Obj
  leftovers -> list of str
  n -> int top_n to return  

  :returns
  output -> top_n recommendations
  '''
  first = True
  for ingredient in leftovers:
    ingredient_embedding = yum.get_vector(ingredient, norm=True)
    ingredient_matches = yum.similar_by_vector(ingredient_embedding, topn=10000)
    ingredient_matches = [(x[0].replace('_',' '), x[1]) for x in ingredient_matches]
    ingredient_output = [x for x in ingredient_matches if not any(ignore in x[0] for ignore in leftovers)] # Remove boring same item matches, e.g. "romaine lettuce" if leftovers already contain "lettuce".
    if first:
      output = ingredient_output
      first = False
    else:
      output = [x for x in output for y in ingredient_output if x[0] == y[0]]
  return output[:n]

def recommend_ingredients_subsets(model, yum,leftovers, subset_size):
  '''
  Returns all subsets from each ingredient 

  :params
  model -> FastText Obj
  yum -> FastText Word2Vec Obj
  leftovers -> list of str
  n -> int top_n to return  

  :returns
  output -> top_n recommendations
  '''
  all_outputs = {}
  for leftovers_subset in itertools.combinations(leftovers, subset_size):
    leftovers_embedding_sum = np.zeros([32,])
    for ingredient in leftovers_subset:
      ingredient_embedding = yum.word_vec(ingredient, use_norm=True)
      leftovers_embedding_sum += ingredient_embedding
    leftovers_embedding = leftovers_embedding_sum / len(leftovers_subset) # Embedding for leftovers
    top_matches = model.similar_by_vector(leftovers_embedding, topn=100)
    top_matches = [(x[0].replace('_',' '), x[1]) for x in top_matches]
    output = [x for x in top_matches if not any(ignore in x[0] for ignore in leftovers_subset)] # Remove boring same item matches, e.g. "romaine lettuce" if leftovers already contain "lettuce".
    all_outputs[leftovers_subset] = output[:10]
  return all_outputs



def filter_adjectives(data):
    '''
    Remove adjectives that are not associated with a food item 

    :params
    data

    :returns
    data
    '''
    recipe_ingredients_token = [nltk.word_tokenize(x) for x in data]
    inds = []
    for i, r in enumerate(recipe_ingredients_token): 
        out = nltk.pos_tag(r)
        out = [x[1] for x in out]
        if len(out) > 1:
            inds.append(int(i))
        elif 'NN' in out or 'NNS' in out:
            inds.append(int(i))
    
    return [data[i] for i in inds]

def plural_to_singular(lemma, recipe): 
  '''
  :params
  lemma -> nltk lemma Obj
  recipe -> list of str

  :returns
  recipe -> converted recipe
  '''
  return [lemma.lemmatize(r) for r in recipe]

def filter_lemma(data):
    '''
    Convert plural to roots

    :params 
    data -> list of lists

    :returns
    data -> returns filtered data
    '''
    # Initialize Lemmatizer (to reduce plurals to stems)
    lemma = nltk.wordnet.WordNetLemmatizer()

    # NOTE: This uses all the computational resources of your computer 
    with ProcessPoolExecutor() as executor: 
        out = list(executor.map(plural_to_singular, itertools.repeat(lemma), data))

    return out


def train_model(data):
    '''
    Train fastfood text 
    NOTE: gensim==4.1.2

    :params
    data -> list of lists of all recipes
    save -> bool 

    :returns 
    model -> FastFood model obj
    '''
    model = FastText(data, vector_size=32, window=99, min_count=5, workers=40, sg=1) # Train model
    
    return model

@st.cache(allow_output_mutation=True)
def load_model(filename='models/fastfood_orig_4.model'):
  '''
  Load the FastText Model
  :params:
  filename -> path to the model 

  :returns
  model -> this is the full FastText obj
  yum -> this is the FastText Word2Vec obj
  '''
  # Load Models

  model = FastText.load(filename)
  yum = model.wv

  return model, yum

@st.cache(allow_output_mutation=True)
def load_data(filename='data/all_recipes_ingredients_lemma.pkl'):
  '''
  Load data
  :params:
  filename -> path to dataset

  :return
  data -> list of all recipes 
  '''
  return pickle.load(open(filename,'rb'))

def plot_results(names, probs, n=5):
  '''
  Plots a bar chart of the names of the items vs. probability of similarity 
  :params:
  names -> list of str 
  probs -> list of float values
  n -> int of how many bars to show NOTE: Max = 100

  :return
  fig -> return figure for plotting 
  '''
  plt.bar(range(len(names)), probs, align='center')
  ax = plt.gca()

  ax.xaxis.set_major_locator(plt.FixedLocator(range(len(names))))
  ax.xaxis.set_major_formatter(plt.FixedFormatter(names))
  ax.set_ylabel('Probability',fontsize='large', fontweight='bold')
  ax.set_xlabel('Ingredients', fontsize='large', fontweight='bold')
  ax.xaxis.labelpad = 10
  ax.set_title(f'FastFood Top {n} Predictions for Leftovers = {st.session_state.leftovers}')
  # mpld3.show()
  fig = plt.gcf()

  return fig


if __name__ == "__main__":
    # Initialize argparse
    # parser = argparse.ArgumentParser()

    # Defaults 
    # data_path = 'data/all_recipes_ingredients_lemma.pkl'
    # model_path = 'models/fastfood_lemma_4.model'
    
    # Arguments
    # parser.add_argument('-d', '--dataset',        default=data_path, type=str,   help="the filepath of the dataset")
    # parser.add_argument('-t', '--train',          default=False, type=bool,   help="the filepath of the dataset")
    # parser.add_argument('-m', '--model',          default=model_path, type=str,   help="the filepath of the dataset")

    # args = parser.parse_args()
    # print(args)

    
    ## Train or Test ## 
    # if args.train:
    #   # Load Dataset

    #   data = load_data(args.dataset) #pickle.load(open(args.dataset, 'rb'))
    #   # model = train_model(data)
    #   # model_path = input("Model filename and directory [eg. models/new_model.model]:   ")
    #   # model.save(model_path)
    # else:
    # gdown.download('https://drive.google.com/uc?id=1fXGsWEbr-1BftKtOsnxc61cM3akMAIC0', 'fastfood.pth')
    # gdown.download('https://drive.google.com/uc?id=1h_TijdSw1K9RT3dnlfIg4xtl8WPNNQmn', 'fastfood.pth.wv.vectors_ngrams.npy')
    model, yum = load_model('fastfood.pth')
    

    ##### UI/UX #####
    ## Sidebar ##
    add_selectbox = st.sidebar.selectbox(
    "Food Utilization App",
    ("FastFood Recommendation Model", "Food Donation Resources", "Contact Team")
    )

    ## Selection Tool ##
    st.multiselect("Select leftovers", list(yum.key_to_index.keys()), default=['bread', 'lettuce'], key="leftovers")
    
    ## Slider ## 
    st.slider("Number of Recommendations", min_value=1, max_value=100, value=5, step=1, key='top_n')

    ## Get food recommendation ##
    out = recommend_ingredients(yum, st.session_state.leftovers, n=st.session_state.top_n)
    names = [o[0] for o in out]
    probs = [o[1] for o in out]

    st.checkbox(label="Show model score", value=False, key="probs")
    if st.session_state.probs:
      st.table(data=out)
    else:
      st.table(data=names)

    ## Plot Results ##
    st.checkbox(label="Show model bar chart", value=False, key="plot")
    if st.session_state.plot:
      fig = plot_results(names, probs, st.session_state.top_n)
      
      ## Show Plot ## 
      st.pyplot(fig)