import gradio as gr
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.repocard import metadata_load
import requests
import re
import pandas as pd
from huggingface_hub import ModelCard
import os 


def pass_emoji(passed):
    if passed is True:
        passed = "✅"
    else:
        passed = "❌"
    return passed

api = HfApi()
USERNAMES_DATASET_ID = "huggingface-course/audio-course-u7-hands-on"
HF_TOKEN = os.environ.get("HF_TOKEN")
U7_USERNAMES = hf_hub_download(USERNAMES_DATASET_ID, repo_type = "dataset", filename="usernames.csv", token=HF_TOKEN)


def get_user_models(hf_username, task):
    """
    List the user's models for a given task
    :param hf_username: User HF username
    """
    models = api.list_models(author=hf_username, filter=[task])
    user_model_ids = [x.modelId for x in models]

    match task:
      case "audio-classification":
        dataset = 'marsyas/gtzan'
      case "automatic-speech-recognition":
        dataset = 'PolyAI/minds14'
      case "text-to-speech":
        dataset = ""
      case _:
        print("Unsupported task")

    dataset_specific_models = []

    if dataset == "":
      return user_model_ids
    else:
        for model in user_model_ids:
          meta = get_metadata(model)
          if meta is None:
              continue
          try:
            if meta["datasets"] == [dataset]:
                dataset_specific_models.append(model)
          except:
            continue
        return dataset_specific_models

def calculate_best_result(user_models, task):
  """
  Calculate the best results of a unit for a given task
  :param user_model_ids: models of a user
  """

  best_model = ""

  if task == "audio-classification":
    best_result = -100
    larger_is_better = True
  elif task == "automatic-speech-recognition":
    best_result = 100
    larger_is_better = False

  for model in user_models:
    meta = get_metadata(model)
    if meta is None:
      continue
    metric = parse_metrics(model, task)

    if larger_is_better:
     if metric > best_result:
      best_result = metric
      best_model = meta['model-index'][0]["name"]
    else:
      if metric < best_result:
        best_result = metric
        best_model = meta['model-index'][0]["name"]

  return best_result, best_model


def get_metadata(model_id):
  """
  Get model metadata (contains evaluation data)
  :param model_id
  """
  try:
    readme_path = hf_hub_download(model_id, filename="README.md")
    return metadata_load(readme_path)
  except requests.exceptions.HTTPError:
    # 404 README.md not found
    return None


def extract_metric(model_card_content, task):
    """
    Extract the metric value from the models' model card
    :param model_card_content: model card content
    """
    accuracy_pattern = r"Accuracy: (\d+\.\d+)"
    wer_pattern = r"Wer: (\d+\.\d+)"

    if task == "audio-classification":
      pattern = accuracy_pattern
    elif task == "automatic-speech-recognition":
      pattern = wer_pattern

    match = re.search(pattern, model_card_content)
    if match:
        metric = match.group(1)
        return float(metric)
    else:
        return None


def parse_metrics(model, task):
  """
  Get model card and parse it
  :param model_id: model id
  """
  card = ModelCard.load(model)
  return extract_metric(card.content, task)


def certification(hf_username):
  results_certification = [
      {
          "unit": "Unit 4: Audio Classification",
          "task": "audio-classification",
          "baseline_metric": 0.87,
          "best_result": 0,
          "best_model_id": "",
          "passed_": False
      },
  {
          "unit": "Unit 5: Automatic Speech Recognition",
          "task": "automatic-speech-recognition",
          "baseline_metric": 0.37,
          "best_result": 0,
          "best_model_id": "",
          "passed_": False
  },
  {
          "unit": "Unit 6: Text-to-Speech",
          "task": "text-to-speech",
          "baseline_metric": 0,
          "best_result": 0,
          "best_model_id": "",
          "passed_": False
  },
  {
          "unit": "Unit 7: Audio applications",
          "task": "demo",
          "baseline_metric": 0,
          "best_result": 0,
          "best_model_id": "",
          "passed_": False
  },
  ]

  for unit in results_certification:
    unit["passed"] = pass_emoji(unit["passed_"])

    match unit["task"]:
      case "audio-classification":
        try:
          user_ac_models = get_user_models(hf_username, task = "audio-classification")
          best_result, best_model_id = calculate_best_result(user_ac_models, task = "audio-classification")
          unit["best_result"] = best_result
          unit["best_model_id"] = best_model_id
          if unit["best_result"] >= unit["baseline_metric"]:
            unit["passed_"] = True
            unit["passed"] = pass_emoji(unit["passed_"])
        except: print("Either no relevant models found, or no metrics in the model card for audio classificaiton")
      case "automatic-speech-recognition":
        try:
          user_asr_models = get_user_models(hf_username, task = "automatic-speech-recognition")
          best_result, best_model_id = calculate_best_result(user_asr_models, task = "automatic-speech-recognition")
          unit["best_result"] = best_result
          unit["best_model_id"] = best_model_id
          if unit["best_result"] <= unit["baseline_metric"]:
            unit["passed_"] = True
            unit["passed"] = pass_emoji(unit["passed_"])
        except: print("Either no relevant models found, or no metrics in the model card for automatic speech recognition")
      case "text-to-speech":
        try:
          user_tts_models = get_user_models(hf_username, task = "text-to-speech")
          if user_tts_models:
            unit["best_result"] = 0
            unit["best_model_id"] = user_tts_models[0]
            unit["passed_"] = True
            unit["passed"] = pass_emoji(unit["passed_"])
        except: print("Either no relevant models found, or no metrics in the model card for automatic speech recognition")
      case "demo":
        u7_users = pd.read_csv(U7_USERNAMES)
        if hf_username in u7_users['username']:
            unit["best_result"] = 0
            unit["best_model_id"] = "Demo check passed, no model id"
            unit["passed_"] = True
            unit["passed"] = pass_emoji(unit["passed_"])
      case _:
        print("Unknown task")

  print(results_certification)

  df = pd.DataFrame(results_certification)
  df = df[['passed', 'unit', 'task', 'baseline_metric', 'best_result', 'best_model_id']]
  return df    

with gr.Blocks() as demo:
    gr.Markdown(f"""
    # 🏆 Check your progress in the Audio Course 🏆

    - To get a certificate of completion, you must **pass 3 out of 4 assignments before July 31st 2023**.
    - To get an honors certificate, you must **pass 4 out of 4 assignments before July 31st 2023**.

    For the assignments where you have to train a model, your model's metric should be equal to or better than the baseline metric.
    For the Unit 7 assignment, first, check your demo with Unit 7 assessment Space: https://huggingface.co/spaces/huggingface-course/audio-course-u7-assessment

    Make sure that you have uploaded your model(s) to Hub, and that your Unit 7 demo is public.
    To check your progress, type your Hugging Face Username here (in my case MariaK)
    """)

    hf_username = gr.Textbox(placeholder="MariaK", label="Your Hugging Face Username")
    check_progress_button = gr.Button(value="Check my progress")
    output = gr.components.Dataframe(value=certification(hf_username))
    check_progress_button.click(fn=certification, inputs=hf_username, outputs=output)

demo.launch()