Spaces:
Running
Running
File size: 5,082 Bytes
9b8968e 329490e 9b8968e 329490e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import custom_llm_inference
from transformers.cache_utils import DynamicCache
@pytest.fixture
def model_and_tokenizer():
model_name = 'google/gemma-2-2b-it'
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.bos_token_id is None:
tokenizer.bos_token_id = tokenizer.pad_token_id
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu",
torch_dtype=torch.float16
)
return model, tokenizer
@pytest.fixture
def sample_inputs():
doc = "The quick brown fox loves to jump over lazy dogs."
prompt = "Rewrite this document to make more sense."
doc_in_progress = "Sure, here's the document rewritten as requested:\n\nA fox,"
return doc, prompt, doc_in_progress
def test_get_next_token_predictions(model_and_tokenizer, sample_inputs):
model, tokenizer = model_and_tokenizer
doc, prompt, doc_in_progress = sample_inputs
predictions = custom_llm_inference.get_next_token_predictions_slow(
model, tokenizer, doc, prompt, doc_in_progress=doc_in_progress, k=5
)
assert len(predictions) == 2 # Should return (token_texts, logits)
assert len(predictions[0]) == 5 # Should return k=5 predictions
assert predictions[1].shape[1] == model.config.vocab_size
def test_get_tokenized_chat(model_and_tokenizer, sample_inputs):
model, tokenizer = model_and_tokenizer
doc, prompt, _ = sample_inputs
tokenized_chat = custom_llm_inference.get_tokenized_chat(tokenizer, prompt, doc)
assert isinstance(tokenized_chat, torch.Tensor)
assert tokenized_chat.dim() == 1
assert tokenized_chat.dtype == torch.int64
def test_highlights(model_and_tokenizer, sample_inputs):
model, tokenizer = model_and_tokenizer
doc, prompt, updated_doc = sample_inputs
highlights = custom_llm_inference.get_highlights_inner(
model, tokenizer, doc, prompt, updated_doc=updated_doc, k=5
)
assert isinstance(highlights, list)
assert len(highlights) > 0
for h in highlights:
assert h['start'] >= 0
assert h['end'] >= h['start']
assert isinstance(h['token'], str)
assert isinstance(h['token_loss'], float)
assert isinstance(h['most_likely_token'], str)
assert isinstance(h['topk_tokens'], list)
def compare_lookahead_predictions(model, tokenizer, doc, prompt, doc_in_progress, k=5):
"""
Extracts and compares the next token predictions between the fast method and slow method.
Returns the differences between the two approaches for analysis.
"""
# Get predictions from the fast method (using cache)
fast_tokens, fast_logits = custom_llm_inference.get_next_token_predictions_inner(
model, tokenizer, doc, prompt, doc_in_progress, k
)
# Get predictions from the slow method (recomputing for each token)
slow_tokens, slow_logits = custom_llm_inference.get_next_token_predictions_slow(
model, tokenizer, doc, prompt, doc_in_progress, k
)
# Compare the decoded tokens (this is what users will see)
token_matches = [fast == slow for fast, slow in zip(fast_tokens, slow_tokens)]
# Calculate the difference in logits for most likely next tokens
fast_most_likely = fast_logits.argmax(dim=-1)
slow_most_likely = slow_logits.argmax(dim=-1)
logit_match = torch.eq(fast_most_likely, slow_most_likely).cpu().numpy()
# Calculate numerical difference in logits
logit_diff_norm = torch.linalg.vector_norm((fast_logits - slow_logits).to(torch.float32), dim=1).cpu().numpy()
return {
"fast_tokens": fast_tokens,
"slow_tokens": slow_tokens,
"token_matches": token_matches,
"token_match_all": all(token_matches),
"logit_match": logit_match,
"logit_diff_norm": logit_diff_norm
}
def test_lookahead_token_consistency(model_and_tokenizer, sample_inputs):
"""
Test that demonstrates the potential issue with cache position indices
when generating lookahead tokens.
"""
model, tokenizer = model_and_tokenizer
doc, prompt, doc_in_progress = sample_inputs
results = compare_lookahead_predictions(model, tokenizer, doc, prompt, doc_in_progress)
# Check if the tokens are the same
assert results["token_match_all"], (
f"Fast and slow methods produced different tokens.\n"
f"Fast: {results['fast_tokens']}\n"
f"Slow: {results['slow_tokens']}"
)
# Check if the most likely next tokens based on logits are the same
assert all(results["logit_match"]), (
f"Fast and slow methods predicted different most likely next tokens"
)
# Check that the logit differences are minimal
# This might fail if there's a bug in the cache position indices
assert all(diff < 1e-4 for diff in results["logit_diff_norm"]), (
f"Significant difference in logits between fast and slow methods: {results['logit_diff_norm']}"
)
|