|
"""A naive baseline method: just pass the full expression to CLIP.""" |
|
|
|
from overrides import overrides |
|
from typing import Dict, Any, List |
|
import numpy as np |
|
import torch |
|
import spacy |
|
from argparse import Namespace |
|
|
|
from .ref_method import RefMethod |
|
from lattice import Product as L |
|
|
|
|
|
class Baseline(RefMethod): |
|
"""CLIP-only baseline where each box is evaluated with the full expression.""" |
|
|
|
nlp = spacy.load('en_core_web_sm') |
|
|
|
def __init__(self, args: Namespace): |
|
self.args = args |
|
self.box_area_threshold = args.box_area_threshold |
|
self.batch_size = args.batch_size |
|
self.batch = [] |
|
|
|
@overrides |
|
def execute(self, caption: str, env: "Environment") -> Dict[str, Any]: |
|
chunk_texts = self.get_chunk_texts(caption) |
|
probs = env.filter(caption, area_threshold = self.box_area_threshold, softmax=True) |
|
if self.args.baseline_head: |
|
probs2 = env.filter(chunk_texts[0], area_threshold = self.box_area_threshold, softmax=True) |
|
probs = L.meet(probs, probs2) |
|
pred = np.argmax(probs) |
|
return { |
|
"probs": probs, |
|
"pred": pred, |
|
"box": env.boxes[pred], |
|
} |
|
|
|
def get_chunk_texts(self, expression: str) -> List: |
|
doc = self.nlp(expression) |
|
head = None |
|
for token in doc: |
|
if token.head.i == token.i: |
|
head = token |
|
break |
|
head_chunk = None |
|
chunk_texts = [] |
|
for chunk in doc.noun_chunks: |
|
if head.i >= chunk.start and head.i < chunk.end: |
|
head_chunk = chunk.text |
|
chunk_texts.append(chunk.text) |
|
if head_chunk is None: |
|
if len(list(doc.noun_chunks)) > 0: |
|
head_chunk = list(doc.noun_chunks)[0].text |
|
else: |
|
head_chunk = expression |
|
return [head_chunk] + [txt for txt in chunk_texts if txt != head_chunk] |
|
|