File size: 4,088 Bytes
3301b3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import json
import pytest
import torch
import numpy as np

from src.gpp import parse_markdown_table, GPP, GPPConfig
from src.qa import Retriever, RetrieverConfig, Reranker, RerankerConfig, AnswerGenerator
from src.utils import LLMClient

# --- Tests for parse_markdown_table ---
def test_parse_markdown_table_valid():
    md = """
    |h1|h2|
    |--|--|
    |a|b|
    |c|d|
    """
    res = parse_markdown_table(md)
    assert res['headers'] == ['h1', 'h2']
    assert res['rows'] == [['a', 'b'], ['c', 'd']]


def test_parse_markdown_table_invalid():
    md = "not a table"
    assert parse_markdown_table(md) is None

# --- Tests for GPP.chunk_blocks ---
class DummyGPPConfig(GPPConfig):
    CHUNK_TOKEN_SIZE = 4  # small threshold for testing

@pytest.fixture
def gpp():
    return GPP(DummyGPPConfig())

@pytest.fixture
def blocks():
    return [
        {'type': 'text', 'text': 'one two three four'},
        {'type': 'table', 'text': '|h|\n|-|\n|v|'},
        {'type': 'text', 'text': 'five six'}
    ]

def test_chunk_blocks_table_isolation(gpp, blocks):
    chunks = gpp.chunk_blocks(blocks)
    # Expect 3 chunks: one text (4 tokens), one table, one text (2 tokens)
    assert len(chunks) == 3
    assert chunks[1]['type'] == 'table'
    assert 'table_structure' in chunks[1]

# --- Tests for Retriever.retrieve combining sparse & dense ---
def test_retriever_combine_unique(monkeypatch):
    chunks = [{'narration': 'a'}, {'narration': 'b'}, {'narration': 'c'}]
    config = RetrieverConfig()
    retr = Retriever(chunks, config)
    # Monkey-patch methods
    monkeypatch.setattr(Retriever, 'retrieve_sparse', lambda self, q, top_k: [chunks[0], chunks[1]])
    monkeypatch.setattr(Retriever, 'retrieve_dense', lambda self, q, top_k: [chunks[1], chunks[2]])
    combined = retr.retrieve('query', top_k=2)
    assert combined == [chunks[0], chunks[1], chunks[2]]

# --- Tests for Reranker.rerank with dummy model and tokenizer ---
class DummyTokenizer:
    def __call__(self, queries, contexts, padding, truncation, return_tensors):
        batch = len(queries)
        return {
            'input_ids': torch.ones((batch, 1), dtype=torch.long),
            'attention_mask': torch.ones((batch, 1), dtype=torch.long)
        }

class DummyModel:
    def __init__(self): pass
    def to(self, device): return self
    def __call__(self, **kwargs):
        # Generate logits: second candidate more relevant
        batch = kwargs['input_ids'].shape[0]
        logits = torch.tensor([[0.1], [0.9]]) if batch == 2 else torch.rand((batch,1))
        return type('Out', (), {'logits': logits})

@pytest.fixture(autouse=True)
def dummy_pretrained(monkeypatch):
    import transformers
    monkeypatch.setattr(transformers.AutoTokenizer, 'from_pretrained', lambda name: DummyTokenizer())
    monkeypatch.setattr(transformers.AutoModelForSequenceClassification, 'from_pretrained', lambda name: DummyModel())
    return

def test_reranker_order():
    config = RerankerConfig()
    rer = Reranker(config)
    candidates = [{'narration': 'A'}, {'narration': 'B'}]
    ranked = rer.rerank('q', candidates, top_k=2)
    # B should be ranked higher than A
    assert ranked[0]['narration'] == 'B'
    assert ranked[1]['narration'] == 'A'

# --- Tests for AnswerGenerator end-to-end logic ---
def test_answer_generator(monkeypatch):
    # Dummy chunks
    chunks = [{'narration': 'hello world'}]
    # Dummy Retriever and Reranker
    class DummyRetriever:
        def __init__(self, chunks, config): pass
        def retrieve(self, q, top_k=10): return chunks
    class DummyReranker:
        def __init__(self, config): pass
        def rerank(self, q, cands, top_k): return chunks

    # Patch in dummy classes
    monkeypatch.setattr('src.qa.Retriever', DummyRetriever)
    monkeypatch.setattr('src.qa.Reranker', DummyReranker)
    # Patch LLMClient.generate
    monkeypatch.setattr(LLMClient, 'generate', staticmethod(lambda prompt: 'TEST_ANSWER'))

    ag = AnswerGenerator()
    ans, sup = ag.answer(chunks, 'What?')
    assert ans == 'TEST_ANSWER'
    assert sup == chunks