Spaces:
Runtime error
Runtime error
static llama_grammar* build_grammar(const std::string & grammar_str) { | |
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); | |
// Ensure we parsed correctly | |
assert(!parsed_grammar.rules.empty()); | |
// Ensure we have a root node | |
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end())); | |
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules()); | |
llama_grammar* grammar = llama_grammar_init( | |
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); | |
return grammar; | |
} | |
static bool match_string(const std::string & input, llama_grammar* grammar) { | |
auto decoded = decode_utf8(input, {}); | |
const auto & code_points = decoded.first; | |
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { | |
auto prev_stacks = grammar->stacks; | |
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks); | |
if (grammar->stacks.empty()) { | |
// no stacks means that the grammar failed to match at this point | |
return false; | |
} | |
} | |
for (const auto & stack : grammar->stacks) { | |
if (stack.empty()) { | |
// An empty stack means that the grammar has been completed | |
return true; | |
} | |
} | |
return false; | |
} | |
static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) { | |
fprintf(stderr, "⚫ Testing %s. Grammar: %s\n", test_desc.c_str(), grammar_str.c_str()); | |
fflush(stderr); | |
auto grammar = build_grammar(grammar_str); | |
// Save the original grammar stacks so that we can reset after every new string we want to test | |
auto original_stacks = grammar->stacks; | |
fprintf(stderr, " 🔵 Valid strings:\n"); | |
// Passing strings | |
for (const auto & test_string : passing_strings) { | |
fprintf(stderr, " \"%s\" ", test_string.c_str()); | |
fflush(stderr); | |
bool matched = match_string(test_string, grammar); | |
if (!matched) { | |
fprintf(stderr, "❌ (failed to match)\n"); | |
} else { | |
fprintf(stdout, "✅︎\n"); | |
} | |
assert(matched); | |
// Reset the grammar stacks | |
grammar->stacks = original_stacks; | |
} | |
fprintf(stderr, " 🟠 Invalid strings:\n"); | |
// Failing strings | |
for (const auto & test_string : failing_strings) { | |
fprintf(stderr, " \"%s\" ", test_string.c_str()); | |
fflush(stderr); | |
bool matched = match_string(test_string, grammar); | |
if (matched) { | |
fprintf(stderr, "❌ (incorrectly matched)\n"); | |
} else { | |
fprintf(stdout, "✅︎\n"); | |
} | |
assert(!matched); | |
// Reset the grammar stacks | |
grammar->stacks = original_stacks; | |
} | |
// Clean up allocated memory | |
llama_grammar_free(grammar); | |
} | |
static void test_simple_grammar() { | |
// Test case for a simple grammar | |
test_grammar( | |
"simple grammar", | |
R"""( | |
root ::= expr | |
expr ::= term ("+" term)* | |
term ::= number | |
number ::= [0-9]+)""", | |
// Passing strings | |
{ | |
"42", | |
"1+2+3+4+5", | |
"123+456", | |
}, | |
// Failing strings | |
{ | |
"+", | |
"/ 3", | |
"1+2+3+4+5+", | |
"12a45", | |
} | |
); | |
} | |
static void test_complex_grammar() { | |
// Test case for a more complex grammar, with both failure strings and success strings | |
test_grammar( | |
"medium complexity grammar", | |
// Grammar | |
R"""( | |
root ::= expression | |
expression ::= term ws (("+"|"-") ws term)* | |
term ::= factor ws (("*"|"/") ws factor)* | |
factor ::= number | variable | "(" expression ")" | function-call | |
number ::= [0-9]+ | |
variable ::= [a-zA-Z_][a-zA-Z0-9_]* | |
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")" | |
ws ::= [ \t\n\r]?)""", | |
// Passing strings | |
{ | |
"42", | |
"1*2*3*4*5", | |
"x", | |
"x+10", | |
"x1+y2", | |
"(a+b)*(c-d)", | |
"func()", | |
"func(x,y+2)", | |
"a*(b+c)-d/e", | |
"f(g(x),h(y,z))", | |
"x + 10", | |
"x1 + y2", | |
"(a + b) * (c - d)", | |
"func()", | |
"func(x, y + 2)", | |
"a * (b + c) - d / e", | |
"f(g(x), h(y, z))", | |
"123+456", | |
"123*456*789-123/456+789*123", | |
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456" | |
}, | |
// Failing strings | |
{ | |
"+", | |
"/ 3x", | |
"x + + y", | |
"a * / b", | |
"func(,)", | |
"func(x y)", | |
"(a + b", | |
"x + y)", | |
"a + b * (c - d", | |
"42 +", | |
"x +", | |
"x + 10 +", | |
"(a + b) * (c - d", | |
"func(", | |
"func(x, y + 2", | |
"a * (b + c) - d /", | |
"f(g(x), h(y, z)", | |
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/", | |
} | |
); | |
} | |
static void test_quantifiers() { | |
// A collection of tests to exercise * + and ? quantifiers | |
test_grammar( | |
"* quantifier", | |
// Grammar | |
R"""(root ::= "a"*)""", | |
// Passing strings | |
{ | |
"", | |
"a", | |
"aaaaa", | |
"aaaaaaaaaaaaaaaaaa", | |
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" | |
}, | |
// Failing strings | |
{ | |
"b", | |
"ab", | |
"aab", | |
"ba", | |
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" | |
} | |
); | |
test_grammar( | |
"+ quantifier", | |
// Grammar | |
R"""(root ::= "a"+)""", | |
// Passing strings | |
{ | |
"a", | |
"aaaaa", | |
"aaaaaaaaaaaaaaaaaa", | |
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" | |
}, | |
// Failing strings | |
{ | |
"", | |
"b", | |
"ab", | |
"aab", | |
"ba", | |
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" | |
} | |
); | |
test_grammar( | |
"? quantifier", | |
// Grammar | |
R"""(root ::= "a"?)""", | |
// Passing strings | |
{ | |
"", | |
"a" | |
}, | |
// Failing strings | |
{ | |
"b", | |
"ab", | |
"aa", | |
"ba", | |
} | |
); | |
test_grammar( | |
"mixed quantifiers", | |
// Grammar | |
R"""( | |
root ::= cons+ vowel* cons? (vowel cons)* | |
vowel ::= [aeiouy] | |
cons ::= [bcdfghjklmnpqrstvwxyz] | |
)""", | |
// Passing strings | |
{ | |
"yes", | |
"no", | |
"noyes", | |
"crwth", | |
"four", | |
"bryyyy", | |
}, | |
// Failing strings | |
{ | |
"yess", | |
"yesno", | |
"forty", | |
"catyyy", | |
} | |
); | |
} | |
static void test_failure_missing_root() { | |
fprintf(stderr, "⚫ Testing missing root node:\n"); | |
// Test case for a grammar that is missing a root rule | |
const std::string grammar_str = R"""(rot ::= expr | |
expr ::= term ("+" term)* | |
term ::= number | |
number ::= [0-9]+)"""; | |
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); | |
// Ensure we parsed correctly | |
assert(!parsed_grammar.rules.empty()); | |
// Ensure we do NOT have a root node | |
assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()); | |
fprintf(stderr, " ✅︎ Passed\n"); | |
} | |
static void test_failure_missing_reference() { | |
fprintf(stderr, "⚫ Testing missing reference node:\n"); | |
// Test case for a grammar that is missing a referenced rule | |
const std::string grammar_str = | |
R"""(root ::= expr | |
expr ::= term ("+" term)* | |
term ::= numero | |
number ::= [0-9]+)"""; | |
fprintf(stderr, " Expected error: "); | |
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); | |
// Ensure we did NOT parsed correctly | |
assert(parsed_grammar.rules.empty()); | |
fprintf(stderr, " End of expected error.\n"); | |
fprintf(stderr, " ✅︎ Passed\n"); | |
} | |
int main() { | |
fprintf(stdout, "Running grammar integration tests...\n"); | |
test_simple_grammar(); | |
test_complex_grammar(); | |
test_quantifiers(); | |
test_failure_missing_root(); | |
test_failure_missing_reference(); | |
fprintf(stdout, "All tests passed.\n"); | |
return 0; | |
} | |