File size: 5,082 Bytes
9b8968e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329490e
9b8968e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329490e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import custom_llm_inference
from transformers.cache_utils import DynamicCache

@pytest.fixture
def model_and_tokenizer():
    model_name = 'google/gemma-2-2b-it'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.bos_token_id is None:
        tokenizer.bos_token_id = tokenizer.pad_token_id
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        device_map="cpu", 
        torch_dtype=torch.float16
    )
    return model, tokenizer

@pytest.fixture
def sample_inputs():
    doc = "The quick brown fox loves to jump over lazy dogs."
    prompt = "Rewrite this document to make more sense."
    doc_in_progress = "Sure, here's the document rewritten as requested:\n\nA fox,"
    return doc, prompt, doc_in_progress

def test_get_next_token_predictions(model_and_tokenizer, sample_inputs):
    model, tokenizer = model_and_tokenizer
    doc, prompt, doc_in_progress = sample_inputs
    
    predictions = custom_llm_inference.get_next_token_predictions_slow(
        model, tokenizer, doc, prompt, doc_in_progress=doc_in_progress, k=5
    )
    
    assert len(predictions) == 2  # Should return (token_texts, logits)
    assert len(predictions[0]) == 5  # Should return k=5 predictions
    assert predictions[1].shape[1] == model.config.vocab_size

def test_get_tokenized_chat(model_and_tokenizer, sample_inputs):
    model, tokenizer = model_and_tokenizer
    doc, prompt, _ = sample_inputs
    
    tokenized_chat = custom_llm_inference.get_tokenized_chat(tokenizer, prompt, doc)
    
    assert isinstance(tokenized_chat, torch.Tensor)
    assert tokenized_chat.dim() == 1
    assert tokenized_chat.dtype == torch.int64

def test_highlights(model_and_tokenizer, sample_inputs):
    model, tokenizer = model_and_tokenizer
    doc, prompt, updated_doc = sample_inputs
    
    highlights = custom_llm_inference.get_highlights_inner(
        model, tokenizer, doc, prompt, updated_doc=updated_doc, k=5
    )
    
    assert isinstance(highlights, list)
    assert len(highlights) > 0
    for h in highlights:
        assert h['start'] >= 0
        assert h['end'] >= h['start']
        assert isinstance(h['token'], str)
        assert isinstance(h['token_loss'], float)
        assert isinstance(h['most_likely_token'], str)
        assert isinstance(h['topk_tokens'], list)

def compare_lookahead_predictions(model, tokenizer, doc, prompt, doc_in_progress, k=5):
    """
    Extracts and compares the next token predictions between the fast method and slow method.
    Returns the differences between the two approaches for analysis.
    """
    # Get predictions from the fast method (using cache)
    fast_tokens, fast_logits = custom_llm_inference.get_next_token_predictions_inner(
        model, tokenizer, doc, prompt, doc_in_progress, k
    )
    
    # Get predictions from the slow method (recomputing for each token)
    slow_tokens, slow_logits = custom_llm_inference.get_next_token_predictions_slow(
        model, tokenizer, doc, prompt, doc_in_progress, k
    )
    
    # Compare the decoded tokens (this is what users will see)
    token_matches = [fast == slow for fast, slow in zip(fast_tokens, slow_tokens)]
    
    # Calculate the difference in logits for most likely next tokens
    fast_most_likely = fast_logits.argmax(dim=-1)
    slow_most_likely = slow_logits.argmax(dim=-1)
    logit_match = torch.eq(fast_most_likely, slow_most_likely).cpu().numpy()
    
    # Calculate numerical difference in logits
    logit_diff_norm = torch.linalg.vector_norm((fast_logits - slow_logits).to(torch.float32), dim=1).cpu().numpy()
    
    return {
        "fast_tokens": fast_tokens,
        "slow_tokens": slow_tokens,
        "token_matches": token_matches,
        "token_match_all": all(token_matches),
        "logit_match": logit_match,
        "logit_diff_norm": logit_diff_norm
    }

def test_lookahead_token_consistency(model_and_tokenizer, sample_inputs):
    """
    Test that demonstrates the potential issue with cache position indices
    when generating lookahead tokens.
    """
    model, tokenizer = model_and_tokenizer
    doc, prompt, doc_in_progress = sample_inputs
    
    results = compare_lookahead_predictions(model, tokenizer, doc, prompt, doc_in_progress)
    
    # Check if the tokens are the same
    assert results["token_match_all"], (
        f"Fast and slow methods produced different tokens.\n"
        f"Fast: {results['fast_tokens']}\n"
        f"Slow: {results['slow_tokens']}"
    )
    
    # Check if the most likely next tokens based on logits are the same
    assert all(results["logit_match"]), (
        f"Fast and slow methods predicted different most likely next tokens"
    )
    
    # Check that the logit differences are minimal
    # This might fail if there's a bug in the cache position indices
    assert all(diff < 1e-4 for diff in results["logit_diff_norm"]), (
        f"Significant difference in logits between fast and slow methods: {results['logit_diff_norm']}"
    )