|
import random |
|
import numpy as np |
|
from nltk import word_tokenize |
|
from collections import defaultdict |
|
from copy import deepcopy |
|
import tqdm |
|
|
|
|
|
class PunktTokenizer: |
|
def __call__(self, texts): |
|
return [word_tokenize(t) for t in texts] |
|
|
|
|
|
class WhiteSpaceTokenizer: |
|
def __call__(self, texts): |
|
return [t.split() for t in texts] |
|
|
|
|
|
class SearchState: |
|
def __init__(self, tokens): |
|
self.tokens = tokens |
|
self.masks = [] |
|
self.mask_set = set() |
|
self.summaries = [] |
|
self.scores = [] |
|
self.best_step = None |
|
self.terminated = False |
|
self.step = 0 |
|
|
|
def update(self, mask, summary, score): |
|
if self.best_step is None or score > self.best_score(): |
|
self.best_step = self.step |
|
self.masks.append(mask) |
|
self.mask_set.add(tuple(mask)) |
|
self.summaries.append(summary) |
|
self.scores.append(score) |
|
self.step += 1 |
|
|
|
def best_mask(self): |
|
return self.masks[self.best_step] |
|
|
|
def best_score(self): |
|
return self.scores[self.best_step] |
|
|
|
def best_summary(self): |
|
return self.summaries[self.best_step] |
|
|
|
def to_dict(self): |
|
return { |
|
"scores": self.scores, |
|
"masks": self.masks, |
|
"summaries": self.summaries, |
|
"best_summary": self.best_summary(), |
|
"best_score": self.best_score(), |
|
} |
|
|
|
|
|
class DynamicRestartHCSC: |
|
def __init__(self, tokenizer, objective): |
|
self.tokenizer = tokenizer |
|
self.objective = objective |
|
self.n_trials = 100 |
|
|
|
def _mask_to_summary(self, mask, tokens): |
|
summary = [tokens[i] for i in range(len(mask)) if mask[i] == 1] |
|
return " ".join(summary) |
|
|
|
def _sample(self, state, sent_len, target_len, from_scratch=False): |
|
""" |
|
Swaps one selected word for another, discarding previous solutions. |
|
""" |
|
if target_len >= sent_len: |
|
mask = [1 for _ in range(sent_len)] |
|
state.terminated = True |
|
return mask, True |
|
if state.step == 0 or from_scratch: |
|
indices = list(range(sent_len)) |
|
sampled = set(random.sample(indices, min(target_len, sent_len))) |
|
mask = [int(i in sampled) for i in indices] |
|
return mask, False |
|
else: |
|
mask = state.masks[state.best_step] |
|
indices = list(range(len(mask))) |
|
one_indices = [i for i in range(len(mask)) if mask[i] == 1] |
|
zero_indices = [i for i in range(len(mask)) if mask[i] == 0] |
|
if len(zero_indices) == 0: |
|
return mask |
|
terminated = True |
|
|
|
for _ in range(self.n_trials): |
|
i = random.choice(one_indices) |
|
j = random.choice(zero_indices) |
|
new_mask = mask.copy() |
|
new_mask[i] = 0 |
|
new_mask[j] = 1 |
|
if tuple(new_mask) not in state.mask_set: |
|
terminated = False |
|
mask = new_mask |
|
break |
|
|
|
return mask, terminated |
|
|
|
def aggregate_states(self, states): |
|
masks = [m for s in states for m in s.masks] |
|
summaries = [x for s in states for x in s.summaries] |
|
scores = [x for s in states for x in s.scores] |
|
best_step = np.argmax(scores) |
|
return { |
|
"masks": masks, |
|
"summaries": summaries, |
|
"scores": scores, |
|
"best_score": scores[best_step], |
|
"best_summary": summaries[best_step], |
|
} |
|
|
|
def __call__( |
|
self, |
|
sentences, |
|
target_lens, |
|
n_steps=100, |
|
verbose=False, |
|
return_states=False, |
|
): |
|
tok_sentences = self.tokenizer(sentences) |
|
batch_size = len(sentences) |
|
terminated_states = [[] for _ in range(batch_size)] |
|
states = [SearchState(s) for s in tok_sentences] |
|
|
|
for t in tqdm.tqdm(list(range(1, n_steps + 1))): |
|
masks = [] |
|
for i in range(batch_size): |
|
if states[i].terminated: |
|
if verbose: |
|
print(f"step {t}, restarting state {i} with score {states[i].best_score()}") |
|
terminated_states[i].append(states[i]) |
|
states[i] = SearchState(tok_sentences[i]) |
|
|
|
mask, terminated = self._sample( |
|
states[i], |
|
sent_len=len(tok_sentences[i]), |
|
target_len=target_lens[i], |
|
) |
|
states[i].terminated = terminated |
|
masks.append(mask) |
|
|
|
summaries = [ |
|
self._mask_to_summary(m, tokens) |
|
for m, tokens in zip(masks, tok_sentences) |
|
] |
|
scores, _ = self.objective(sentences, summaries) |
|
|
|
if verbose: |
|
print(f"t={t}") |
|
for i in range(batch_size): |
|
print(f"[{scores[i]:.3f}][{summaries[i]}]") |
|
print() |
|
|
|
for i in range(batch_size): |
|
states[i].update(masks[i], summaries[i], scores[i]) |
|
|
|
for i in range(batch_size): |
|
terminated_states[i].append(states[i]) |
|
output_states = [ |
|
self.aggregate_states(i_states) for i_states in terminated_states |
|
] |
|
return output_states |
|
|