"""A naive baseline method: just pass the full expression to CLIP."""

from overrides import overrides
from typing import Dict, Any, List
import numpy as np
import torch
import spacy
from argparse import Namespace

from .ref_method import RefMethod
from lattice import Product as L


class Baseline(RefMethod):
    """CLIP-only baseline where each box is evaluated with the full expression."""

    nlp = spacy.load('en_core_web_sm')

    def __init__(self, args: Namespace):
        self.args = args
        self.box_area_threshold = args.box_area_threshold
        self.batch_size = args.batch_size
        self.batch = []

    @overrides
    def execute(self, caption: str, env: "Environment") -> Dict[str, Any]:
        chunk_texts = self.get_chunk_texts(caption)
        probs = env.filter(caption, area_threshold = self.box_area_threshold, softmax=True)
        if self.args.baseline_head:
            probs2 = env.filter(chunk_texts[0], area_threshold = self.box_area_threshold, softmax=True)
            probs = L.meet(probs, probs2)
        pred = np.argmax(probs)
        return {
            "probs": probs,
            "pred": pred,
            "box": env.boxes[pred],
        }

    def get_chunk_texts(self, expression: str) -> List:
        doc = self.nlp(expression)
        head = None
        for token in doc:
            if token.head.i == token.i:
                head = token
                break
        head_chunk = None
        chunk_texts = []
        for chunk in doc.noun_chunks:
            if head.i >= chunk.start and head.i < chunk.end:
                head_chunk = chunk.text
            chunk_texts.append(chunk.text)
        if head_chunk is None:
            if len(list(doc.noun_chunks)) > 0:
                head_chunk = list(doc.noun_chunks)[0].text
            else:
                head_chunk = expression
        return [head_chunk] + [txt for txt in chunk_texts if txt != head_chunk]