|
import os |
|
import json |
|
import pytest |
|
import torch |
|
import numpy as np |
|
|
|
from src.gpp import parse_markdown_table, GPP, GPPConfig |
|
from src.qa import Retriever, RetrieverConfig, Reranker, RerankerConfig, AnswerGenerator |
|
from src.utils import LLMClient |
|
|
|
|
|
def test_parse_markdown_table_valid(): |
|
md = """ |
|
|h1|h2| |
|
|--|--| |
|
|a|b| |
|
|c|d| |
|
""" |
|
res = parse_markdown_table(md) |
|
assert res['headers'] == ['h1', 'h2'] |
|
assert res['rows'] == [['a', 'b'], ['c', 'd']] |
|
|
|
|
|
def test_parse_markdown_table_invalid(): |
|
md = "not a table" |
|
assert parse_markdown_table(md) is None |
|
|
|
|
|
class DummyGPPConfig(GPPConfig): |
|
CHUNK_TOKEN_SIZE = 4 |
|
|
|
@pytest.fixture |
|
def gpp(): |
|
return GPP(DummyGPPConfig()) |
|
|
|
@pytest.fixture |
|
def blocks(): |
|
return [ |
|
{'type': 'text', 'text': 'one two three four'}, |
|
{'type': 'table', 'text': '|h|\n|-|\n|v|'}, |
|
{'type': 'text', 'text': 'five six'} |
|
] |
|
|
|
def test_chunk_blocks_table_isolation(gpp, blocks): |
|
chunks = gpp.chunk_blocks(blocks) |
|
|
|
assert len(chunks) == 3 |
|
assert chunks[1]['type'] == 'table' |
|
assert 'table_structure' in chunks[1] |
|
|
|
|
|
def test_retriever_combine_unique(monkeypatch): |
|
chunks = [{'narration': 'a'}, {'narration': 'b'}, {'narration': 'c'}] |
|
config = RetrieverConfig() |
|
retr = Retriever(chunks, config) |
|
|
|
monkeypatch.setattr(Retriever, 'retrieve_sparse', lambda self, q, top_k: [chunks[0], chunks[1]]) |
|
monkeypatch.setattr(Retriever, 'retrieve_dense', lambda self, q, top_k: [chunks[1], chunks[2]]) |
|
combined = retr.retrieve('query', top_k=2) |
|
assert combined == [chunks[0], chunks[1], chunks[2]] |
|
|
|
|
|
class DummyTokenizer: |
|
def __call__(self, queries, contexts, padding, truncation, return_tensors): |
|
batch = len(queries) |
|
return { |
|
'input_ids': torch.ones((batch, 1), dtype=torch.long), |
|
'attention_mask': torch.ones((batch, 1), dtype=torch.long) |
|
} |
|
|
|
class DummyModel: |
|
def __init__(self): pass |
|
def to(self, device): return self |
|
def __call__(self, **kwargs): |
|
|
|
batch = kwargs['input_ids'].shape[0] |
|
logits = torch.tensor([[0.1], [0.9]]) if batch == 2 else torch.rand((batch,1)) |
|
return type('Out', (), {'logits': logits}) |
|
|
|
@pytest.fixture(autouse=True) |
|
def dummy_pretrained(monkeypatch): |
|
import transformers |
|
monkeypatch.setattr(transformers.AutoTokenizer, 'from_pretrained', lambda name: DummyTokenizer()) |
|
monkeypatch.setattr(transformers.AutoModelForSequenceClassification, 'from_pretrained', lambda name: DummyModel()) |
|
return |
|
|
|
def test_reranker_order(): |
|
config = RerankerConfig() |
|
rer = Reranker(config) |
|
candidates = [{'narration': 'A'}, {'narration': 'B'}] |
|
ranked = rer.rerank('q', candidates, top_k=2) |
|
|
|
assert ranked[0]['narration'] == 'B' |
|
assert ranked[1]['narration'] == 'A' |
|
|
|
|
|
def test_answer_generator(monkeypatch): |
|
|
|
chunks = [{'narration': 'hello world'}] |
|
|
|
class DummyRetriever: |
|
def __init__(self, chunks, config): pass |
|
def retrieve(self, q, top_k=10): return chunks |
|
class DummyReranker: |
|
def __init__(self, config): pass |
|
def rerank(self, q, cands, top_k): return chunks |
|
|
|
|
|
monkeypatch.setattr('src.qa.Retriever', DummyRetriever) |
|
monkeypatch.setattr('src.qa.Reranker', DummyReranker) |
|
|
|
monkeypatch.setattr(LLMClient, 'generate', staticmethod(lambda prompt: 'TEST_ANSWER')) |
|
|
|
ag = AnswerGenerator() |
|
ans, sup = ag.answer(chunks, 'What?') |
|
assert ans == 'TEST_ANSWER' |
|
assert sup == chunks |
|
|