HugoVoxx commited on
Commit
764db3e
·
verified ·
1 Parent(s): 53aa5f8

Upload 21 files

Browse files
ag4masses/alphageometry/alphageometry.py CHANGED
@@ -33,8 +33,6 @@ import problem as pr
33
  #=============
34
  import sys, os, math, re
35
  import multiprocessing
36
- import warnings
37
- warnings.filterwarnings("ignore")
38
  model = None # global variable used in multi-processing workers
39
 
40
  _GIN_SEARCH_PATHS = flags.DEFINE_list(
@@ -152,8 +150,8 @@ def write_solution(g: gh.Graph, p: pr.Problem, out_file: str) -> None:
152
  g, p.goal, merge_trivials=False
153
  )
154
 
155
- solution = ''
156
- solution += 'Theo đề bài ta có:\n'
157
  premises_nl = []
158
  for premises, [points] in setup:
159
  solution += ' '.join([p.name.upper() for p in points]) + ' '
@@ -165,18 +163,15 @@ def write_solution(g: gh.Graph, p: pr.Problem, out_file: str) -> None:
165
  ]
166
  solution += ': Points\n' + '\n'.join(premises_nl)
167
 
168
- solution += '\n\nCác điểm cần dựng thêm:\n'
169
  aux_premises_nl = []
170
- if len(aux) == 0:
171
- solution += 'Không cần dựng thêm điểm nào.'
172
- else:
173
- for premises, [points] in aux:
174
- solution += ' '.join([p.name.upper() for p in points]) + ' '
175
- aux_premises_nl += [
176
- natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
177
- for p in premises
178
- ]
179
- solution += ': Points\n' + '\n'.join(aux_premises_nl)
180
 
181
  # some special case where the deduction rule has a well known name.
182
  r2name = {
@@ -194,19 +189,22 @@ def write_solution(g: gh.Graph, p: pr.Problem, out_file: str) -> None:
194
  'a02': '(Angle chase)',
195
  }
196
 
197
- solution += '\n\nCác bước chứng minh:\n'
198
  for i, step in enumerate(proof_steps):
199
  _, [con] = step
200
  nl = proof_step_string(step, refs, last_step=i == len(proof_steps) - 1)
201
  rule_name = r2name.get(con.rule_name, '')
202
  nl = nl.replace('\u21d2', f'{rule_name}\u21d2 ')
203
  solution += '{:03}. '.format(i + 1) + nl + '\n'
 
 
204
  logging.info(solution)
205
  if out_file:
206
  with open(out_file, 'w') as f:
207
  f.write(solution)
208
  logging.info('Solution written to %s.', out_file)
209
 
 
210
  def get_lm(ckpt_init: str, vocab_path: str) -> lm.LanguageModelInference:
211
  lm.parse_gin_configuration(
212
  _GIN_FILE.value, _GIN_PARAM.value, gin_paths=_GIN_SEARCH_PATHS.value
@@ -234,10 +232,12 @@ def run_ddar(g: gh.Graph, p: pr.Problem, out_file: str) -> bool:
234
  return False
235
 
236
  write_solution(g, p, out_file)
 
237
  gh.nm.draw(
238
  g.type2nodes[gh.Point],
239
  g.type2nodes[gh.Line],
240
  g.type2nodes[gh.Circle],
 
241
  g.type2nodes[gh.Segment],
242
  goal=(p.goal.name, goal_args),
243
  save_to="ag4mout/output.png",)
@@ -718,7 +718,7 @@ def main(_):
718
  # point names will be renamed to alphabetical a, b, c, d, e, ...
719
  # instead of staying with their original names,
720
  # in order to match the synthetic training data generation.
721
- need_rename = _MODE.value != 'ddar'
722
 
723
  # load problems from the problems_file,
724
  problems = pr.Problem.from_txt_file(
@@ -752,4 +752,4 @@ def main(_):
752
 
753
 
754
  if __name__ == '__main__':
755
- app.run(main)
 
33
  #=============
34
  import sys, os, math, re
35
  import multiprocessing
 
 
36
  model = None # global variable used in multi-processing workers
37
 
38
  _GIN_SEARCH_PATHS = flags.DEFINE_list(
 
150
  g, p.goal, merge_trivials=False
151
  )
152
 
153
+ solution = '\n=========================='
154
+ solution += '\n * From theorem premises:\n'
155
  premises_nl = []
156
  for premises, [points] in setup:
157
  solution += ' '.join([p.name.upper() for p in points]) + ' '
 
163
  ]
164
  solution += ': Points\n' + '\n'.join(premises_nl)
165
 
166
+ solution += '\n\n * Auxiliary Constructions:\n'
167
  aux_premises_nl = []
168
+ for premises, [points] in aux:
169
+ solution += ' '.join([p.name.upper() for p in points]) + ' '
170
+ aux_premises_nl += [
171
+ natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
172
+ for p in premises
173
+ ]
174
+ solution += ': Points\n' + '\n'.join(aux_premises_nl)
 
 
 
175
 
176
  # some special case where the deduction rule has a well known name.
177
  r2name = {
 
189
  'a02': '(Angle chase)',
190
  }
191
 
192
+ solution += '\n\n * Proof steps:\n'
193
  for i, step in enumerate(proof_steps):
194
  _, [con] = step
195
  nl = proof_step_string(step, refs, last_step=i == len(proof_steps) - 1)
196
  rule_name = r2name.get(con.rule_name, '')
197
  nl = nl.replace('\u21d2', f'{rule_name}\u21d2 ')
198
  solution += '{:03}. '.format(i + 1) + nl + '\n'
199
+
200
+ solution += '==========================\n'
201
  logging.info(solution)
202
  if out_file:
203
  with open(out_file, 'w') as f:
204
  f.write(solution)
205
  logging.info('Solution written to %s.', out_file)
206
 
207
+
208
  def get_lm(ckpt_init: str, vocab_path: str) -> lm.LanguageModelInference:
209
  lm.parse_gin_configuration(
210
  _GIN_FILE.value, _GIN_PARAM.value, gin_paths=_GIN_SEARCH_PATHS.value
 
232
  return False
233
 
234
  write_solution(g, p, out_file)
235
+
236
  gh.nm.draw(
237
  g.type2nodes[gh.Point],
238
  g.type2nodes[gh.Line],
239
  g.type2nodes[gh.Circle],
240
+ g.type2nodes[gh.SemiCircle],
241
  g.type2nodes[gh.Segment],
242
  goal=(p.goal.name, goal_args),
243
  save_to="ag4mout/output.png",)
 
718
  # point names will be renamed to alphabetical a, b, c, d, e, ...
719
  # instead of staying with their original names,
720
  # in order to match the synthetic training data generation.
721
+ need_rename = False
722
 
723
  # load problems from the problems_file,
724
  problems = pr.Problem.from_txt_file(
 
752
 
753
 
754
  if __name__ == '__main__':
755
+ app.run(main)
ag4masses/alphageometry/dd.py CHANGED
@@ -1,1156 +1,1220 @@
1
- # Copyright 2023 DeepMind Technologies Limited
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Implements Deductive Database (DD)."""
17
-
18
- # pylint: disable=g-multiple-import,g-importing-member
19
- from collections import defaultdict
20
- import time
21
- from typing import Any, Callable, Generator
22
-
23
- import geometry as gm
24
- import graph as gh
25
- import graph_utils as utils
26
- import numericals as nm
27
- import problem as pr
28
- from problem import Dependency, EmptyDependency
29
-
30
-
31
- def intersect1(set1: set[Any], set2: set[Any]) -> Any:
32
- for x in set1:
33
- if x in set2:
34
- return x
35
- return None
36
-
37
-
38
- def diff_point(l: gm.Line, a: gm.Point) -> gm.Point:
39
- for x in l.neighbors(gm.Point):
40
- if x != a:
41
- return x
42
- return None
43
-
44
-
45
- # pylint: disable=protected-access
46
- # pylint: disable=unused-argument
47
-
48
-
49
- def match_eqratio_eqratio_eqratio(
50
- g: gh.Graph,
51
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
52
- theorem: pr.Theorem,
53
- ) -> Generator[dict[str, gm.Point], None, None]:
54
- """Match eqratio a b c d m n p q, eqratio c d e f p q r u => eqratio a b e f m n r u."""
55
- for m1 in g.type2nodes[gm.Value]:
56
- for m2 in g.type2nodes[gm.Value]:
57
- rats1 = []
58
- for rat in m1.neighbors(gm.Ratio):
59
- l1, l2 = rat.lengths
60
- if l1 is None or l2 is None:
61
- continue
62
- rats1.append((l1, l2))
63
-
64
- rats2 = []
65
- for rat in m2.neighbors(gm.Ratio):
66
- l1, l2 = rat.lengths
67
- if l1 is None or l2 is None:
68
- continue
69
- rats2.append((l1, l2))
70
-
71
- pairs = []
72
- for (l1, l2), (l3, l4) in utils.cross(rats1, rats2):
73
- if l2 == l3:
74
- pairs.append((l1, l2, l4))
75
-
76
- for (l1, l12, l2), (l3, l34, l4) in utils.comb2(pairs):
77
- if (l1, l12, l2) == (l3, l34, l4):
78
- continue
79
- if l1 == l2 or l3 == l4:
80
- continue
81
- if l1 == l12 or l12 == l2 or l3 == l34 or l4 == l34:
82
- continue
83
- # d12 - d1 = d34 - d3 = m1
84
- # d2 - d12 = d4 - d34 = m2
85
- # => d2 - d1 = d4 - d3 (= m1+m2)
86
- a, b = g.two_points_of_length(l1)
87
- c, d = g.two_points_of_length(l12)
88
- m, n = g.two_points_of_length(l3)
89
- p, q = g.two_points_of_length(l34)
90
- # eqangle a b c d m n p q
91
- e, f = g.two_points_of_length(l2)
92
- r, u = g.two_points_of_length(l4)
93
- yield dict(zip('abcdefmnpqru', [a, b, c, d, e, f, m, n, p, q, r, u]))
94
-
95
-
96
- def match_eqangle_eqangle_eqangle(
97
- g: gh.Graph,
98
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
99
- theorem: pr.Theorem,
100
- ) -> Generator[dict[str, gm.Point], None, None]:
101
- """Match eqangle a b c d m n p q, eqangle c d e f p q r u => eqangle a b e f m n r u."""
102
- for m1 in g.type2nodes[gm.Measure]:
103
- for m2 in g.type2nodes[gm.Measure]:
104
- angs1 = []
105
- for ang in m1.neighbors(gm.Angle):
106
- d1, d2 = ang.directions
107
- if d1 is None or d2 is None:
108
- continue
109
- angs1.append((d1, d2))
110
-
111
- angs2 = []
112
- for ang in m2.neighbors(gm.Angle):
113
- d1, d2 = ang.directions
114
- if d1 is None or d2 is None:
115
- continue
116
- angs2.append((d1, d2))
117
-
118
- pairs = []
119
- for (d1, d2), (d3, d4) in utils.cross(angs1, angs2):
120
- if d2 == d3:
121
- pairs.append((d1, d2, d4))
122
-
123
- for (d1, d12, d2), (d3, d34, d4) in utils.comb2(pairs):
124
- if (d1, d12, d2) == (d3, d34, d4):
125
- continue
126
- if d1 == d2 or d3 == d4:
127
- continue
128
- if d1 == d12 or d12 == d2 or d3 == d34 or d4 == d34:
129
- continue
130
- # d12 - d1 = d34 - d3 = m1
131
- # d2 - d12 = d4 - d34 = m2
132
- # => d2 - d1 = d4 - d3
133
- a, b = g.two_points_on_direction(d1)
134
- c, d = g.two_points_on_direction(d12)
135
- m, n = g.two_points_on_direction(d3)
136
- p, q = g.two_points_on_direction(d34)
137
- # eqangle a b c d m n p q
138
- e, f = g.two_points_on_direction(d2)
139
- r, u = g.two_points_on_direction(d4)
140
- yield dict(zip('abcdefmnpqru', [a, b, c, d, e, f, m, n, p, q, r, u]))
141
-
142
-
143
- def match_perp_perp_npara_eqangle(
144
- g: gh.Graph,
145
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
146
- theorem: pr.Theorem,
147
- ) -> Generator[dict[str, gm.Point], None, None]:
148
- """Match perp A B C D, perp E F G H, npara A B E F => eqangle A B E F C D G H."""
149
- dpairs = []
150
- for ang in g.vhalfpi.neighbors(gm.Angle):
151
- d1, d2 = ang.directions
152
- if d1 is None or d2 is None:
153
- continue
154
- dpairs.append((d1, d2))
155
-
156
- for (d1, d2), (d3, d4) in utils.comb2(dpairs):
157
- a, b = g.two_points_on_direction(d1)
158
- c, d = g.two_points_on_direction(d2)
159
- m, n = g.two_points_on_direction(d3)
160
- p, q = g.two_points_on_direction(d4)
161
- if g.check_npara([a, b, m, n]):
162
- if ({a, b}, {c, d}) == ({m, n}, {p, q}):
163
- continue
164
- if ({a, b}, {c, d}) == ({p, q}, {m, n}):
165
- continue
166
-
167
- yield dict(zip('ABCDEFGH', [a, b, c, d, m, n, p, q]))
168
-
169
-
170
- def match_circle_coll_eqangle_midp(
171
- g: gh.Graph,
172
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
173
- theorem: pr.Theorem,
174
- ) -> Generator[dict[str, gm.Point], None, None]:
175
- """Match circle O A B C, coll M B C, eqangle A B A C O B O M => midp M B C."""
176
- for p, a, b, c in g.all_circles():
177
- ab = g._get_line(a, b)
178
- if ab is None:
179
- continue
180
- if ab.val is None:
181
- continue
182
- ac = g._get_line(a, c)
183
- if ac is None:
184
- continue
185
- if ac.val is None:
186
- continue
187
- pb = g._get_line(p, b)
188
- if pb is None:
189
- continue
190
- if pb.val is None:
191
- continue
192
-
193
- bc = g._get_line(b, c)
194
- if bc is None:
195
- continue
196
- bc_points = bc.neighbors(gm.Point, return_set=True)
197
-
198
- anga, _ = g._get_angle(ab.val, ac.val)
199
-
200
- for angp in pb.val.neighbors(gm.Angle):
201
- if not g.is_equal(anga, angp):
202
- continue
203
-
204
- _, d = angp.directions
205
- for l in d.neighbors(gm.Line):
206
- l_points = l.neighbors(gm.Point, return_set=True)
207
- m = intersect1(bc_points, l_points)
208
- if m is not None:
209
- yield dict(zip('ABCMO', [a, b, c, m, p]))
210
-
211
-
212
- def match_midp_perp_cong(
213
- g: gh.Graph,
214
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
215
- theorem: pr.Theorem,
216
- ) -> Generator[dict[str, gm.Point], None, None]:
217
- """Match midp M A B, perp O M A B => cong O A O B."""
218
- for m, a, b in g.all_midps():
219
- ab = g._get_line(a, b)
220
- for l in m.neighbors(gm.Line):
221
- if g.check_perpl(l, ab):
222
- for o in l.neighbors(gm.Point):
223
- if o != m:
224
- yield dict(zip('ABMO', [a, b, m, o]))
225
-
226
-
227
- def match_cyclic_eqangle_cong(
228
- g: gh.Graph,
229
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
230
- theorem: pr.Theorem,
231
- ) -> Generator[dict[str, gm.Point], None, None]:
232
- """Match cyclic A B C P Q R, eqangle C A C B R P R Q => cong A B P Q."""
233
- for c in g.type2nodes[gm.Circle]:
234
- ps = c.neighbors(gm.Point)
235
- for (a, b, c), (x, y, z) in utils.comb2(list(utils.perm3(ps))):
236
- if {a, b, c} == {x, y, z}:
237
- continue
238
- if g.check_eqangle([c, a, c, b, z, x, z, y]):
239
- yield dict(zip('ABCPQR', [a, b, c, x, y, z]))
240
-
241
-
242
- def match_circle_eqangle_perp(
243
- g: gh.Graph,
244
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
245
- theorem: pr.Theorem,
246
- ) -> Generator[dict[str, gm.Point], None, None]:
247
- """Match circle O A B C, eqangle A X A B C A C B => perp O A A X."""
248
- for p, a, b, c in g.all_circles():
249
- ca = g._get_line(c, a)
250
- if ca is None:
251
- continue
252
- cb = g._get_line(c, b)
253
- if cb is None:
254
- continue
255
- ab = g._get_line(a, b)
256
- if ab is None:
257
- continue
258
-
259
- if ca.val is None:
260
- continue
261
- if cb.val is None:
262
- continue
263
- if ab.val is None:
264
- continue
265
-
266
- c_ang, _ = g._get_angle(cb.val, ca.val)
267
- if c_ang is None:
268
- continue
269
-
270
- for ang in ab.val.neighbors(gm.Angle):
271
- if g.is_equal(ang, c_ang):
272
- _, d = ang.directions
273
- for l in d.neighbors(gm.Line):
274
- if a not in l.neighbors(gm.Point):
275
- continue
276
- x = diff_point(l, a)
277
- if x is None:
278
- continue
279
- yield dict(zip('OABCX', [p, a, b, c, x]))
280
- break
281
-
282
-
283
- def match_circle_perp_eqangle(
284
- g: gh.Graph,
285
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
286
- theorem: pr.Theorem,
287
- ) -> Generator[dict[str, gm.Point], None, None]:
288
- """Match circle O A B C, perp O A A X => eqangle A X A B C A C B."""
289
- for p, a, b, c in g.all_circles():
290
- pa = g._get_line(p, a)
291
- if pa is None:
292
- continue
293
- if pa.val is None:
294
- continue
295
- for l in a.neighbors(gm.Line):
296
- if g.check_perpl(pa, l):
297
- x = diff_point(l, a)
298
- if x is not None:
299
- yield dict(zip('OABCX', [p, a, b, c, x]))
300
-
301
-
302
- def match_perp_perp_ncoll_para(
303
- g: gh.Graph,
304
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
305
- theorem: pr.Theorem,
306
- ) -> Generator[dict[str, gm.Point], None, None]:
307
- """Match perp A B C D, perp C D E F, ncoll A B E => para A B E F."""
308
- d2d = defaultdict(list)
309
- for ang in g.vhalfpi.neighbors(gm.Angle):
310
- d1, d2 = ang.directions
311
- if d1 is None or d2 is None:
312
- continue
313
- d2d[d1] += [d2]
314
- d2d[d2] += [d1]
315
-
316
- for x, ys in d2d.items():
317
- if len(ys) < 2:
318
- continue
319
- c, d = g.two_points_on_direction(x)
320
- for y1, y2 in utils.comb2(ys):
321
- a, b = g.two_points_on_direction(y1)
322
- e, f = g.two_points_on_direction(y2)
323
- if nm.check_ncoll([a.num, b.num, e.num]):
324
- yield dict(zip('ABCDEF', [a, b, c, d, e, f]))
325
-
326
-
327
- def match_eqangle6_ncoll_cong(
328
- g: gh.Graph,
329
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
330
- theorem: pr.Theorem,
331
- ) -> Generator[dict[str, gm.Point], None, None]:
332
- """Match eqangle6 A O A B B A B O, ncoll O A B => cong O A O B."""
333
- for a in g.type2nodes[gm.Point]:
334
- for b, c in utils.comb2(g.type2nodes[gm.Point]):
335
- if a == b or a == c:
336
- continue
337
- if g.check_eqangle([b, a, b, c, c, b, c, a]):
338
- if g.check_ncoll([a, b, c]):
339
- yield dict(zip('OAB', [a, b, c]))
340
-
341
-
342
- def match_eqangle_perp_perp(
343
- g: gh.Graph,
344
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
345
- theorem: pr.Theorem,
346
- ) -> Generator[dict[str, gm.Point], None, None]:
347
- """Match eqangle A B P Q C D U V, perp P Q U V => perp A B C D."""
348
- for ang in g.vhalfpi.neighbors(gm.Angle):
349
- # d1 perp d2
350
- d1, d2 = ang.directions
351
- if d1 is None or d2 is None:
352
- continue
353
- for d3, d4 in utils.comb2(g.type2nodes[gm.Direction]):
354
- if d1 == d3 or d2 == d4:
355
- continue
356
- # if d1 - d3 = d2 - d4 => d3 perp d4
357
- a13, a31 = g._get_angle(d1, d3)
358
- a24, a42 = g._get_angle(d2, d4)
359
- if a13 is None or a31 is None or a24 is None or a42 is None:
360
- continue
361
- if g.is_equal(a13, a24) and g.is_equal(a31, a42):
362
- a, b = g.two_points_on_direction(d1)
363
- c, d = g.two_points_on_direction(d2)
364
- m, n = g.two_points_on_direction(d3)
365
- p, q = g.two_points_on_direction(d4)
366
- yield dict(zip('ABCDPQUV', [m, n, p, q, a, b, c, d]))
367
-
368
-
369
- def match_eqangle_ncoll_cyclic(
370
- g: gh.Graph,
371
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
372
- theorem: pr.Theorem,
373
- ) -> Generator[dict[str, gm.Point], None, None]:
374
- """Match eqangle6 P A P B Q A Q B, ncoll P Q A B => cyclic A B P Q."""
375
- for l1, l2, l3, l4 in g.all_eqangles_distinct_linepairss():
376
- if len(set([l1, l2, l3, l4])) < 4:
377
- continue # they all must be distinct.
378
-
379
- p1s = l1.neighbors(gm.Point, return_set=True)
380
- p2s = l2.neighbors(gm.Point, return_set=True)
381
- p3s = l3.neighbors(gm.Point, return_set=True)
382
- p4s = l4.neighbors(gm.Point, return_set=True)
383
-
384
- p = intersect1(p1s, p2s)
385
- if not p:
386
- continue
387
- q = intersect1(p3s, p4s)
388
- if not q:
389
- continue
390
- a = intersect1(p1s, p3s)
391
- if not a:
392
- continue
393
- b = intersect1(p2s, p4s)
394
- if not b:
395
- continue
396
- if len(set([a, b, p, q])) < 4:
397
- continue
398
-
399
- if not g.check_ncoll([a, b, p, q]):
400
- continue
401
-
402
- yield dict(zip('ABPQ', [a, b, p, q]))
403
-
404
-
405
- def match_eqangle_para(
406
- g: gh.Graph,
407
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
408
- theorem: pr.Theorem,
409
- ) -> Generator[dict[str, gm.Point], None, None]:
410
- """Match eqangle A B P Q C D P Q => para A B C D."""
411
- for measure in g.type2nodes[gm.Measure]:
412
- angs = measure.neighbors(gm.Angle)
413
- d12, d21 = defaultdict(list), defaultdict(list)
414
- for ang in angs:
415
- d1, d2 = ang.directions
416
- if d1 is None or d2 is None:
417
- continue
418
- d12[d1].append(d2)
419
- d21[d2].append(d1)
420
-
421
- for d1, d2s in d12.items():
422
- a, b = g.two_points_on_direction(d1)
423
- for d2, d3 in utils.comb2(d2s):
424
- c, d = g.two_points_on_direction(d2)
425
- e, f = g.two_points_on_direction(d3)
426
- yield dict(zip('ABCDPQ', [c, d, e, f, a, b]))
427
-
428
-
429
- def match_cyclic_eqangle(
430
- g: gh.Graph,
431
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
432
- theorem: pr.Theorem,
433
- ) -> Generator[dict[str, gm.Point], None, None]:
434
- """Match cyclic A B P Q => eqangle P A P B Q A Q B."""
435
- record = set()
436
- for a, b, c, d in g_matcher('cyclic'):
437
- if (a, b, c, d) in record:
438
- continue
439
- record.add((a, b, c, d))
440
- record.add((a, b, d, c))
441
- record.add((b, a, c, d))
442
- record.add((b, a, d, c))
443
- yield dict(zip('ABPQ', [a, b, c, d]))
444
-
445
-
446
- def rotate_simtri(
447
- a: gm.Point, b: gm.Point, c: gm.Point, x: gm.Point, y: gm.Point, z: gm.Point
448
- ) -> Generator[tuple[gm.Point, ...], None, None]:
449
- """Rotate points around for similar triangle predicates."""
450
- yield (z, y, x, c, b, a)
451
- for p in [
452
- (b, c, a, y, z, x),
453
- (c, a, b, z, x, y),
454
- (x, y, z, a, b, c),
455
- (y, z, x, b, c, a),
456
- (z, x, y, c, a, b),
457
- ]:
458
- yield p
459
- yield p[::-1]
460
-
461
-
462
- def match_cong_cong_cong_cyclic(
463
- g: gh.Graph,
464
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
465
- theorem: pr.Theorem,
466
- ) -> Generator[dict[str, gm.Point], None, None]:
467
- """Match cong O A O B, cong O B O C, cong O C O D => cyclic A B C D."""
468
- for l in g.type2nodes[gm.Length]:
469
- p2p = defaultdict(list)
470
- for s in l.neighbors(gm.Segment):
471
- a, b = s.points
472
- p2p[a].append(b)
473
- p2p[b].append(a)
474
-
475
- for p, ps in p2p.items():
476
- if len(ps) >= 4:
477
- for a, b, c, d in utils.comb4(ps):
478
- yield dict(zip('OABCD', [p, a, b, c, d]))
479
-
480
-
481
- def match_cong_cong_cong_ncoll_contri(
482
- g: gh.Graph,
483
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
484
- theorem: pr.Theorem,
485
- ) -> Generator[dict[str, gm.Point], None, None]:
486
- """Match cong A B P Q, cong B C Q R, cong C A R P, ncoll A B C => contri* A B C P Q R."""
487
- record = set()
488
- for a, b, p, q in g_matcher('cong'):
489
- for c in g.type2nodes[gm.Point]:
490
- for r in g.type2nodes[gm.Point]:
491
- if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]):
492
- continue
493
- if not g.check_ncoll([a, b, c]):
494
- continue
495
- if g.check_cong([b, c, q, r]) and g.check_cong([c, a, r, p]):
496
- record.add((a, b, c, p, q, r))
497
- yield dict(zip('ABCPQR', [a, b, c, p, q, r]))
498
-
499
-
500
- def match_cong_cong_eqangle6_ncoll_contri(
501
- g: gh.Graph,
502
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
503
- theorem: pr.Theorem,
504
- ) -> Generator[dict[str, gm.Point], None, None]:
505
- """Match cong A B P Q, cong B C Q R, eqangle6 B A B C Q P Q R, ncoll A B C => contri* A B C P Q R."""
506
- record = set()
507
- for a, b, p, q in g_matcher('cong'):
508
- for c in g.type2nodes[gm.Point]:
509
- if c in (a, b):
510
- continue
511
- for r in g.type2nodes[gm.Point]:
512
- if r in (p, q):
513
- continue
514
-
515
- in_record = False
516
- for x in [
517
- (c, b, a, r, q, p),
518
- (p, q, r, a, b, c),
519
- (r, q, p, c, b, a),
520
- ]:
521
- if x in record:
522
- in_record = True
523
- break
524
-
525
- if in_record:
526
- continue
527
-
528
- if not g.check_cong([b, c, q, r]):
529
- continue
530
- if not g.check_ncoll([a, b, c]):
531
- continue
532
-
533
- if nm.same_clock(a.num, b.num, c.num, p.num, q.num, r.num):
534
- if g.check_eqangle([b, a, b, c, q, p, q, r]):
535
- record.add((a, b, c, p, q, r))
536
- yield dict(zip('ABCPQR', [a, b, c, p, q, r]))
537
- else:
538
- if g.check_eqangle([b, a, b, c, q, r, q, p]):
539
- record.add((a, b, c, p, q, r))
540
- yield dict(zip('ABCPQR', [a, b, c, p, q, r]))
541
-
542
-
543
- def match_eqratio6_eqangle6_ncoll_simtri(
544
- g: gh.Graph,
545
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
546
- theorem: pr.Theorem,
547
- ) -> Generator[dict[str, gm.Point], None, None]:
548
- """Match eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C => simtri* A B C P Q R."""
549
- enums = g_matcher('eqratio6')
550
-
551
- record = set()
552
- for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable
553
- if (a, b, c) == (p, q, r):
554
- continue
555
- if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]):
556
- continue
557
- if not g.check_ncoll([a, b, c]):
558
- continue
559
-
560
- if nm.same_clock(a.num, b.num, c.num, p.num, q.num, r.num):
561
- if g.check_eqangle([b, a, b, c, q, p, q, r]):
562
- record.add((a, b, c, p, q, r))
563
- yield dict(zip('ABCPQR', [a, b, c, p, q, r]))
564
- elif g.check_eqangle([b, a, b, c, q, r, q, p]):
565
- record.add((a, b, c, p, q, r))
566
- yield dict(zip('ABCPQR', [a, b, c, p, q, r]))
567
-
568
-
569
- def match_eqangle6_eqangle6_ncoll_simtri(
570
- g: gh.Graph,
571
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
572
- theorem: pr.Theorem,
573
- ) -> Generator[dict[str, gm.Point], None, None]:
574
- """Match eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C => simtri A B C P Q R."""
575
- enums = g_matcher('eqangle6')
576
-
577
- record = set()
578
- for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable
579
- if (a, b, c) == (p, q, r):
580
- continue
581
- if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]):
582
- continue
583
- if not g.check_eqangle([c, a, c, b, r, p, r, q]):
584
- continue
585
- if not g.check_ncoll([a, b, c]):
586
- continue
587
-
588
- mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
589
- record.add((a, b, c, p, q, r))
590
- yield mapping
591
-
592
-
593
- def match_eqratio6_eqratio6_ncoll_simtri(
594
- g: gh.Graph,
595
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
596
- theorem: pr.Theorem,
597
- ) -> Generator[dict[str, gm.Point], None, None]:
598
- """Match eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C => simtri* A B C P Q R."""
599
- enums = g_matcher('eqratio6')
600
-
601
- record = set()
602
- for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable
603
- if (a, b, c) == (p, q, r):
604
- continue
605
- if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]):
606
- continue
607
- if not g.check_eqratio([c, a, c, b, r, p, r, q]):
608
- continue
609
- if not g.check_ncoll([a, b, c]):
610
- continue
611
-
612
- mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
613
- record.add((a, b, c, p, q, r))
614
- yield mapping
615
-
616
-
617
- def match_eqangle6_eqangle6_ncoll_simtri2(
618
- g: gh.Graph,
619
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
620
- theorem: pr.Theorem,
621
- ) -> Generator[dict[str, gm.Point], None, None]:
622
- """Match eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C => simtri2 A B C P Q R."""
623
- enums = g_matcher('eqangle6')
624
-
625
- record = set()
626
- for b, a, b, c, q, r, q, p in enums: # pylint: disable=redeclared-assigned-name,unused-variable
627
- if (a, b, c) == (p, q, r):
628
- continue
629
- if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]):
630
- continue
631
- if not g.check_eqangle([c, a, c, b, r, q, r, p]):
632
- continue
633
- if not g.check_ncoll([a, b, c]):
634
- continue
635
-
636
- mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
637
- record.add((a, b, c, p, q, r))
638
- yield mapping
639
-
640
-
641
- def rotate_contri(
642
- a: gm.Point, b: gm.Point, c: gm.Point, x: gm.Point, y: gm.Point, z: gm.Point
643
- ) -> Generator[tuple[gm.Point, ...], None, None]:
644
- for p in [(b, a, c, y, x, z), (x, y, z, a, b, c), (y, x, z, b, a, c)]:
645
- yield p
646
-
647
-
648
- def match_eqangle6_eqangle6_ncoll_cong_contri(
649
- g: gh.Graph,
650
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
651
- theorem: pr.Theorem,
652
- ) -> Generator[dict[str, gm.Point], None, None]:
653
- """Match eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri A B C P Q R."""
654
- enums = g_matcher('eqangle6')
655
-
656
- record = set()
657
- for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable
658
- if not g.check_cong([a, b, p, q]):
659
- continue
660
- if (a, b, c) == (p, q, r):
661
- continue
662
- if any([x in record for x in rotate_contri(a, b, c, p, q, r)]):
663
- continue
664
- if not g.check_eqangle([c, a, c, b, r, p, r, q]):
665
- continue
666
-
667
- if not g.check_ncoll([a, b, c]):
668
- continue
669
-
670
- mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
671
- record.add((a, b, c, p, q, r))
672
- yield mapping
673
-
674
-
675
- def match_eqratio6_eqratio6_ncoll_cong_contri(
676
- g: gh.Graph,
677
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
678
- theorem: pr.Theorem,
679
- ) -> Generator[dict[str, gm.Point], None, None]:
680
- """Match eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri* A B C P Q R."""
681
- enums = g_matcher('eqratio6')
682
-
683
- record = set()
684
- for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable
685
- if not g.check_cong([a, b, p, q]):
686
- continue
687
- if (a, b, c) == (p, q, r):
688
- continue
689
- if any([x in record for x in rotate_contri(a, b, c, p, q, r)]):
690
- continue
691
- if not g.check_eqratio([c, a, c, b, r, p, r, q]):
692
- continue
693
-
694
- if not g.check_ncoll([a, b, c]):
695
- continue
696
-
697
- mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
698
- record.add((a, b, c, p, q, r))
699
- yield mapping
700
-
701
-
702
- def match_eqangle6_eqangle6_ncoll_cong_contri2(
703
- g: gh.Graph,
704
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
705
- theorem: pr.Theorem,
706
- ) -> Generator[dict[str, gm.Point], None, None]:
707
- """Match eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C, cong A B P Q => contri2 A B C P Q R."""
708
- enums = g_matcher('eqangle6')
709
-
710
- record = set()
711
- for b, a, b, c, q, r, q, p in enums: # pylint: disable=redeclared-assigned-name,unused-variable
712
- if not g.check_cong([a, b, p, q]):
713
- continue
714
- if (a, b, c) == (p, q, r):
715
- continue
716
- if any([x in record for x in rotate_contri(a, b, c, p, q, r)]):
717
- continue
718
- if not g.check_eqangle([c, a, c, b, r, q, r, p]):
719
- continue
720
- if not g.check_ncoll([a, b, c]):
721
- continue
722
-
723
- mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
724
- record.add((a, b, c, p, q, r))
725
- yield mapping
726
-
727
-
728
- def match_eqratio6_coll_ncoll_eqangle6(
729
- g: gh.Graph,
730
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
731
- theorem: pr.Theorem,
732
- ) -> Generator[dict[str, gm.Point], None, None]:
733
- """Match eqratio6 d b d c a b a c, coll d b c, ncoll a b c => eqangle6 a b a d a d a c."""
734
- records = set()
735
- for b, d, c in g_matcher('coll'):
736
- for a in g.all_points():
737
- if g.check_coll([a, b, c]):
738
- continue
739
- if (a, b, d, c) in records or (a, c, d, b) in records:
740
- continue
741
- records.add((a, b, d, c))
742
-
743
- if g.check_eqratio([d, b, d, c, a, b, a, c]):
744
- yield dict(zip('abcd', [a, b, c, d]))
745
-
746
-
747
- def match_eqangle6_coll_ncoll_eqratio6(
748
- g: gh.Graph,
749
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
750
- theorem: pr.Theorem,
751
- ) -> Generator[dict[str, gm.Point], None, None]:
752
- """Match eqangle6 a b a d a d a c, coll d b c, ncoll a b c => eqratio6 d b d c a b a c."""
753
- records = set()
754
- for b, d, c in g_matcher('coll'):
755
- for a in g.all_points():
756
- if g.check_coll([a, b, c]):
757
- continue
758
- if (a, b, d, c) in records or (a, c, d, b) in records:
759
- continue
760
- records.add((a, b, d, c))
761
-
762
- if g.check_eqangle([a, b, a, d, a, d, a, c]):
763
- yield dict(zip('abcd', [a, b, c, d]))
764
-
765
-
766
- def match_eqangle6_ncoll_cyclic(
767
- g: gh.Graph,
768
- g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
769
- theorem: pr.Theorem,
770
- ) -> Generator[dict[str, gm.Point], None, None]:
771
- """Match eqangle6 P A P B Q A Q B, ncoll P Q A B => cyclic A B P Q."""
772
- for a, b, a, c, x, y, x, z in g_matcher('eqangle6'): # pylint: disable=redeclared-assigned-name,unused-variable
773
- if (b, c) != (y, z) or a == x:
774
- continue
775
- if nm.check_ncoll([x.num for x in [a, b, c, x]]):
776
- yield dict(zip('ABPQ', [b, c, a, x]))
777
-
778
-
779
- def match_all(
780
- name: str, g: gh.Graph
781
- ) -> Generator[tuple[gm.Point, ...], None, None]:
782
- """Match all instances of a certain relation."""
783
- if name in ['ncoll', 'npara', 'nperp']:
784
- return []
785
- if name == 'coll':
786
- return g.all_colls()
787
- if name == 'para':
788
- return g.all_paras()
789
- if name == 'perp':
790
- return g.all_perps()
791
- if name == 'cong':
792
- return g.all_congs()
793
- if name == 'eqangle':
794
- return g.all_eqangles_8points()
795
- if name == 'eqangle6':
796
- return g.all_eqangles_6points()
797
- if name == 'eqratio':
798
- return g.all_eqratios_8points()
799
- if name == 'eqratio6':
800
- return g.all_eqratios_6points()
801
- if name == 'cyclic':
802
- return g.all_cyclics()
803
- if name == 'midp':
804
- return g.all_midps()
805
- if name == 'circle':
806
- return g.all_circles()
807
- raise ValueError(f'Unrecognize {name}')
808
-
809
-
810
- def cache_match(
811
- graph: gh.Graph,
812
- ) -> Callable[str, list[tuple[gm.Point, ...]]]:
813
- """Cache throughout one single BFS level."""
814
- cache = {}
815
-
816
- def match_fn(name: str) -> list[tuple[gm.Point, ...]]:
817
- if name in cache:
818
- return cache[name]
819
-
820
- result = list(match_all(name, graph))
821
- cache[name] = result
822
- return result
823
-
824
- return match_fn
825
-
826
-
827
- def try_to_map(
828
- clause_enum: list[tuple[pr.Clause, list[tuple[gm.Point, ...]]]],
829
- mapping: dict[str, gm.Point],
830
- ) -> Generator[dict[str, gm.Point], None, None]:
831
- """Recursively try to match the remaining points given current mapping."""
832
- if not clause_enum:
833
- yield mapping
834
- return
835
-
836
- clause, enum = clause_enum[0]
837
- for points in enum:
838
- mpcpy = dict(mapping)
839
-
840
- fail = False
841
- for p, a in zip(points, clause.args):
842
- if a in mpcpy and mpcpy[a] != p or p in mpcpy and mpcpy[p] != a:
843
- fail = True
844
- break
845
- mpcpy[a] = p
846
- mpcpy[p] = a
847
-
848
- if fail:
849
- continue
850
-
851
- for m in try_to_map(clause_enum[1:], mpcpy):
852
- yield m
853
-
854
-
855
- def match_generic(
856
- g: gh.Graph,
857
- cache: Callable[str, list[tuple[gm.Point, ...]]],
858
- theorem: pr.Theorem
859
- ) -> Generator[dict[str, gm.Point], None, None]:
860
- """Match any generic rule that is not one of the above match_*() rules."""
861
- clause2enum = {}
862
-
863
- clauses = []
864
- numerical_checks = []
865
- for clause in theorem.premise:
866
- if clause.name in ['ncoll', 'npara', 'nperp', 'sameside']:
867
- numerical_checks.append(clause)
868
- continue
869
-
870
- enum = cache(clause.name)
871
- if len(enum) == 0: # pylint: disable=g-explicit-length-test
872
- return 0
873
-
874
- clause2enum[clause] = enum
875
- clauses.append((len(set(clause.args)), clause))
876
-
877
- clauses = sorted(clauses, key=lambda x: x[0], reverse=True)
878
- _, clauses = zip(*clauses)
879
-
880
- for mapping in try_to_map([(c, clause2enum[c]) for c in clauses], {}):
881
- if not mapping:
882
- continue
883
-
884
- checks_ok = True
885
- for check in numerical_checks:
886
- args = [mapping[a] for a in check.args]
887
- if check.name == 'ncoll':
888
- checks_ok = g.check_ncoll(args)
889
- elif check.name == 'npara':
890
- checks_ok = g.check_npara(args)
891
- elif check.name == 'nperp':
892
- checks_ok = g.check_nperp(args)
893
- elif check.name == 'sameside':
894
- checks_ok = g.check_sameside(args)
895
- if not checks_ok:
896
- break
897
- if not checks_ok:
898
- continue
899
-
900
- yield mapping
901
-
902
-
903
- BUILT_IN_FNS = {
904
- 'cong_cong_cong_cyclic': match_cong_cong_cong_cyclic,
905
- 'cong_cong_cong_ncoll_contri*': match_cong_cong_cong_ncoll_contri,
906
- 'cong_cong_eqangle6_ncoll_contri*': match_cong_cong_eqangle6_ncoll_contri,
907
- 'eqangle6_eqangle6_ncoll_simtri': match_eqangle6_eqangle6_ncoll_simtri,
908
- 'eqangle6_eqangle6_ncoll_cong_contri': (
909
- match_eqangle6_eqangle6_ncoll_cong_contri
910
- ), # pylint: disable=line-too-long
911
- 'eqangle6_eqangle6_ncoll_simtri2': match_eqangle6_eqangle6_ncoll_simtri2,
912
- 'eqangle6_eqangle6_ncoll_cong_contri2': (
913
- match_eqangle6_eqangle6_ncoll_cong_contri2
914
- ), # pylint: disable=line-too-long
915
- 'eqratio6_eqratio6_ncoll_simtri*': match_eqratio6_eqratio6_ncoll_simtri,
916
- 'eqratio6_eqratio6_ncoll_cong_contri*': (
917
- match_eqratio6_eqratio6_ncoll_cong_contri
918
- ), # pylint: disable=line-too-long
919
- 'eqangle_para': match_eqangle_para,
920
- 'eqangle_ncoll_cyclic': match_eqangle_ncoll_cyclic,
921
- 'eqratio6_eqangle6_ncoll_simtri*': match_eqratio6_eqangle6_ncoll_simtri,
922
- 'eqangle_perp_perp': match_eqangle_perp_perp,
923
- 'eqangle6_ncoll_cong': match_eqangle6_ncoll_cong,
924
- 'perp_perp_ncoll_para': match_perp_perp_ncoll_para,
925
- 'circle_perp_eqangle': match_circle_perp_eqangle,
926
- 'circle_eqangle_perp': match_circle_eqangle_perp,
927
- 'cyclic_eqangle_cong': match_cyclic_eqangle_cong,
928
- 'midp_perp_cong': match_midp_perp_cong,
929
- 'perp_perp_npara_eqangle': match_perp_perp_npara_eqangle,
930
- 'cyclic_eqangle': match_cyclic_eqangle,
931
- 'eqangle_eqangle_eqangle': match_eqangle_eqangle_eqangle,
932
- 'eqratio_eqratio_eqratio': match_eqratio_eqratio_eqratio,
933
- 'eqratio6_coll_ncoll_eqangle6': match_eqratio6_coll_ncoll_eqangle6,
934
- 'eqangle6_coll_ncoll_eqratio6': match_eqangle6_coll_ncoll_eqratio6,
935
- 'eqangle6_ncoll_cyclic': match_eqangle6_ncoll_cyclic,
936
- }
937
-
938
-
939
- SKIP_THEOREMS = set()
940
-
941
-
942
- def set_skip_theorems(theorems: set[str]) -> None:
943
- SKIP_THEOREMS.update(theorems)
944
-
945
-
946
- MAX_BRANCH = 50_000
947
-
948
-
949
- def match_one_theorem(
950
- g: gh.Graph,
951
- cache: Callable[str, list[tuple[gm.Point, ...]]],
952
- theorem: pr.Theorem
953
- ) -> Generator[dict[str, gm.Point], None, None]:
954
- """Match all instances of a single theorem (rule)."""
955
- if cache is None:
956
- cache = cache_match(g)
957
-
958
- if theorem.name in SKIP_THEOREMS:
959
- return []
960
-
961
- if theorem.name.split('_')[-1] in SKIP_THEOREMS:
962
- return []
963
-
964
- if theorem.name in BUILT_IN_FNS:
965
- mps = BUILT_IN_FNS[theorem.name](g, cache, theorem)
966
- else:
967
- mps = match_generic(g, cache, theorem)
968
-
969
- mappings = []
970
- for mp in mps:
971
- mappings.append(mp)
972
- if len(mappings) > MAX_BRANCH: # cap branching at this number.
973
- break
974
-
975
- return mappings
976
-
977
-
978
- def match_all_theorems(
979
- g: gh.Graph, theorems: list[pr.Theorem], goal: pr.Clause
980
- ) -> dict[pr.Theorem, dict[pr.Theorem, dict[str, gm.Point]]]:
981
- """Match all instances of all theorems (rules)."""
982
- cache = cache_match(g)
983
- # for BFS, collect all potential matches
984
- # and then do it at the same time
985
- theorem2mappings = {}
986
-
987
- # Step 1: list all matches
988
- for _, theorem in theorems.items():
989
- name = theorem.name
990
- if name.split('_')[-1] in [
991
- 'acompute',
992
- 'rcompute',
993
- 'fixl',
994
- 'fixc',
995
- 'fixb',
996
- 'fixt',
997
- 'fixp',
998
- ]:
999
- if goal and goal.name != name:
1000
- continue
1001
-
1002
- mappings = match_one_theorem(g, cache, theorem)
1003
- if len(mappings): # pylint: disable=g-explicit-length-test
1004
- theorem2mappings[theorem] = list(mappings)
1005
- return theorem2mappings
1006
-
1007
-
1008
- def bfs_one_level(
1009
- g: gh.Graph,
1010
- theorems: list[pr.Theorem],
1011
- level: int,
1012
- controller: pr.Problem,
1013
- verbose: bool = False,
1014
- nm_check: bool = False,
1015
- timeout: int = 600,
1016
- ) -> tuple[
1017
- list[pr.Dependency],
1018
- dict[str, list[tuple[gm.Point, ...]]],
1019
- dict[str, list[tuple[gm.Point, ...]]],
1020
- int,
1021
- ]:
1022
- """Forward deduce one breadth-first level."""
1023
-
1024
- # Step 1: match all theorems:
1025
- theorem2mappings = match_all_theorems(g, theorems, controller.goal)
1026
-
1027
- # Step 2: traceback for each deduce:
1028
- theorem2deps = {}
1029
- t0 = time.time()
1030
- for theorem, mappings in theorem2mappings.items():
1031
- if time.time() - t0 > timeout:
1032
- break
1033
- mp_deps = []
1034
- for mp in mappings:
1035
- deps = EmptyDependency(level=level, rule_name=theorem.rule_name)
1036
- fail = False # finding why deps might fail.
1037
-
1038
- for p in theorem.premise:
1039
- p_args = [mp[a] for a in p.args]
1040
- # Trivial deps.
1041
- if p.name == 'cong':
1042
- a, b, c, d = p_args
1043
- if {a, b} == {c, d}:
1044
- continue
1045
- if p.name == 'para':
1046
- a, b, c, d = p_args
1047
- if {a, b} == {c, d}:
1048
- continue
1049
-
1050
- if theorem.name in [
1051
- 'cong_cong_eqangle6_ncoll_contri*',
1052
- 'eqratio6_eqangle6_ncoll_simtri*',
1053
- ]:
1054
- if p.name in ['eqangle', 'eqangle6']: # SAS or RAR
1055
- b, a, b, c, y, x, y, z = ( # pylint: disable=redeclared-assigned-name,unused-variable
1056
- p_args
1057
- )
1058
- if not nm.same_clock(a.num, b.num, c.num, x.num, y.num, z.num):
1059
- p_args = b, a, b, c, y, z, y, x
1060
-
1061
- dep = Dependency(p.name, p_args, rule_name='', level=level)
1062
- try:
1063
- dep = dep.why_me_or_cache(g, level)
1064
- except: # pylint: disable=bare-except
1065
- fail = True
1066
- break
1067
-
1068
- if dep.why is None:
1069
- fail = True
1070
- break
1071
- g.cache_dep(p.name, p_args, dep)
1072
- deps.why.append(dep)
1073
-
1074
- if fail:
1075
- continue
1076
-
1077
- mp_deps.append((mp, deps))
1078
- theorem2deps[theorem] = mp_deps
1079
-
1080
- theorem2deps = list(theorem2deps.items())
1081
-
1082
- # Step 3: add conclusions to graph.
1083
- # Note that we do NOT mix step 2 and 3, strictly going for BFS.
1084
- added = []
1085
- for theorem, mp_deps in theorem2deps:
1086
- for mp, deps in mp_deps:
1087
- if time.time() - t0 > timeout:
1088
- break
1089
- name, args = theorem.conclusion_name_args(mp)
1090
- hash_conclusion = pr.hashed(name, args)
1091
- if hash_conclusion in g.cache:
1092
- continue
1093
-
1094
- add = g.add_piece(name, args, deps=deps)
1095
- added += add
1096
-
1097
- branching = len(added)
1098
-
1099
- # Check if goal is found
1100
- if controller.goal:
1101
- args = []
1102
-
1103
- for a in controller.goal.args:
1104
- if a in g._name2node:
1105
- a = g._name2node[a]
1106
- elif '/' in a:
1107
- a = create_consts_str(g, a)
1108
- elif a.isdigit():
1109
- a = int(a)
1110
- args.append(a)
1111
-
1112
- if g.check(controller.goal.name, args):
1113
- return added, {}, {}, branching
1114
-
1115
- # Run AR, but do NOT apply to the proof state (yet).
1116
- for dep in added:
1117
- g.add_algebra(dep, level)
1118
- derives, eq4s = g.derive_algebra(level, verbose=False)
1119
-
1120
- branching += sum([len(x) for x in derives.values()])
1121
- branching += sum([len(x) for x in eq4s.values()])
1122
-
1123
- return added, derives, eq4s, branching
1124
-
1125
-
1126
- def create_consts_str(g: gh.Graph, s: str) -> gm.Angle | gm.Ratio:
1127
- if 'pi/' in s:
1128
- n, d = s.split('pi/')
1129
- n, d = int(n), int(d)
1130
- p0, _ = g.get_or_create_const_ang(n, d)
1131
- else:
1132
- n, d = s.split('/')
1133
- n, d = int(n), int(d)
1134
- p0, _ = g.get_or_create_const_rat(n, d)
1135
- return p0
1136
-
1137
-
1138
- def do_algebra(
1139
- g: gh.Graph, added: list[pr.Dependency], verbose: bool = False
1140
- ) -> None:
1141
- for add in added:
1142
- g.add_algebra(add, None)
1143
- derives, eq4s = g.derive_algebra(level=None, verbose=verbose)
1144
- apply_derivations(g, derives)
1145
- apply_derivations(g, eq4s)
1146
-
1147
-
1148
- def apply_derivations(
1149
- g: gh.Graph, derives: dict[str, list[tuple[gm.Point, ...]]]
1150
- ) -> list[pr.Dependency]:
1151
- applied = []
1152
- all_derives = list(derives.items())
1153
- for name, args in all_derives:
1154
- for arg in args:
1155
- applied += g.do_algebra(name, arg)
1156
- return applied
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Implements Deductive Database (DD)."""
17
+
18
+ # pylint: disable=g-multiple-import,g-importing-member
19
+ from collections import defaultdict
20
+ import time
21
+ from typing import Any, Callable, Generator
22
+
23
+ import geometry as gm
24
+ import graph as gh
25
+ import graph_utils as utils
26
+ import numericals as nm
27
+ import problem as pr
28
+ from problem import Dependency, EmptyDependency
29
+ from typing import Union
30
+
31
+
32
+ def intersect1(set1: set[Any], set2: set[Any]) -> Any:
33
+ for x in set1:
34
+ if x in set2:
35
+ return x
36
+ return None
37
+
38
+
39
+ def diff_point(l: gm.Line, a: gm.Point) -> gm.Point:
40
+ for x in l.neighbors(gm.Point):
41
+ if x != a :
42
+ return x
43
+ return None
44
+
45
+
46
+ # pylint: disable=protected-access
47
+ # pylint: disable=unused-argument
48
+
49
+
50
+ def match_eqratio_eqratio_eqratio(
51
+ g: gh.Graph,
52
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
53
+ theorem: pr.Theorem,
54
+ ) -> Generator[dict[str, gm.Point], None, None]:
55
+ """Match eqratio a b c d m n p q, eqratio c d e f p q r u => eqratio a b e f m n r u."""
56
+ for m1 in g.type2nodes[gm.Value]:
57
+ for m2 in g.type2nodes[gm.Value]:
58
+ rats1 = []
59
+ for rat in m1.neighbors(gm.Ratio):
60
+ l1, l2 = rat.lengths
61
+ if l1 is None or l2 is None:
62
+ continue
63
+ rats1.append((l1, l2))
64
+
65
+ rats2 = []
66
+ for rat in m2.neighbors(gm.Ratio):
67
+ l1, l2 = rat.lengths
68
+ if l1 is None or l2 is None:
69
+ continue
70
+ rats2.append((l1, l2))
71
+
72
+ pairs = []
73
+ for (l1, l2), (l3, l4) in utils.cross(rats1, rats2):
74
+ if l2 == l3:
75
+ pairs.append((l1, l2, l4))
76
+
77
+ for (l1, l12, l2), (l3, l34, l4) in utils.comb2(pairs):
78
+ if (l1, l12, l2) == (l3, l34, l4):
79
+ continue
80
+ if l1 == l2 or l3 == l4:
81
+ continue
82
+ if l1 == l12 or l12 == l2 or l3 == l34 or l4 == l34:
83
+ continue
84
+ # d12 - d1 = d34 - d3 = m1
85
+ # d2 - d12 = d4 - d34 = m2
86
+ # => d2 - d1 = d4 - d3 (= m1+m2)
87
+ a, b = g.two_points_of_length(l1)
88
+ c, d = g.two_points_of_length(l12)
89
+ m, n = g.two_points_of_length(l3)
90
+ p, q = g.two_points_of_length(l34)
91
+ # eqangle a b c d m n p q
92
+ e, f = g.two_points_of_length(l2)
93
+ r, u = g.two_points_of_length(l4)
94
+ yield dict(zip('abcdefmnpqru', [a, b, c, d, e, f, m, n, p, q, r, u]))
95
+
96
+
97
+ def match_eqangle_eqangle_eqangle(
98
+ g: gh.Graph,
99
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
100
+ theorem: pr.Theorem,
101
+ ) -> Generator[dict[str, gm.Point], None, None]:
102
+ """Match eqangle a b c d m n p q, eqangle c d e f p q r u => eqangle a b e f m n r u."""
103
+ for m1 in g.type2nodes[gm.Measure]:
104
+ for m2 in g.type2nodes[gm.Measure]:
105
+ angs1 = []
106
+ for ang in m1.neighbors(gm.Angle):
107
+ d1, d2 = ang.directions
108
+ if d1 is None or d2 is None:
109
+ continue
110
+ angs1.append((d1, d2))
111
+
112
+ angs2 = []
113
+ for ang in m2.neighbors(gm.Angle):
114
+ d1, d2 = ang.directions
115
+ if d1 is None or d2 is None:
116
+ continue
117
+ angs2.append((d1, d2))
118
+
119
+ pairs = []
120
+ for (d1, d2), (d3, d4) in utils.cross(angs1, angs2):
121
+ if d2 == d3:
122
+ pairs.append((d1, d2, d4))
123
+
124
+ for (d1, d12, d2), (d3, d34, d4) in utils.comb2(pairs):
125
+ if (d1, d12, d2) == (d3, d34, d4):
126
+ continue
127
+ if d1 == d2 or d3 == d4:
128
+ continue
129
+ if d1 == d12 or d12 == d2 or d3 == d34 or d4 == d34:
130
+ continue
131
+ # d12 - d1 = d34 - d3 = m1
132
+ # d2 - d12 = d4 - d34 = m2
133
+ # => d2 - d1 = d4 - d3
134
+ a, b = g.two_points_on_direction(d1)
135
+ c, d = g.two_points_on_direction(d12)
136
+ m, n = g.two_points_on_direction(d3)
137
+ p, q = g.two_points_on_direction(d34)
138
+ # eqangle a b c d m n p q
139
+ e, f = g.two_points_on_direction(d2)
140
+ r, u = g.two_points_on_direction(d4)
141
+ yield dict(zip('abcdefmnpqru', [a, b, c, d, e, f, m, n, p, q, r, u]))
142
+
143
+
144
+ def match_perp_perp_npara_eqangle(
145
+ g: gh.Graph,
146
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
147
+ theorem: pr.Theorem,
148
+ ) -> Generator[dict[str, gm.Point], None, None]:
149
+ """Match perp A B C D, perp E F G H, npara A B E F => eqangle A B E F C D G H."""
150
+ dpairs = []
151
+ for ang in g.vhalfpi.neighbors(gm.Angle):
152
+ d1, d2 = ang.directions
153
+ if d1 is None or d2 is None:
154
+ continue
155
+ dpairs.append((d1, d2))
156
+
157
+ for (d1, d2), (d3, d4) in utils.comb2(dpairs):
158
+ a, b = g.two_points_on_direction(d1)
159
+ c, d = g.two_points_on_direction(d2)
160
+ m, n = g.two_points_on_direction(d3)
161
+ p, q = g.two_points_on_direction(d4)
162
+ if g.check_npara([a, b, m, n]):
163
+ if ({a, b}, {c, d}) == ({m, n}, {p, q}):
164
+ continue
165
+ if ({a, b}, {c, d}) == ({p, q}, {m, n}):
166
+ continue
167
+
168
+ yield dict(zip('ABCDEFGH', [a, b, c, d, m, n, p, q]))
169
+
170
+
171
+ def match_circle_coll_eqangle_midp(
172
+ g: gh.Graph,
173
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
174
+ theorem: pr.Theorem,
175
+ ) -> Generator[dict[str, gm.Point], None, None]:
176
+ """Match circle O A B C, coll M B C, eqangle A B A C O B O M => midp M B C."""
177
+ for p, a, b, c in g.all_circles():
178
+ ab = g._get_line(a, b)
179
+ if ab is None:
180
+ continue
181
+ if ab.val is None:
182
+ continue
183
+ ac = g._get_line(a, c)
184
+ if ac is None:
185
+ continue
186
+ if ac.val is None:
187
+ continue
188
+ pb = g._get_line(p, b)
189
+ if pb is None:
190
+ continue
191
+ if pb.val is None:
192
+ continue
193
+
194
+ bc = g._get_line(b, c)
195
+ if bc is None:
196
+ continue
197
+ bc_points = bc.neighbors(gm.Point, return_set=True)
198
+
199
+ anga, _ = g._get_angle(ab.val, ac.val)
200
+
201
+ for angp in pb.val.neighbors(gm.Angle):
202
+ if not g.is_equal(anga, angp):
203
+ continue
204
+
205
+ _, d = angp.directions
206
+ for l in d.neighbors(gm.Line):
207
+ l_points = l.neighbors(gm.Point, return_set=True)
208
+ m = intersect1(bc_points, l_points)
209
+ if m is not None:
210
+ yield dict(zip('ABCMO', [a, b, c, m, p]))
211
+
212
+
213
+ def match_midp_perp_cong(
214
+ g: gh.Graph,
215
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
216
+ theorem: pr.Theorem,
217
+ ) -> Generator[dict[str, gm.Point], None, None]:
218
+ """Match midp M A B, perp O M A B => cong O A O B."""
219
+ for m, a, b in g.all_midps():
220
+ ab = g._get_line(a, b)
221
+ for l in m.neighbors(gm.Line):
222
+ if g.check_perpl(l, ab):
223
+ for o in l.neighbors(gm.Point):
224
+ if o != m:
225
+ yield dict(zip('ABMO', [a, b, m, o]))
226
+
227
+
228
+ def match_cyclic_eqangle_cong(
229
+ g: gh.Graph,
230
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
231
+ theorem: pr.Theorem,
232
+ ) -> Generator[dict[str, gm.Point], None, None]:
233
+ """Match cyclic A B C P Q R, eqangle C A C B R P R Q => cong A B P Q."""
234
+ for c in g.type2nodes[gm.Circle]:
235
+ ps = c.neighbors(gm.Point)
236
+ for (a, b, c), (x, y, z) in utils.comb2(list(utils.perm3(ps))):
237
+ if {a, b, c} == {x, y, z}:
238
+ continue
239
+ if g.check_eqangle([c, a, c, b, z, x, z, y]):
240
+ yield dict(zip('ABCPQR', [a, b, c, x, y, z]))
241
+
242
+
243
+ def match_circle_eqangle_perp(
244
+ g: gh.Graph,
245
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
246
+ theorem: pr.Theorem,
247
+ ) -> Generator[dict[str, gm.Point], None, None]:
248
+ """Match circle O A B C, eqangle A X A B C A C B => perp O A A X."""
249
+ for p, a, b, c in g.all_circles():
250
+ ca = g._get_line(c, a)
251
+ if ca is None:
252
+ continue
253
+ cb = g._get_line(c, b)
254
+ if cb is None:
255
+ continue
256
+ ab = g._get_line(a, b)
257
+ if ab is None:
258
+ continue
259
+
260
+ if ca.val is None:
261
+ continue
262
+ if cb.val is None:
263
+ continue
264
+ if ab.val is None:
265
+ continue
266
+
267
+ c_ang, _ = g._get_angle(cb.val, ca.val)
268
+ if c_ang is None:
269
+ continue
270
+
271
+ for ang in ab.val.neighbors(gm.Angle):
272
+ if g.is_equal(ang, c_ang):
273
+ _, d = ang.directions
274
+ for l in d.neighbors(gm.Line):
275
+ if a not in l.neighbors(gm.Point):
276
+ continue
277
+ x = diff_point(l, a)
278
+ if x is None:
279
+ continue
280
+ yield dict(zip('OABCX', [p, a, b, c, x]))
281
+ break
282
+
283
+
284
+ def match_circle_perp_eqangle(
285
+ g: gh.Graph,
286
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
287
+ theorem: pr.Theorem,
288
+ ) -> Generator[dict[str, gm.Point], None, None]:
289
+ """Match circle O A B C, perp O A A X => eqangle A X A B C A C B."""
290
+ for p, a, b, c in g.all_circles():
291
+ pa = g._get_line(p, a)
292
+ if pa is None:
293
+ continue
294
+ if pa.val is None:
295
+ continue
296
+ for l in a.neighbors(gm.Line):
297
+ if g.check_perpl(pa, l):
298
+ x = diff_point(l, a)
299
+ if x is not None:
300
+ yield dict(zip('OABCX', [p, a, b, c, x]))
301
+
302
+ def match_semicircle_eqangle_perp(
303
+ g: gh.Graph,
304
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
305
+ theorem: pr.Theorem,
306
+ ) -> Generator[dict[str, gm.Point], None, None]:
307
+ """Match circle O A B C, eqangle A X A B C A C B => perp O A A X."""
308
+ for p, a, b, c in g.all_circles():
309
+ ca = g._get_line(c, a)
310
+ if ca is None:
311
+ continue
312
+ cb = g._get_line(c, b)
313
+ if cb is None:
314
+ continue
315
+ ab = g._get_line(a, b)
316
+ if ab is None:
317
+ continue
318
+
319
+ if ca.val is None:
320
+ continue
321
+ if cb.val is None:
322
+ continue
323
+ if ab.val is None:
324
+ continue
325
+
326
+ c_ang, _ = g._get_angle(cb.val, ca.val)
327
+ if c_ang is None:
328
+ continue
329
+
330
+ for ang in ab.val.neighbors(gm.Angle):
331
+ if g.is_equal(ang, c_ang):
332
+ _, d = ang.directions
333
+ for l in d.neighbors(gm.Line):
334
+ if a not in l.neighbors(gm.Point):
335
+ continue
336
+ x = diff_point(l, a)
337
+ if x is None:
338
+ continue
339
+ yield dict(zip('OABCX', [p, a, b, c, x]))
340
+ break
341
+
342
+
343
+ def match_semicircle_perp_eqangle(
344
+ g: gh.Graph,
345
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
346
+ theorem: pr.Theorem,
347
+ ) -> Generator[dict[str, gm.Point], None, None]:
348
+ """Match semicircle O A B, perp O A A X => eqangle A X A B O A B."""
349
+ for o, a, b, c in g.all_semicircles():
350
+ oa = g._get_line(o, a)
351
+ if oa is None:
352
+ continue
353
+ if oa.val is None:
354
+ continue
355
+ for l in a.neighbors(gm.Line):
356
+ if g.check_perpl(oa, l):
357
+ x = diff_point(l, a)
358
+ if x is not None:
359
+ yield dict(zip('OABCX', [o, a, b, c, x]))
360
+
361
+
362
+ def match_perp_perp_ncoll_para(
363
+ g: gh.Graph,
364
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
365
+ theorem: pr.Theorem,
366
+ ) -> Generator[dict[str, gm.Point], None, None]:
367
+ """Match perp A B C D, perp C D E F, ncoll A B E => para A B E F."""
368
+ d2d = defaultdict(list)
369
+ for ang in g.vhalfpi.neighbors(gm.Angle):
370
+ d1, d2 = ang.directions
371
+ if d1 is None or d2 is None:
372
+ continue
373
+ d2d[d1] += [d2]
374
+ d2d[d2] += [d1]
375
+
376
+ for x, ys in d2d.items():
377
+ if len(ys) < 2:
378
+ continue
379
+ c, d = g.two_points_on_direction(x)
380
+ for y1, y2 in utils.comb2(ys):
381
+ a, b = g.two_points_on_direction(y1)
382
+ e, f = g.two_points_on_direction(y2)
383
+ if nm.check_ncoll([a.num, b.num, e.num]):
384
+ yield dict(zip('ABCDEF', [a, b, c, d, e, f]))
385
+
386
+
387
+ def match_eqangle6_ncoll_cong(
388
+ g: gh.Graph,
389
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
390
+ theorem: pr.Theorem,
391
+ ) -> Generator[dict[str, gm.Point], None, None]:
392
+ """Match eqangle6 A O A B B A B O, ncoll O A B => cong O A O B."""
393
+ for a in g.type2nodes[gm.Point]:
394
+ for b, c in utils.comb2(g.type2nodes[gm.Point]):
395
+ if a == b or a == c:
396
+ continue
397
+ if g.check_eqangle([b, a, b, c, c, b, c, a]):
398
+ if g.check_ncoll([a, b, c]):
399
+ yield dict(zip('OAB', [a, b, c]))
400
+
401
+
402
+ def match_eqangle_perp_perp(
403
+ g: gh.Graph,
404
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
405
+ theorem: pr.Theorem,
406
+ ) -> Generator[dict[str, gm.Point], None, None]:
407
+ """Match eqangle A B P Q C D U V, perp P Q U V => perp A B C D."""
408
+ for ang in g.vhalfpi.neighbors(gm.Angle):
409
+ # d1 perp d2
410
+ d1, d2 = ang.directions
411
+ if d1 is None or d2 is None:
412
+ continue
413
+ for d3, d4 in utils.comb2(g.type2nodes[gm.Direction]):
414
+ if d1 == d3 or d2 == d4:
415
+ continue
416
+ # if d1 - d3 = d2 - d4 => d3 perp d4
417
+ a13, a31 = g._get_angle(d1, d3)
418
+ a24, a42 = g._get_angle(d2, d4)
419
+ if a13 is None or a31 is None or a24 is None or a42 is None:
420
+ continue
421
+ if g.is_equal(a13, a24) and g.is_equal(a31, a42):
422
+ a, b = g.two_points_on_direction(d1)
423
+ c, d = g.two_points_on_direction(d2)
424
+ m, n = g.two_points_on_direction(d3)
425
+ p, q = g.two_points_on_direction(d4)
426
+ yield dict(zip('ABCDPQUV', [m, n, p, q, a, b, c, d]))
427
+
428
+
429
+ def match_eqangle_ncoll_cyclic(
430
+ g: gh.Graph,
431
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
432
+ theorem: pr.Theorem,
433
+ ) -> Generator[dict[str, gm.Point], None, None]:
434
+ """Match eqangle6 P A P B Q A Q B, ncoll P Q A B => cyclic A B P Q."""
435
+ for l1, l2, l3, l4 in g.all_eqangles_distinct_linepairss():
436
+ if len(set([l1, l2, l3, l4])) < 4:
437
+ continue # they all must be distinct.
438
+
439
+ p1s = l1.neighbors(gm.Point, return_set=True)
440
+ p2s = l2.neighbors(gm.Point, return_set=True)
441
+ p3s = l3.neighbors(gm.Point, return_set=True)
442
+ p4s = l4.neighbors(gm.Point, return_set=True)
443
+
444
+ p = intersect1(p1s, p2s)
445
+ if not p:
446
+ continue
447
+ q = intersect1(p3s, p4s)
448
+ if not q:
449
+ continue
450
+ a = intersect1(p1s, p3s)
451
+ if not a:
452
+ continue
453
+ b = intersect1(p2s, p4s)
454
+ if not b:
455
+ continue
456
+ if len(set([a, b, p, q])) < 4:
457
+ continue
458
+
459
+ if not g.check_ncoll([a, b, p, q]):
460
+ continue
461
+
462
+ yield dict(zip('ABPQ', [a, b, p, q]))
463
+
464
+
465
+ def match_eqangle_para(
466
+ g: gh.Graph,
467
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
468
+ theorem: pr.Theorem,
469
+ ) -> Generator[dict[str, gm.Point], None, None]:
470
+ """Match eqangle A B P Q C D P Q => para A B C D."""
471
+ for measure in g.type2nodes[gm.Measure]:
472
+ angs = measure.neighbors(gm.Angle)
473
+ d12, d21 = defaultdict(list), defaultdict(list)
474
+ for ang in angs:
475
+ d1, d2 = ang.directions
476
+ if d1 is None or d2 is None:
477
+ continue
478
+ d12[d1].append(d2)
479
+ d21[d2].append(d1)
480
+
481
+ for d1, d2s in d12.items():
482
+ a, b = g.two_points_on_direction(d1)
483
+ for d2, d3 in utils.comb2(d2s):
484
+ c, d = g.two_points_on_direction(d2)
485
+ e, f = g.two_points_on_direction(d3)
486
+ yield dict(zip('ABCDPQ', [c, d, e, f, a, b]))
487
+
488
+
489
+ def match_cyclic_eqangle(
490
+ g: gh.Graph,
491
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
492
+ theorem: pr.Theorem,
493
+ ) -> Generator[dict[str, gm.Point], None, None]:
494
+ """Match cyclic A B P Q => eqangle P A P B Q A Q B."""
495
+ record = set()
496
+ for a, b, c, d in g_matcher('cyclic'):
497
+ if (a, b, c, d) in record:
498
+ continue
499
+ record.add((a, b, c, d))
500
+ record.add((a, b, d, c))
501
+ record.add((b, a, c, d))
502
+ record.add((b, a, d, c))
503
+ yield dict(zip('ABPQ', [a, b, c, d]))
504
+
505
+
506
+ def rotate_simtri(
507
+ a: gm.Point, b: gm.Point, c: gm.Point, x: gm.Point, y: gm.Point, z: gm.Point
508
+ ) -> Generator[tuple[gm.Point, ...], None, None]:
509
+ """Rotate points around for similar triangle predicates."""
510
+ yield (z, y, x, c, b, a)
511
+ for p in [
512
+ (b, c, a, y, z, x),
513
+ (c, a, b, z, x, y),
514
+ (x, y, z, a, b, c),
515
+ (y, z, x, b, c, a),
516
+ (z, x, y, c, a, b),
517
+ ]:
518
+ yield p
519
+ yield p[::-1]
520
+
521
+
522
+ def match_cong_cong_cong_cyclic(
523
+ g: gh.Graph,
524
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
525
+ theorem: pr.Theorem,
526
+ ) -> Generator[dict[str, gm.Point], None, None]:
527
+ """Match cong O A O B, cong O B O C, cong O C O D => cyclic A B C D."""
528
+ for l in g.type2nodes[gm.Length]:
529
+ p2p = defaultdict(list)
530
+ for s in l.neighbors(gm.Segment):
531
+ a, b = s.points
532
+ p2p[a].append(b)
533
+ p2p[b].append(a)
534
+
535
+ for p, ps in p2p.items():
536
+ if len(ps) >= 4:
537
+ for a, b, c, d in utils.comb4(ps):
538
+ yield dict(zip('OABCD', [p, a, b, c, d]))
539
+
540
+
541
+ def match_cong_cong_cong_ncoll_contri(
542
+ g: gh.Graph,
543
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
544
+ theorem: pr.Theorem,
545
+ ) -> Generator[dict[str, gm.Point], None, None]:
546
+ """Match cong A B P Q, cong B C Q R, cong C A R P, ncoll A B C => contri* A B C P Q R."""
547
+ record = set()
548
+ for a, b, p, q in g_matcher('cong'):
549
+ for c in g.type2nodes[gm.Point]:
550
+ for r in g.type2nodes[gm.Point]:
551
+ if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]):
552
+ continue
553
+ if not g.check_ncoll([a, b, c]):
554
+ continue
555
+ if g.check_cong([b, c, q, r]) and g.check_cong([c, a, r, p]):
556
+ record.add((a, b, c, p, q, r))
557
+ yield dict(zip('ABCPQR', [a, b, c, p, q, r]))
558
+
559
+
560
+ def match_cong_cong_eqangle6_ncoll_contri(
561
+ g: gh.Graph,
562
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
563
+ theorem: pr.Theorem,
564
+ ) -> Generator[dict[str, gm.Point], None, None]:
565
+ """Match cong A B P Q, cong B C Q R, eqangle6 B A B C Q P Q R, ncoll A B C => contri* A B C P Q R."""
566
+ record = set()
567
+ for a, b, p, q in g_matcher('cong'):
568
+ for c in g.type2nodes[gm.Point]:
569
+ if c in (a, b):
570
+ continue
571
+ for r in g.type2nodes[gm.Point]:
572
+ if r in (p, q):
573
+ continue
574
+
575
+ in_record = False
576
+ for x in [
577
+ (c, b, a, r, q, p),
578
+ (p, q, r, a, b, c),
579
+ (r, q, p, c, b, a),
580
+ ]:
581
+ if x in record:
582
+ in_record = True
583
+ break
584
+
585
+ if in_record:
586
+ continue
587
+
588
+ if not g.check_cong([b, c, q, r]):
589
+ continue
590
+ if not g.check_ncoll([a, b, c]):
591
+ continue
592
+
593
+ if nm.same_clock(a.num, b.num, c.num, p.num, q.num, r.num):
594
+ if g.check_eqangle([b, a, b, c, q, p, q, r]):
595
+ record.add((a, b, c, p, q, r))
596
+ yield dict(zip('ABCPQR', [a, b, c, p, q, r]))
597
+ else:
598
+ if g.check_eqangle([b, a, b, c, q, r, q, p]):
599
+ record.add((a, b, c, p, q, r))
600
+ yield dict(zip('ABCPQR', [a, b, c, p, q, r]))
601
+
602
+
603
+ def match_eqratio6_eqangle6_ncoll_simtri(
604
+ g: gh.Graph,
605
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
606
+ theorem: pr.Theorem,
607
+ ) -> Generator[dict[str, gm.Point], None, None]:
608
+ """Match eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C => simtri* A B C P Q R."""
609
+ enums = g_matcher('eqratio6')
610
+
611
+ record = set()
612
+ for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable
613
+ if (a, b, c) == (p, q, r):
614
+ continue
615
+ if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]):
616
+ continue
617
+ if not g.check_ncoll([a, b, c]):
618
+ continue
619
+
620
+ if nm.same_clock(a.num, b.num, c.num, p.num, q.num, r.num):
621
+ if g.check_eqangle([b, a, b, c, q, p, q, r]):
622
+ record.add((a, b, c, p, q, r))
623
+ yield dict(zip('ABCPQR', [a, b, c, p, q, r]))
624
+ elif g.check_eqangle([b, a, b, c, q, r, q, p]):
625
+ record.add((a, b, c, p, q, r))
626
+ yield dict(zip('ABCPQR', [a, b, c, p, q, r]))
627
+
628
+
629
+ def match_eqangle6_eqangle6_ncoll_simtri(
630
+ g: gh.Graph,
631
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
632
+ theorem: pr.Theorem,
633
+ ) -> Generator[dict[str, gm.Point], None, None]:
634
+ """Match eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C => simtri A B C P Q R."""
635
+ enums = g_matcher('eqangle6')
636
+
637
+ record = set()
638
+ for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable
639
+ if (a, b, c) == (p, q, r):
640
+ continue
641
+ if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]):
642
+ continue
643
+ if not g.check_eqangle([c, a, c, b, r, p, r, q]):
644
+ continue
645
+ if not g.check_ncoll([a, b, c]):
646
+ continue
647
+
648
+ mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
649
+ record.add((a, b, c, p, q, r))
650
+ yield mapping
651
+
652
+
653
+ def match_eqratio6_eqratio6_ncoll_simtri(
654
+ g: gh.Graph,
655
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
656
+ theorem: pr.Theorem,
657
+ ) -> Generator[dict[str, gm.Point], None, None]:
658
+ """Match eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C => simtri* A B C P Q R."""
659
+ enums = g_matcher('eqratio6')
660
+
661
+ record = set()
662
+ for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable
663
+ if (a, b, c) == (p, q, r):
664
+ continue
665
+ if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]):
666
+ continue
667
+ if not g.check_eqratio([c, a, c, b, r, p, r, q]):
668
+ continue
669
+ if not g.check_ncoll([a, b, c]):
670
+ continue
671
+
672
+ mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
673
+ record.add((a, b, c, p, q, r))
674
+ yield mapping
675
+
676
+
677
+ def match_eqangle6_eqangle6_ncoll_simtri2(
678
+ g: gh.Graph,
679
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
680
+ theorem: pr.Theorem,
681
+ ) -> Generator[dict[str, gm.Point], None, None]:
682
+ """Match eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C => simtri2 A B C P Q R."""
683
+ enums = g_matcher('eqangle6')
684
+
685
+ record = set()
686
+ for b, a, b, c, q, r, q, p in enums: # pylint: disable=redeclared-assigned-name,unused-variable
687
+ if (a, b, c) == (p, q, r):
688
+ continue
689
+ if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]):
690
+ continue
691
+ if not g.check_eqangle([c, a, c, b, r, q, r, p]):
692
+ continue
693
+ if not g.check_ncoll([a, b, c]):
694
+ continue
695
+
696
+ mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
697
+ record.add((a, b, c, p, q, r))
698
+ yield mapping
699
+
700
+
701
+ def rotate_contri(
702
+ a: gm.Point, b: gm.Point, c: gm.Point, x: gm.Point, y: gm.Point, z: gm.Point
703
+ ) -> Generator[tuple[gm.Point, ...], None, None]:
704
+ for p in [(b, a, c, y, x, z), (x, y, z, a, b, c), (y, x, z, b, a, c)]:
705
+ yield p
706
+
707
+
708
+ def match_eqangle6_eqangle6_ncoll_cong_contri(
709
+ g: gh.Graph,
710
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
711
+ theorem: pr.Theorem,
712
+ ) -> Generator[dict[str, gm.Point], None, None]:
713
+ """Match eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri A B C P Q R."""
714
+ enums = g_matcher('eqangle6')
715
+
716
+ record = set()
717
+ for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable
718
+ if not g.check_cong([a, b, p, q]):
719
+ continue
720
+ if (a, b, c) == (p, q, r):
721
+ continue
722
+ if any([x in record for x in rotate_contri(a, b, c, p, q, r)]):
723
+ continue
724
+ if not g.check_eqangle([c, a, c, b, r, p, r, q]):
725
+ continue
726
+
727
+ if not g.check_ncoll([a, b, c]):
728
+ continue
729
+
730
+ mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
731
+ record.add((a, b, c, p, q, r))
732
+ yield mapping
733
+
734
+
735
+ def match_eqratio6_eqratio6_ncoll_cong_contri(
736
+ g: gh.Graph,
737
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
738
+ theorem: pr.Theorem,
739
+ ) -> Generator[dict[str, gm.Point], None, None]:
740
+ """Match eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri* A B C P Q R."""
741
+ enums = g_matcher('eqratio6')
742
+
743
+ record = set()
744
+ for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable
745
+ if not g.check_cong([a, b, p, q]):
746
+ continue
747
+ if (a, b, c) == (p, q, r):
748
+ continue
749
+ if any([x in record for x in rotate_contri(a, b, c, p, q, r)]):
750
+ continue
751
+ if not g.check_eqratio([c, a, c, b, r, p, r, q]):
752
+ continue
753
+
754
+ if not g.check_ncoll([a, b, c]):
755
+ continue
756
+
757
+ mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
758
+ record.add((a, b, c, p, q, r))
759
+ yield mapping
760
+
761
+
762
+ def match_eqangle6_eqangle6_ncoll_cong_contri2(
763
+ g: gh.Graph,
764
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
765
+ theorem: pr.Theorem,
766
+ ) -> Generator[dict[str, gm.Point], None, None]:
767
+ """Match eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C, cong A B P Q => contri2 A B C P Q R."""
768
+ enums = g_matcher('eqangle6')
769
+
770
+ record = set()
771
+ for b, a, b, c, q, r, q, p in enums: # pylint: disable=redeclared-assigned-name,unused-variable
772
+ if not g.check_cong([a, b, p, q]):
773
+ continue
774
+ if (a, b, c) == (p, q, r):
775
+ continue
776
+ if any([x in record for x in rotate_contri(a, b, c, p, q, r)]):
777
+ continue
778
+ if not g.check_eqangle([c, a, c, b, r, q, r, p]):
779
+ continue
780
+ if not g.check_ncoll([a, b, c]):
781
+ continue
782
+
783
+ mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
784
+ record.add((a, b, c, p, q, r))
785
+ yield mapping
786
+
787
+
788
+ def match_eqratio6_coll_ncoll_eqangle6(
789
+ g: gh.Graph,
790
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
791
+ theorem: pr.Theorem,
792
+ ) -> Generator[dict[str, gm.Point], None, None]:
793
+ """Match eqratio6 d b d c a b a c, coll d b c, ncoll a b c => eqangle6 a b a d a d a c."""
794
+ records = set()
795
+ for b, d, c in g_matcher('coll'):
796
+ for a in g.all_points():
797
+ if g.check_coll([a, b, c]):
798
+ continue
799
+ if (a, b, d, c) in records or (a, c, d, b) in records:
800
+ continue
801
+ records.add((a, b, d, c))
802
+
803
+ if g.check_eqratio([d, b, d, c, a, b, a, c]):
804
+ yield dict(zip('abcd', [a, b, c, d]))
805
+
806
+
807
+ def match_eqangle6_coll_ncoll_eqratio6(
808
+ g: gh.Graph,
809
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
810
+ theorem: pr.Theorem,
811
+ ) -> Generator[dict[str, gm.Point], None, None]:
812
+ """Match eqangle6 a b a d a d a c, coll d b c, ncoll a b c => eqratio6 d b d c a b a c."""
813
+ records = set()
814
+ for b, d, c in g_matcher('coll'):
815
+ for a in g.all_points():
816
+ if g.check_coll([a, b, c]):
817
+ continue
818
+ if (a, b, d, c) in records or (a, c, d, b) in records:
819
+ continue
820
+ records.add((a, b, d, c))
821
+
822
+ if g.check_eqangle([a, b, a, d, a, d, a, c]):
823
+ yield dict(zip('abcd', [a, b, c, d]))
824
+
825
+
826
+ def match_eqangle6_ncoll_cyclic(
827
+ g: gh.Graph,
828
+ g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
829
+ theorem: pr.Theorem,
830
+ ) -> Generator[dict[str, gm.Point], None, None]:
831
+ """Match eqangle6 P A P B Q A Q B, ncoll P Q A B => cyclic A B P Q."""
832
+ for a, b, a, c, x, y, x, z in g_matcher('eqangle6'): # pylint: disable=redeclared-assigned-name,unused-variable
833
+ if (b, c) != (y, z) or a == x:
834
+ continue
835
+ if nm.check_ncoll([x.num for x in [a, b, c, x]]):
836
+ yield dict(zip('ABPQ', [b, c, a, x]))
837
+
838
+
839
+ def match_all(
840
+ name: str, g: gh.Graph
841
+ ) -> Generator[tuple[gm.Point, ...], None, None]:
842
+ """Match all instances of a certain relation."""
843
+ if name in ['ncoll', 'npara', 'nperp']:
844
+ return []
845
+ if name == 'coll':
846
+ return g.all_colls()
847
+ if name == 'para':
848
+ return g.all_paras()
849
+ if name == 'perp':
850
+ return g.all_perps()
851
+ if name == 'cong':
852
+ return g.all_congs()
853
+ if name == 'eqangle':
854
+ return g.all_eqangles_8points()
855
+ if name == 'eqangle6':
856
+ return g.all_eqangles_6points()
857
+ if name == 'eqratio':
858
+ return g.all_eqratios_8points()
859
+ if name == 'eqratio6':
860
+ return g.all_eqratios_6points()
861
+ if name == 'cyclic':
862
+ return g.all_cyclics()
863
+ if name == 'midp':
864
+ return g.all_midps()
865
+ if name == 'circle':
866
+ return g.all_circles()
867
+ if name == 'semicircle':
868
+ return g.all_semicircles()
869
+ raise ValueError(f'Unrecognize {name}')
870
+
871
+
872
+ def cache_match(
873
+ graph: gh.Graph,
874
+ ) -> Callable[str, list[tuple[gm.Point, ...]]]:
875
+ """Cache throughout one single BFS level."""
876
+ cache = {}
877
+
878
+ def match_fn(name: str) -> list[tuple[gm.Point, ...]]:
879
+ if name in cache:
880
+ return cache[name]
881
+
882
+ result = list(match_all(name, graph))
883
+ cache[name] = result
884
+ return result
885
+
886
+ return match_fn
887
+
888
+
889
+ def try_to_map(
890
+ clause_enum: list[tuple[pr.Clause, list[tuple[gm.Point, ...]]]],
891
+ mapping: dict[str, gm.Point],
892
+ ) -> Generator[dict[str, gm.Point], None, None]:
893
+ """Recursively try to match the remaining points given current mapping."""
894
+ if not clause_enum:
895
+ yield mapping
896
+ return
897
+
898
+ clause, enum = clause_enum[0]
899
+ for points in enum:
900
+ mpcpy = dict(mapping)
901
+
902
+ fail = False
903
+ for p, a in zip(points, clause.args):
904
+ if a in mpcpy and mpcpy[a] != p or p in mpcpy and mpcpy[p] != a:
905
+ fail = True
906
+ break
907
+ mpcpy[a] = p
908
+ mpcpy[p] = a
909
+
910
+ if fail:
911
+ continue
912
+
913
+ for m in try_to_map(clause_enum[1:], mpcpy):
914
+ yield m
915
+
916
+
917
+ def match_generic(
918
+ g: gh.Graph,
919
+ cache: Callable[str, list[tuple[gm.Point, ...]]],
920
+ theorem: pr.Theorem
921
+ ) -> Generator[dict[str, gm.Point], None, None]:
922
+ """Match any generic rule that is not one of the above match_*() rules."""
923
+ clause2enum = {}
924
+
925
+ clauses = []
926
+ numerical_checks = []
927
+ for clause in theorem.premise:
928
+ if clause.name in ['ncoll', 'npara', 'nperp', 'sameside']:
929
+ numerical_checks.append(clause)
930
+ continue
931
+
932
+ enum = cache(clause.name)
933
+ if len(enum) == 0: # pylint: disable=g-explicit-length-test
934
+ return 0
935
+
936
+ clause2enum[clause] = enum
937
+ clauses.append((len(set(clause.args)), clause))
938
+
939
+ clauses = sorted(clauses, key=lambda x: x[0], reverse=True)
940
+ _, clauses = zip(*clauses)
941
+
942
+ for mapping in try_to_map([(c, clause2enum[c]) for c in clauses], {}):
943
+ if not mapping:
944
+ continue
945
+
946
+ checks_ok = True
947
+ for check in numerical_checks:
948
+ args = [mapping[a] for a in check.args]
949
+ if check.name == 'ncoll':
950
+ checks_ok = g.check_ncoll(args)
951
+ elif check.name == 'npara':
952
+ checks_ok = g.check_npara(args)
953
+ elif check.name == 'nperp':
954
+ checks_ok = g.check_nperp(args)
955
+ elif check.name == 'sameside':
956
+ checks_ok = g.check_sameside(args)
957
+ if not checks_ok:
958
+ break
959
+ if not checks_ok:
960
+ continue
961
+
962
+ yield mapping
963
+
964
+
965
+ BUILT_IN_FNS = {
966
+ 'cong_cong_cong_cyclic': match_cong_cong_cong_cyclic,
967
+ 'cong_cong_cong_ncoll_contri*': match_cong_cong_cong_ncoll_contri,
968
+ 'cong_cong_eqangle6_ncoll_contri*': match_cong_cong_eqangle6_ncoll_contri,
969
+ 'eqangle6_eqangle6_ncoll_simtri': match_eqangle6_eqangle6_ncoll_simtri,
970
+ 'eqangle6_eqangle6_ncoll_cong_contri': (
971
+ match_eqangle6_eqangle6_ncoll_cong_contri
972
+ ), # pylint: disable=line-too-long
973
+ 'eqangle6_eqangle6_ncoll_simtri2': match_eqangle6_eqangle6_ncoll_simtri2,
974
+ 'eqangle6_eqangle6_ncoll_cong_contri2': (
975
+ match_eqangle6_eqangle6_ncoll_cong_contri2
976
+ ), # pylint: disable=line-too-long
977
+ 'eqratio6_eqratio6_ncoll_simtri*': match_eqratio6_eqratio6_ncoll_simtri,
978
+ 'eqratio6_eqratio6_ncoll_cong_contri*': (
979
+ match_eqratio6_eqratio6_ncoll_cong_contri
980
+ ), # pylint: disable=line-too-long
981
+ 'eqangle_para': match_eqangle_para,
982
+ 'eqangle_ncoll_cyclic': match_eqangle_ncoll_cyclic,
983
+ 'eqratio6_eqangle6_ncoll_simtri*': match_eqratio6_eqangle6_ncoll_simtri,
984
+ 'eqangle_perp_perp': match_eqangle_perp_perp,
985
+ 'eqangle6_ncoll_cong': match_eqangle6_ncoll_cong,
986
+ 'perp_perp_ncoll_para': match_perp_perp_ncoll_para,
987
+ 'circle_perp_eqangle': match_circle_perp_eqangle,
988
+ 'circle_eqangle_perp': match_circle_eqangle_perp,
989
+ 'cyclic_eqangle_cong': match_cyclic_eqangle_cong,
990
+ 'midp_perp_cong': match_midp_perp_cong,
991
+ 'perp_perp_npara_eqangle': match_perp_perp_npara_eqangle,
992
+ 'cyclic_eqangle': match_cyclic_eqangle,
993
+ 'eqangle_eqangle_eqangle': match_eqangle_eqangle_eqangle,
994
+ 'eqratio_eqratio_eqratio': match_eqratio_eqratio_eqratio,
995
+ 'eqratio6_coll_ncoll_eqangle6': match_eqratio6_coll_ncoll_eqangle6,
996
+ 'eqangle6_coll_ncoll_eqratio6': match_eqangle6_coll_ncoll_eqratio6,
997
+ 'eqangle6_ncoll_cyclic': match_eqangle6_ncoll_cyclic,
998
+ 'semicircle_perp_eqangle': match_semicircle_perp_eqangle,
999
+ 'semicircle_eqangle_perp': match_semicircle_eqangle_perp,
1000
+ }
1001
+
1002
+
1003
+ SKIP_THEOREMS = set()
1004
+
1005
+
1006
+ def set_skip_theorems(theorems: set[str]) -> None:
1007
+ SKIP_THEOREMS.update(theorems)
1008
+
1009
+
1010
+ MAX_BRANCH = 50_000
1011
+
1012
+
1013
+ def match_one_theorem(
1014
+ g: gh.Graph,
1015
+ cache: Callable[str, list[tuple[gm.Point, ...]]],
1016
+ theorem: pr.Theorem
1017
+ ) -> Generator[dict[str, gm.Point], None, None]:
1018
+ """Match all instances of a single theorem (rule)."""
1019
+ if cache is None:
1020
+ cache = cache_match(g)
1021
+
1022
+ if theorem.name in SKIP_THEOREMS:
1023
+ return []
1024
+
1025
+ if theorem.name.split('_')[-1] in SKIP_THEOREMS:
1026
+ return []
1027
+
1028
+ if theorem.name in BUILT_IN_FNS:
1029
+ mps = BUILT_IN_FNS[theorem.name](g, cache, theorem)
1030
+ else:
1031
+ mps = match_generic(g, cache, theorem)
1032
+
1033
+ mappings = []
1034
+ for mp in mps:
1035
+ mappings.append(mp)
1036
+ if len(mappings) > MAX_BRANCH: # cap branching at this number.
1037
+ break
1038
+
1039
+ return mappings
1040
+
1041
+
1042
+ def match_all_theorems(
1043
+ g: gh.Graph, theorems: list[pr.Theorem], goal: pr.Clause
1044
+ ) -> dict[pr.Theorem, dict[pr.Theorem, dict[str, gm.Point]]]:
1045
+ """Match all instances of all theorems (rules)."""
1046
+ cache = cache_match(g)
1047
+ # for BFS, collect all potential matches
1048
+ # and then do it at the same time
1049
+ theorem2mappings = {}
1050
+
1051
+ # Step 1: list all matches
1052
+ for _, theorem in theorems.items():
1053
+ name = theorem.name
1054
+ if name.split('_')[-1] in [
1055
+ 'acompute',
1056
+ 'rcompute',
1057
+ 'fixl',
1058
+ 'fixc',
1059
+ 'fixb',
1060
+ 'fixt',
1061
+ 'fixp',
1062
+ ]:
1063
+ if goal and goal.name != name:
1064
+ continue
1065
+
1066
+ mappings = match_one_theorem(g, cache, theorem)
1067
+ if len(mappings): # pylint: disable=g-explicit-length-test
1068
+ theorem2mappings[theorem] = list(mappings)
1069
+ return theorem2mappings
1070
+
1071
+
1072
+ def bfs_one_level(
1073
+ g: gh.Graph,
1074
+ theorems: list[pr.Theorem],
1075
+ level: int,
1076
+ controller: pr.Problem,
1077
+ verbose: bool = False,
1078
+ nm_check: bool = False,
1079
+ timeout: int = 600,
1080
+ ) -> tuple[
1081
+ list[pr.Dependency],
1082
+ dict[str, list[tuple[gm.Point, ...]]],
1083
+ dict[str, list[tuple[gm.Point, ...]]],
1084
+ int,
1085
+ ]:
1086
+ """Forward deduce one breadth-first level."""
1087
+
1088
+ # Step 1: match all theorems:
1089
+ theorem2mappings = match_all_theorems(g, theorems, controller.goal)
1090
+
1091
+ # Step 2: traceback for each deduce:
1092
+ theorem2deps = {}
1093
+ t0 = time.time()
1094
+ for theorem, mappings in theorem2mappings.items():
1095
+ if time.time() - t0 > timeout:
1096
+ break
1097
+ mp_deps = []
1098
+ for mp in mappings:
1099
+ deps = EmptyDependency(level=level, rule_name=theorem.rule_name)
1100
+ fail = False # finding why deps might fail.
1101
+
1102
+ for p in theorem.premise:
1103
+ p_args = [mp[a] for a in p.args]
1104
+ # Trivial deps.
1105
+ if p.name == 'cong':
1106
+ a, b, c, d = p_args
1107
+ if {a, b} == {c, d}:
1108
+ continue
1109
+ if p.name == 'para':
1110
+ a, b, c, d = p_args
1111
+ if {a, b} == {c, d}:
1112
+ continue
1113
+
1114
+ if theorem.name in [
1115
+ 'cong_cong_eqangle6_ncoll_contri*',
1116
+ 'eqratio6_eqangle6_ncoll_simtri*',
1117
+ ]:
1118
+ if p.name in ['eqangle', 'eqangle6']: # SAS or RAR
1119
+ b, a, b, c, y, x, y, z = ( # pylint: disable=redeclared-assigned-name,unused-variable
1120
+ p_args
1121
+ )
1122
+ if not nm.same_clock(a.num, b.num, c.num, x.num, y.num, z.num):
1123
+ p_args = b, a, b, c, y, z, y, x
1124
+
1125
+ dep = Dependency(p.name, p_args, rule_name='', level=level)
1126
+ try:
1127
+ dep = dep.why_me_or_cache(g, level)
1128
+ except: # pylint: disable=bare-except
1129
+ fail = True
1130
+ break
1131
+
1132
+ if dep.why is None:
1133
+ fail = True
1134
+ break
1135
+ g.cache_dep(p.name, p_args, dep)
1136
+ deps.why.append(dep)
1137
+
1138
+ if fail:
1139
+ continue
1140
+
1141
+ mp_deps.append((mp, deps))
1142
+ theorem2deps[theorem] = mp_deps
1143
+
1144
+ theorem2deps = list(theorem2deps.items())
1145
+
1146
+ # Step 3: add conclusions to graph.
1147
+ # Note that we do NOT mix step 2 and 3, strictly going for BFS.
1148
+ added = []
1149
+ for theorem, mp_deps in theorem2deps:
1150
+ for mp, deps in mp_deps:
1151
+ if time.time() - t0 > timeout:
1152
+ break
1153
+ name, args = theorem.conclusion_name_args(mp)
1154
+ hash_conclusion = pr.hashed(name, args)
1155
+ if hash_conclusion in g.cache:
1156
+ continue
1157
+
1158
+ add = g.add_piece(name, args, deps=deps)
1159
+ added += add
1160
+
1161
+ branching = len(added)
1162
+
1163
+ # Check if goal is found
1164
+ if controller.goal:
1165
+ args = []
1166
+
1167
+ for a in controller.goal.args:
1168
+ if a in g._name2node:
1169
+ a = g._name2node[a]
1170
+ elif '/' in a:
1171
+ a = create_consts_str(g, a)
1172
+ elif a.isdigit():
1173
+ a = int(a)
1174
+ args.append(a)
1175
+
1176
+ if g.check(controller.goal.name, args):
1177
+ return added, {}, {}, branching
1178
+
1179
+ # Run AR, but do NOT apply to the proof state (yet).
1180
+ for dep in added:
1181
+ g.add_algebra(dep, level)
1182
+ derives, eq4s = g.derive_algebra(level, verbose=False)
1183
+
1184
+ branching += sum([len(x) for x in derives.values()])
1185
+ branching += sum([len(x) for x in eq4s.values()])
1186
+
1187
+ return added, derives, eq4s, branching
1188
+
1189
+
1190
+ def create_consts_str(g: gh.Graph, s: str) -> Union [gm.Angle, gm.Ratio]:
1191
+ if 'pi/' in s:
1192
+ n, d = s.split('pi/')
1193
+ n, d = int(n), int(d)
1194
+ p0, _ = g.get_or_create_const_ang(n, d)
1195
+ else:
1196
+ n, d = s.split('/')
1197
+ n, d = int(n), int(d)
1198
+ p0, _ = g.get_or_create_const_rat(n, d)
1199
+ return p0
1200
+
1201
+
1202
+ def do_algebra(
1203
+ g: gh.Graph, added: list[pr.Dependency], verbose: bool = False
1204
+ ) -> None:
1205
+ for add in added:
1206
+ g.add_algebra(add, None)
1207
+ derives, eq4s = g.derive_algebra(level=None, verbose=verbose)
1208
+ apply_derivations(g, derives)
1209
+ apply_derivations(g, eq4s)
1210
+
1211
+
1212
+ def apply_derivations(
1213
+ g: gh.Graph, derives: dict[str, list[tuple[gm.Point, ...]]]
1214
+ ) -> list[pr.Dependency]:
1215
+ applied = []
1216
+ all_derives = list(derives.items())
1217
+ for name, args in all_derives:
1218
+ for arg in args:
1219
+ applied += g.do_algebra(name, arg)
1220
+ return applied
ag4masses/alphageometry/ddar.py CHANGED
@@ -1,159 +1,157 @@
1
- # Copyright 2023 DeepMind Technologies Limited
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Implements the combination DD+AR."""
17
- import time
18
-
19
- from absl import logging
20
- import dd
21
- import graph as gh
22
- import problem as pr
23
- from problem import Dependency # pylint: disable=g-importing-member
24
- import trace_back
25
-
26
-
27
- def saturate_or_goal(
28
- g: gh.Graph,
29
- theorems: list[pr.Theorem],
30
- level_times: list[float],
31
- p: pr.Problem,
32
- max_level: int = 100,
33
- timeout: int = 600,
34
- ) -> tuple[
35
- list[dict[str, list[tuple[gh.Point, ...]]]],
36
- list[dict[str, list[tuple[gh.Point, ...]]]],
37
- list[int],
38
- list[pr.Dependency],
39
- ]:
40
- """Run DD until saturation or goal found."""
41
- derives = []
42
- eq4s = []
43
- branching = []
44
- all_added = []
45
-
46
- while len(level_times) < max_level:
47
- level = len(level_times) + 1
48
-
49
- t = time.time()
50
- added, derv, eq4, n_branching = dd.bfs_one_level(
51
- g, theorems, level, p, verbose=False, nm_check=True, timeout=timeout
52
- )
53
- all_added += added
54
- branching.append(n_branching)
55
-
56
- derives.append(derv)
57
- eq4s.append(eq4)
58
- level_time = time.time() - t
59
-
60
- logging.info(f'Depth {level}/{max_level} time = {level_time}') # pylint: disable=logging-fstring-interpolation
61
- level_times.append(level_time)
62
-
63
- if p.goal is not None:
64
- goal_args = list(map(lambda x: g.get(x, lambda: int(x)), p.goal.args))
65
- if g.check(p.goal.name, goal_args): # found goal
66
- break
67
-
68
- if not added: # saturated
69
- break
70
-
71
- if level_time > timeout:
72
- break
73
-
74
- return derives, eq4s, branching, all_added
75
-
76
-
77
- def solve(
78
- g: gh.Graph,
79
- theorems: list[pr.Problem],
80
- controller: pr.Problem,
81
- max_level: int = 1000,
82
- timeout: int = 600,
83
- ) -> tuple[gh.Graph, list[float], str, list[int], list[pr.Dependency]]:
84
- """Alternate between DD and AR until goal is found."""
85
- status = 'saturated'
86
- level_times = []
87
-
88
- dervs, eq4 = g.derive_algebra(level=0, verbose=False)
89
- derives = [dervs]
90
- eq4s = [eq4]
91
- branches = []
92
- all_added = []
93
-
94
- while len(level_times) < max_level:
95
- dervs, eq4, next_branches, added = saturate_or_goal(
96
- g, theorems, level_times, controller, max_level, timeout=timeout
97
- )
98
- all_added += added
99
-
100
- derives += dervs
101
- eq4s += eq4
102
- branches += next_branches
103
-
104
- # Now, it is either goal or saturated
105
- if controller.goal is not None:
106
- goal_args = g.names2points(controller.goal.args)
107
- if g.check(controller.goal.name, goal_args): # found goal
108
- status = 'solved'
109
- break
110
-
111
- if not derives: # officially saturated.
112
- logging.info("derives empty, breaking")
113
- break
114
-
115
- # Now we resort to algebra derivations.
116
- added = []
117
- while derives and not added:
118
- added += dd.apply_derivations(g, derives.pop(0))
119
-
120
- if added:
121
- continue
122
-
123
- # Final help from AR.
124
- while eq4s and not added:
125
- added += dd.apply_derivations(g, eq4s.pop(0))
126
-
127
- all_added += added
128
-
129
- if not added: # Nothing left. saturated.
130
- logging.info("Nothing added, breaking")
131
- break
132
-
133
- return g, level_times, status, branches, all_added
134
-
135
-
136
- def get_proof_steps(
137
- g: gh.Graph, goal: pr.Clause, merge_trivials: bool = False
138
- ) -> tuple[
139
- list[pr.Dependency],
140
- list[pr.Dependency],
141
- list[tuple[list[pr.Dependency], list[pr.Dependency]]],
142
- dict[tuple[str, ...], int],
143
- ]:
144
- """Extract proof steps from the built DAG."""
145
- goal_args = g.names2nodes(goal.args)
146
- query = Dependency(goal.name, goal_args, None, None)
147
-
148
- setup, aux, log, setup_points = trace_back.get_logs(
149
- query, g, merge_trivials=merge_trivials
150
- )
151
-
152
- refs = {}
153
- setup = trace_back.point_log(setup, refs, set())
154
- aux = trace_back.point_log(aux, refs, setup_points)
155
-
156
- setup = [(prems, [tuple(p)]) for p, prems in setup]
157
- aux = [(prems, [tuple(p)]) for p, prems in aux]
158
-
159
- return setup, aux, log, refs
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Implements the combination DD+AR."""
17
+ import time
18
+
19
+ from absl import logging
20
+ import dd
21
+ import graph as gh
22
+ import problem as pr
23
+ from problem import Dependency # pylint: disable=g-importing-member
24
+ import trace_back
25
+
26
+
27
+ def saturate_or_goal(
28
+ g: gh.Graph,
29
+ theorems: list[pr.Theorem],
30
+ level_times: list[float],
31
+ p: pr.Problem,
32
+ max_level: int = 100,
33
+ timeout: int = 600,
34
+ ) -> tuple[
35
+ list[dict[str, list[tuple[gh.Point, ...]]]],
36
+ list[dict[str, list[tuple[gh.Point, ...]]]],
37
+ list[int],
38
+ list[pr.Dependency],
39
+ ]:
40
+ """Run DD until saturation or goal found."""
41
+ derives = []
42
+ eq4s = []
43
+ branching = []
44
+ all_added = []
45
+
46
+ while len(level_times) < max_level:
47
+ level = len(level_times) + 1
48
+
49
+ t = time.time()
50
+ added, derv, eq4, n_branching = dd.bfs_one_level(
51
+ g, theorems, level, p, verbose=False, nm_check=True, timeout=timeout
52
+ )
53
+ all_added += added
54
+ branching.append(n_branching)
55
+
56
+ derives.append(derv)
57
+ eq4s.append(eq4)
58
+ level_time = time.time() - t
59
+
60
+ logging.info(f'Depth {level}/{max_level} time = {level_time}') # pylint: disable=logging-fstring-interpolation
61
+ level_times.append(level_time)
62
+
63
+ if p.goal is not None:
64
+ goal_args = list(map(lambda x: g.get(x, lambda: int(x)), p.goal.args))
65
+ if g.check(p.goal.name, goal_args): # found goal
66
+ break
67
+
68
+ if not added: # saturated
69
+ break
70
+
71
+ if level_time > timeout:
72
+ break
73
+
74
+ return derives, eq4s, branching, all_added
75
+
76
+
77
+ def solve(
78
+ g: gh.Graph,
79
+ theorems: list[pr.Problem],
80
+ controller: pr.Problem,
81
+ max_level: int = 1000,
82
+ timeout: int = 600,
83
+ ) -> tuple[gh.Graph, list[float], str, list[int], list[pr.Dependency]]:
84
+ """Alternate between DD and AR until goal is found."""
85
+ status = 'saturated'
86
+ level_times = []
87
+
88
+ dervs, eq4 = g.derive_algebra(level=0, verbose=False)
89
+ derives = [dervs]
90
+ eq4s = [eq4]
91
+ branches = []
92
+ all_added = []
93
+
94
+ while len(level_times) < max_level:
95
+ dervs, eq4, next_branches, added = saturate_or_goal(
96
+ g, theorems, level_times, controller, max_level, timeout=timeout
97
+ )
98
+ all_added += added
99
+
100
+ derives += dervs
101
+ eq4s += eq4
102
+ branches += next_branches
103
+
104
+ # Now, it is either goal or saturated
105
+ if controller.goal is not None:
106
+ goal_args = g.names2points(controller.goal.args)
107
+ if g.check(controller.goal.name, goal_args): # found goal
108
+ status = 'solved'
109
+ break
110
+
111
+ if not derives: # officially saturated.
112
+ break
113
+
114
+ # Now we resort to algebra derivations.
115
+ added = []
116
+ while derives and not added:
117
+ added += dd.apply_derivations(g, derives.pop(0))
118
+
119
+ if added:
120
+ continue
121
+
122
+ # Final help from AR.
123
+ while eq4s and not added:
124
+ added += dd.apply_derivations(g, eq4s.pop(0))
125
+
126
+ all_added += added
127
+
128
+ if not added: # Nothing left. saturated.
129
+ break
130
+
131
+ return g, level_times, status, branches, all_added
132
+
133
+
134
+ def get_proof_steps(
135
+ g: gh.Graph, goal: pr.Clause, merge_trivials: bool = False
136
+ ) -> tuple[
137
+ list[pr.Dependency],
138
+ list[pr.Dependency],
139
+ list[tuple[list[pr.Dependency], list[pr.Dependency]]],
140
+ dict[tuple[str, ...], int],
141
+ ]:
142
+ """Extract proof steps from the built DAG."""
143
+ goal_args = g.names2nodes(goal.args)
144
+ query = Dependency(goal.name, goal_args, None, None)
145
+
146
+ setup, aux, log, setup_points = trace_back.get_logs(
147
+ query, g, merge_trivials=merge_trivials
148
+ )
149
+
150
+ refs = {}
151
+ setup = trace_back.point_log(setup, refs, set())
152
+ aux = trace_back.point_log(aux, refs, setup_points)
153
+
154
+ setup = [(prems, [tuple(p)]) for p, prems in setup]
155
+ aux = [(prems, [tuple(p)]) for p, prems in aux]
156
+
157
+ return setup, aux, log, refs
 
 
ag4masses/alphageometry/decoder_stack.py CHANGED
@@ -1,55 +1,55 @@
1
- # Copyright 2023 DeepMind Technologies Limited
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """The decoder stack in inference mode."""
17
-
18
- from typing import Any, Tuple
19
-
20
- import gin
21
- from transformer import decoder_stack
22
- import transformer_layer as tl
23
-
24
-
25
- struct = decoder_stack.struct
26
- nn_components = decoder_stack.nn_components
27
- position = decoder_stack.position
28
- jnp = decoder_stack.jnp
29
- attention = decoder_stack.attention
30
-
31
- DStackWindowState = decoder_stack.DStackWindowState
32
-
33
- Array = Any
34
-
35
- TransformerTaskConfig = decoder_stack.TransformerTaskConfig
36
-
37
- DStackDecoderState = Tuple[tl.DecoderState, ...]
38
-
39
-
40
- @gin.configurable
41
- class DecoderStackGenerate(decoder_stack.DecoderStack):
42
- """Stack of transformer decoder layers."""
43
-
44
- layer_factory = tl.TransformerLayerGenerate
45
-
46
- def init_decoder_state_vanilla(
47
- self, sequence_length: int, start_of_sequence: Array
48
- ) -> DStackDecoderState:
49
- """Return initial state for autoregressive generation."""
50
- return tuple(
51
- [
52
- layer.init_decoder_state_vanilla(sequence_length, start_of_sequence)
53
- for layer in self.transformer_layers
54
- ]
55
- )
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """The decoder stack in inference mode."""
17
+
18
+ from typing import Any, Tuple
19
+
20
+ import gin
21
+ from meliad_lib.meliad.transformer import decoder_stack
22
+ import transformer_layer as tl
23
+
24
+
25
+ struct = decoder_stack.struct
26
+ nn_components = decoder_stack.nn_components
27
+ position = decoder_stack.position
28
+ jnp = decoder_stack.jnp
29
+ attention = decoder_stack.attention
30
+
31
+ DStackWindowState = decoder_stack.DStackWindowState
32
+
33
+ Array = Any
34
+
35
+ TransformerTaskConfig = decoder_stack.TransformerTaskConfig
36
+
37
+ DStackDecoderState = Tuple[tl.DecoderState, ...]
38
+
39
+
40
+ @gin.configurable
41
+ class DecoderStackGenerate(decoder_stack.DecoderStack):
42
+ """Stack of transformer decoder layers."""
43
+
44
+ layer_factory = tl.TransformerLayerGenerate
45
+
46
+ def init_decoder_state_vanilla(
47
+ self, sequence_length: int, start_of_sequence: Array
48
+ ) -> DStackDecoderState:
49
+ """Return initial state for autoregressive generation."""
50
+ return tuple(
51
+ [
52
+ layer.init_decoder_state_vanilla(sequence_length, start_of_sequence)
53
+ for layer in self.transformer_layers
54
+ ]
55
+ )
ag4masses/alphageometry/defs.txt CHANGED
@@ -405,3 +405,15 @@ x : a b c
405
  a b c = ncoll a b c
406
  x : cyclic a b c x
407
  cyclic a b c
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  a b c = ncoll a b c
406
  x : cyclic a b c x
407
  cyclic a b c
408
+
409
+ semicircle x a b c
410
+ x : a b c
411
+ a b c = ncoll a b c
412
+ x : cong x a x b; cong x b x c
413
+ bline a b, bline a c
414
+
415
+ on_semicircle x o a
416
+ x : x o a
417
+ o a = diff o a
418
+ x : cong o x o a
419
+ circle o o a
ag4masses/alphageometry/geometry.py CHANGED
@@ -1,578 +1,621 @@
1
- # Copyright 2023 DeepMind Technologies Limited
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Implements geometric objects used in the graph representation."""
17
- from __future__ import annotations
18
- from collections import defaultdict # pylint: disable=g-importing-member
19
- from typing import Any, Type
20
-
21
- # pylint: disable=protected-access
22
-
23
-
24
- class Node:
25
- r"""Node in the proof state graph.
26
-
27
- Can be Point, Line, Circle, etc.
28
-
29
- Each node maintains a merge history to
30
- other nodes if they are (found out to be) equivalent
31
-
32
- a -> b -
33
- \
34
- c -> d -> e -> f -> g
35
-
36
- d.merged_to = e
37
- d.rep = g
38
- d.merged_from = {a, b, c, d}
39
- d.equivs = {a, b, c, d, e, f, g}
40
- """
41
-
42
- def __init__(self, name: str = '', graph: Any = None):
43
- self.name = name or str(self)
44
- self.graph = graph
45
-
46
- self.edge_graph = {}
47
- # Edge graph: what other nodes is connected to this node.
48
- # edge graph = {
49
- # other1: {self1: deps, self2: deps},
50
- # other2: {self2: deps, self3: deps}
51
- # }
52
-
53
- self.merge_graph = {}
54
- # Merge graph: history of merges with other nodes.
55
- # merge_graph = {self1: {self2: deps1, self3: deps2}}
56
-
57
- self.rep_by = None # represented by.
58
- self.members = {self}
59
-
60
- self._val = None
61
- self._obj = None
62
-
63
- self.deps = []
64
-
65
- # numerical representation.
66
- self.num = None
67
- self.change = set() # what other nodes' num rely on this node?
68
-
69
- def set_rep(self, node: Node) -> None:
70
- if node == self:
71
- return
72
- self.rep_by = node
73
- node.merge_edge_graph(self.edge_graph)
74
- node.members.update(self.members)
75
-
76
- def rep(self) -> Node:
77
- x = self
78
- while x.rep_by:
79
- x = x.rep_by
80
- return x
81
-
82
- def why_rep(self) -> list[Any]:
83
- return self.why_equal([self.rep()], None)
84
-
85
- def rep_and_why(self) -> tuple[Node, list[Any]]:
86
- rep = self.rep()
87
- return rep, self.why_equal([rep], None)
88
-
89
- def neighbors(
90
- self, oftype: Type[Node], return_set: bool = False, do_rep: bool = True
91
- ) -> list[Node]:
92
- """Neighbors of this node in the proof state graph."""
93
- if do_rep:
94
- rep = self.rep()
95
- else:
96
- rep = self
97
- result = set()
98
-
99
- for n in rep.edge_graph:
100
- if oftype is None or oftype and isinstance(n, oftype):
101
- if do_rep:
102
- result.add(n.rep())
103
- else:
104
- result.add(n)
105
-
106
- if return_set:
107
- return result
108
- return list(result)
109
-
110
- def merge_edge_graph(
111
- self, new_edge_graph: dict[Node, dict[Node, list[Node]]]
112
- ) -> None:
113
- for x, xdict in new_edge_graph.items():
114
- if x in self.edge_graph:
115
- self.edge_graph[x].update(dict(xdict))
116
- else:
117
- self.edge_graph[x] = dict(xdict)
118
-
119
- def merge(self, nodes: list[Node], deps: list[Any]) -> None:
120
- for node in nodes:
121
- self.merge_one(node, deps)
122
-
123
- def merge_one(self, node: Node, deps: list[Any]) -> None:
124
- node.rep().set_rep(self.rep())
125
-
126
- if node in self.merge_graph:
127
- return
128
-
129
- self.merge_graph[node] = deps
130
- node.merge_graph[self] = deps
131
-
132
- def is_val(self, node: Node) -> bool:
133
- return (
134
- isinstance(self, Line)
135
- and isinstance(node, Direction)
136
- or isinstance(self, Segment)
137
- and isinstance(node, Length)
138
- or isinstance(self, Angle)
139
- and isinstance(node, Measure)
140
- or isinstance(self, Ratio)
141
- and isinstance(node, Value)
142
- )
143
-
144
- def set_val(self, node: Node) -> None:
145
- self._val = node
146
-
147
- def set_obj(self, node: Node) -> None:
148
- self._obj = node
149
-
150
- @property
151
- def val(self) -> Node:
152
- if self._val is None:
153
- return None
154
- return self._val.rep()
155
-
156
- @property
157
- def obj(self) -> Node:
158
- if self._obj is None:
159
- return None
160
- return self._obj.rep()
161
-
162
- def equivs(self) -> set[Node]:
163
- return self.rep().members
164
-
165
- def connect_to(self, node: Node, deps: list[Any] = None) -> None:
166
- rep = self.rep()
167
-
168
- if node in rep.edge_graph:
169
- rep.edge_graph[node].update({self: deps})
170
- else:
171
- rep.edge_graph[node] = {self: deps}
172
-
173
- if self.is_val(node):
174
- self.set_val(node)
175
- node.set_obj(self)
176
-
177
- def equivs_upto(self, level: int) -> dict[Node, Node]:
178
- """What are the equivalent nodes up to a certain level."""
179
- parent = {self: None}
180
- visited = set()
181
- queue = [self]
182
- i = 0
183
-
184
- while i < len(queue):
185
- current = queue[i]
186
- i += 1
187
- visited.add(current)
188
-
189
- for neighbor in current.merge_graph:
190
- if (
191
- level is not None
192
- and current.merge_graph[neighbor].level is not None
193
- and current.merge_graph[neighbor].level >= level
194
- ):
195
- continue
196
- if neighbor not in visited:
197
- queue.append(neighbor)
198
- parent[neighbor] = current
199
-
200
- return parent
201
-
202
- def why_equal(self, others: list[Node], level: int) -> list[Any]:
203
- """BFS why this node is equal to other nodes."""
204
- others = set(others)
205
- found = 0
206
-
207
- parent = {}
208
- queue = [self]
209
- i = 0
210
-
211
- while i < len(queue):
212
- current = queue[i]
213
- if current in others:
214
- found += 1
215
- if found == len(others):
216
- break
217
-
218
- i += 1
219
-
220
- for neighbor in current.merge_graph:
221
- if (
222
- level is not None
223
- and current.merge_graph[neighbor].level is not None
224
- and current.merge_graph[neighbor].level >= level
225
- ):
226
- continue
227
- if neighbor not in parent:
228
- queue.append(neighbor)
229
- parent[neighbor] = current
230
-
231
- return bfs_backtrack(self, others, parent)
232
-
233
- def why_equal_groups(
234
- self, groups: list[list[Node]], level: int
235
- ) -> tuple[list[Any], list[Node]]:
236
- """BFS for why self is equal to at least one member of each group."""
237
- others = [None for _ in groups]
238
- found = 0
239
-
240
- parent = {}
241
- queue = [self]
242
- i = 0
243
-
244
- while i < len(queue):
245
- current = queue[i]
246
-
247
- for j, grp in enumerate(groups):
248
- if others[j] is None and current in grp:
249
- others[j] = current
250
- found += 1
251
-
252
- if found == len(others):
253
- break
254
-
255
- i += 1
256
-
257
- for neighbor in current.merge_graph:
258
- if (
259
- level is not None
260
- and current.merge_graph[neighbor].level is not None
261
- and current.merge_graph[neighbor].level >= level
262
- ):
263
- continue
264
- if neighbor not in parent:
265
- queue.append(neighbor)
266
- parent[neighbor] = current
267
-
268
- return bfs_backtrack(self, others, parent), others
269
-
270
- def why_val(self, level: int) -> list[Any]:
271
- return self._val.why_equal([self.val], level)
272
-
273
- def why_connect(self, node: Node, level: int = None) -> list[Any]:
274
- rep = self.rep()
275
- equivs = list(rep.edge_graph[node].keys())
276
- if not equivs:
277
- return None
278
- equiv = equivs[0]
279
- dep = rep.edge_graph[node][equiv]
280
- return [dep] + self.why_equal(equiv, level)
281
-
282
-
283
- def why_connect(*pairs: list[tuple[Node, Node]]) -> list[Any]:
284
- result = []
285
- for node1, node2 in pairs:
286
- result += node1.why_connect(node2)
287
- return result
288
-
289
-
290
- def is_equiv(x: Node, y: Node, level: int = None) -> bool:
291
- level = level or float('inf')
292
- return x.why_equal([y], level) is not None
293
-
294
-
295
- def is_equal(x: Node, y: Node, level: int = None) -> bool:
296
- if x == y:
297
- return True
298
- if x._val is None or y._val is None:
299
- return False
300
- if x.val != y.val:
301
- return False
302
- return is_equiv(x._val, y._val, level)
303
-
304
-
305
- def bfs_backtrack(
306
- root: Node, leafs: list[Node], parent: dict[Node, Node]
307
- ) -> list[Any]:
308
- """Return the path given BFS trace of parent nodes."""
309
- backtracked = {root} # no need to backtrack further when touching this set.
310
- deps = []
311
- for node in leafs:
312
- if node is None:
313
- return None
314
- if node in backtracked:
315
- continue
316
- if node not in parent:
317
- return None
318
- while node not in backtracked:
319
- backtracked.add(node)
320
- deps.append(node.merge_graph[parent[node]])
321
- node = parent[node]
322
-
323
- return deps
324
-
325
-
326
- class Point(Node):
327
- pass
328
-
329
-
330
- class Line(Node):
331
- """Node of type Line."""
332
-
333
- def new_val(self) -> Direction:
334
- return Direction()
335
-
336
- def why_coll(self, points: list[Point], level: int = None) -> list[Any]:
337
- """Why points are connected to self."""
338
- level = level or float('inf')
339
-
340
- groups = []
341
- for p in points:
342
- group = [
343
- l
344
- for l, d in self.edge_graph[p].items()
345
- if d is None or d.level < level
346
- ]
347
- if not group:
348
- return None
349
- groups.append(group)
350
-
351
- min_deps = None
352
- for line in groups[0]:
353
- deps, others = line.why_equal_groups(groups[1:], level)
354
- if deps is None:
355
- continue
356
- for p, o in zip(points, [line] + others):
357
- deps.append(self.edge_graph[p][o])
358
- if min_deps is None or len(deps) < len(min_deps):
359
- min_deps = deps
360
-
361
- if min_deps is None:
362
- return None
363
- return [d for d in min_deps if d is not None]
364
-
365
-
366
- class Segment(Node):
367
-
368
- def new_val(self) -> Length:
369
- return Length()
370
-
371
-
372
- class Circle(Node):
373
- """Node of type Circle."""
374
-
375
- def why_cyclic(self, points: list[Point], level: int = None) -> list[Any]:
376
- """Why points are connected to self."""
377
- level = level or float('inf')
378
-
379
- groups = []
380
- for p in points:
381
- group = [
382
- c
383
- for c, d in self.edge_graph[p].items()
384
- if d is None or d.level < level
385
- ]
386
- if not group:
387
- return None
388
- groups.append(group)
389
-
390
- min_deps = None
391
- for circle in groups[0]:
392
- deps, others = circle.why_equal_groups(groups[1:], level)
393
- if deps is None:
394
- continue
395
- for p, o in zip(points, [circle] + others):
396
- deps.append(self.edge_graph[p][o])
397
-
398
- if min_deps is None or len(deps) < len(min_deps):
399
- min_deps = deps
400
-
401
- if min_deps is None:
402
- return None
403
- return [d for d in min_deps if d is not None]
404
-
405
-
406
- def why_equal(x: Node, y: Node, level: int = None) -> list[Any]:
407
- if x == y:
408
- return []
409
- if not x._val or not y._val:
410
- return None
411
- if x._val == y._val:
412
- return []
413
- return x._val.why_equal([y._val], level)
414
-
415
-
416
- class Direction(Node):
417
- pass
418
-
419
-
420
- def get_lines_thru_all(*points: list[Point]) -> list[Line]:
421
- line2count = defaultdict(lambda: 0)
422
- points = set(points)
423
- for p in points:
424
- for l in p.neighbors(Line):
425
- line2count[l] += 1
426
- return [l for l, count in line2count.items() if count == len(points)]
427
-
428
-
429
- def line_of_and_why(
430
- points: list[Point], level: int = None
431
- ) -> tuple[Line, list[Any]]:
432
- """Why points are collinear."""
433
- for l0 in get_lines_thru_all(*points):
434
- for l in l0.equivs():
435
- if all([p in l.edge_graph for p in points]):
436
- x, y = l.points
437
- colls = list({x, y} | set(points))
438
- # if len(colls) < 3:
439
- # return l, []
440
- why = l.why_coll(colls, level)
441
- if why is not None:
442
- return l, why
443
-
444
- return None, None
445
-
446
-
447
- def get_circles_thru_all(*points: list[Point]) -> list[Circle]:
448
- circle2count = defaultdict(lambda: 0)
449
- points = set(points)
450
- for p in points:
451
- for c in p.neighbors(Circle):
452
- circle2count[c] += 1
453
- return [c for c, count in circle2count.items() if count == len(points)]
454
-
455
-
456
- def circle_of_and_why(
457
- points: list[Point], level: int = None
458
- ) -> tuple[Circle, list[Any]]:
459
- """Why points are concyclic."""
460
- for c0 in get_circles_thru_all(*points):
461
- for c in c0.equivs():
462
- if all([p in c.edge_graph for p in points]):
463
- cycls = list(set(points))
464
- why = c.why_cyclic(cycls, level)
465
- if why is not None:
466
- return c, why
467
-
468
- return None, None
469
-
470
-
471
- def name_map(struct: Any) -> Any:
472
- if isinstance(struct, list):
473
- return [name_map(x) for x in struct]
474
- elif isinstance(struct, tuple):
475
- return tuple([name_map(x) for x in struct])
476
- elif isinstance(struct, set):
477
- return set([name_map(x) for x in struct])
478
- elif isinstance(struct, dict):
479
- return {name_map(x): name_map(y) for x, y in struct.items()}
480
- else:
481
- return getattr(struct, 'name', '')
482
-
483
-
484
- class Angle(Node):
485
- """Node of type Angle."""
486
-
487
- def new_val(self) -> Measure:
488
- return Measure()
489
-
490
- def set_directions(self, d1: Direction, d2: Direction) -> None:
491
- self._d = d1, d2
492
-
493
- @property
494
- def directions(self) -> tuple[Direction, Direction]:
495
- d1, d2 = self._d
496
- if d1 is None or d2 is None:
497
- return d1, d2
498
- return d1.rep(), d2.rep()
499
-
500
-
501
- class Measure(Node):
502
- pass
503
-
504
-
505
- class Length(Node):
506
- pass
507
-
508
-
509
- class Ratio(Node):
510
- """Node of type Ratio."""
511
-
512
- def new_val(self) -> Value:
513
- return Value()
514
-
515
- def set_lengths(self, l1: Length, l2: Length) -> None:
516
- self._l = l1, l2
517
-
518
- @property
519
- def lengths(self) -> tuple[Length, Length]:
520
- l1, l2 = self._l
521
- if l1 is None or l2 is None:
522
- return l1, l2
523
- return l1.rep(), l2.rep()
524
-
525
-
526
- class Value(Node):
527
- pass
528
-
529
-
530
- def all_angles(
531
- d1: Direction, d2: Direction, level: int = None
532
- ) -> tuple[Angle, list[Direction], list[Direction]]:
533
- level = level or float('inf')
534
- d1s = d1.equivs_upto(level)
535
- d2s = d2.equivs_upto(level)
536
-
537
- for ang in d1.rep().neighbors(Angle):
538
- d1_, d2_ = ang._d
539
- if d1_ in d1s and d2_ in d2s:
540
- yield ang, d1s, d2s
541
-
542
-
543
- def all_ratios(
544
- d1, d2, level=None
545
- ) -> tuple[Angle, list[Direction], list[Direction]]:
546
- level = level or float('inf')
547
- d1s = d1.equivs_upto(level)
548
- d2s = d2.equivs_upto(level)
549
-
550
- for ang in d1.rep().neighbors(Ratio):
551
- d1_, d2_ = ang._l
552
- if d1_ in d1s and d2_ in d2s:
553
- yield ang, d1s, d2s
554
-
555
-
556
- RANKING = {
557
- Point: 0,
558
- Line: 1,
559
- Segment: 2,
560
- Circle: 3,
561
- Direction: 4,
562
- Length: 5,
563
- Angle: 6,
564
- Ratio: 7,
565
- Measure: 8,
566
- Value: 9,
567
- }
568
-
569
-
570
- def val_type(x: Node) -> Type[Node]:
571
- if isinstance(x, Line):
572
- return Direction
573
- if isinstance(x, Segment):
574
- return Length
575
- if isinstance(x, Angle):
576
- return Measure
577
- if isinstance(x, Ratio):
578
- return Value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Implements geometric objects used in the graph representation."""
17
+ from __future__ import annotations
18
+ from collections import defaultdict # pylint: disable=g-importing-member
19
+ from typing import Any, Type
20
+ import math
21
+ # pylint: disable=protected-access
22
+
23
+
24
+ class Node:
25
+ r"""Node in the proof state graph.
26
+
27
+ Can be Point, Line, Circle, etc.
28
+
29
+ Each node maintains a merge history to
30
+ other nodes if they are (found out to be) equivalent
31
+
32
+ a -> b -
33
+ \
34
+ c -> d -> e -> f -> g
35
+
36
+ d.merged_to = e
37
+ d.rep = g
38
+ d.merged_from = {a, b, c, d}
39
+ d.equivs = {a, b, c, d, e, f, g}
40
+ """
41
+
42
+ def __init__(self, name: str = '', graph: Any = None):
43
+ self.name = name or str(self)
44
+ self.graph = graph
45
+
46
+ self.edge_graph = {}
47
+ # Edge graph: what other nodes is connected to this node.
48
+ # edge graph = {
49
+ # other1: {self1: deps, self2: deps},
50
+ # other2: {self2: deps, self3: deps}
51
+ # }
52
+
53
+ self.merge_graph = {}
54
+ # Merge graph: history of merges with other nodes.
55
+ # merge_graph = {self1: {self2: deps1, self3: deps2}}
56
+
57
+ self.rep_by = None # represented by.
58
+ self.members = {self}
59
+
60
+ self._val = None
61
+ self._obj = None
62
+
63
+ self.deps = []
64
+
65
+ # numerical representation.
66
+ self.num = None
67
+ self.change = set() # what other nodes' num rely on this node?
68
+
69
+ def set_rep(self, node: Node) -> None:
70
+ if node == self:
71
+ return
72
+ self.rep_by = node
73
+ node.merge_edge_graph(self.edge_graph)
74
+ node.members.update(self.members)
75
+
76
+ def rep(self) -> Node:
77
+ x = self
78
+ while x.rep_by:
79
+ x = x.rep_by
80
+ return x
81
+
82
+ def why_rep(self) -> list[Any]:
83
+ return self.why_equal([self.rep()], None)
84
+
85
+ def rep_and_why(self) -> tuple[Node, list[Any]]:
86
+ rep = self.rep()
87
+ return rep, self.why_equal([rep], None)
88
+
89
+ def neighbors(
90
+ self, oftype: Type[Node], return_set: bool = False, do_rep: bool = True
91
+ ) -> list[Node]:
92
+ """Neighbors of this node in the proof state graph."""
93
+ if do_rep:
94
+ rep = self.rep()
95
+ else:
96
+ rep = self
97
+ result = set()
98
+
99
+ for n in rep.edge_graph:
100
+ if oftype is None or oftype and isinstance(n, oftype):
101
+ if do_rep:
102
+ result.add(n.rep())
103
+ else:
104
+ result.add(n)
105
+
106
+ if return_set:
107
+ return result
108
+ return list(result)
109
+
110
+ def merge_edge_graph(
111
+ self, new_edge_graph: dict[Node, dict[Node, list[Node]]]
112
+ ) -> None:
113
+ for x, xdict in new_edge_graph.items():
114
+ if x in self.edge_graph:
115
+ self.edge_graph[x].update(dict(xdict))
116
+ else:
117
+ self.edge_graph[x] = dict(xdict)
118
+
119
+ def merge(self, nodes: list[Node], deps: list[Any]) -> None:
120
+ for node in nodes:
121
+ self.merge_one(node, deps)
122
+
123
+ def merge_one(self, node: Node, deps: list[Any]) -> None:
124
+ node.rep().set_rep(self.rep())
125
+
126
+ if node in self.merge_graph:
127
+ return
128
+
129
+ self.merge_graph[node] = deps
130
+ node.merge_graph[self] = deps
131
+
132
+ def is_val(self, node: Node) -> bool:
133
+ return (
134
+ isinstance(self, Line)
135
+ and isinstance(node, Direction)
136
+ or isinstance(self, Segment)
137
+ and isinstance(node, Length)
138
+ or isinstance(self, Angle)
139
+ and isinstance(node, Measure)
140
+ or isinstance(self, Ratio)
141
+ and isinstance(node, Value)
142
+ )
143
+
144
+ def set_val(self, node: Node) -> None:
145
+ self._val = node
146
+
147
+ def set_obj(self, node: Node) -> None:
148
+ self._obj = node
149
+
150
+ @property
151
+ def val(self) -> Node:
152
+ if self._val is None:
153
+ return None
154
+ return self._val.rep()
155
+
156
+ @property
157
+ def obj(self) -> Node:
158
+ if self._obj is None:
159
+ return None
160
+ return self._obj.rep()
161
+
162
+ def equivs(self) -> set[Node]:
163
+ return self.rep().members
164
+
165
+ def connect_to(self, node: Node, deps: list[Any] = None) -> None:
166
+ rep = self.rep()
167
+
168
+ if node in rep.edge_graph:
169
+ rep.edge_graph[node].update({self: deps})
170
+ else:
171
+ rep.edge_graph[node] = {self: deps}
172
+
173
+ if self.is_val(node):
174
+ self.set_val(node)
175
+ node.set_obj(self)
176
+
177
+ def equivs_upto(self, level: int) -> dict[Node, Node]:
178
+ """What are the equivalent nodes up to a certain level."""
179
+ parent = {self: None}
180
+ visited = set()
181
+ queue = [self]
182
+ i = 0
183
+
184
+ while i < len(queue):
185
+ current = queue[i]
186
+ i += 1
187
+ visited.add(current)
188
+
189
+ for neighbor in current.merge_graph:
190
+ if (
191
+ level is not None
192
+ and current.merge_graph[neighbor].level is not None
193
+ and current.merge_graph[neighbor].level >= level
194
+ ):
195
+ continue
196
+ if neighbor not in visited:
197
+ queue.append(neighbor)
198
+ parent[neighbor] = current
199
+
200
+ return parent
201
+
202
+ def why_equal(self, others: list[Node], level: int) -> list[Any]:
203
+ """BFS why this node is equal to other nodes."""
204
+ others = set(others)
205
+ found = 0
206
+
207
+ parent = {}
208
+ queue = [self]
209
+ i = 0
210
+
211
+ while i < len(queue):
212
+ current = queue[i]
213
+ if current in others:
214
+ found += 1
215
+ if found == len(others):
216
+ break
217
+
218
+ i += 1
219
+
220
+ for neighbor in current.merge_graph:
221
+ if (
222
+ level is not None
223
+ and current.merge_graph[neighbor].level is not None
224
+ and current.merge_graph[neighbor].level >= level
225
+ ):
226
+ continue
227
+ if neighbor not in parent:
228
+ queue.append(neighbor)
229
+ parent[neighbor] = current
230
+
231
+ return bfs_backtrack(self, others, parent)
232
+
233
+ def why_equal_groups(
234
+ self, groups: list[list[Node]], level: int
235
+ ) -> tuple[list[Any], list[Node]]:
236
+ """BFS for why self is equal to at least one member of each group."""
237
+ others = [None for _ in groups]
238
+ found = 0
239
+
240
+ parent = {}
241
+ queue = [self]
242
+ i = 0
243
+
244
+ while i < len(queue):
245
+ current = queue[i]
246
+
247
+ for j, grp in enumerate(groups):
248
+ if others[j] is None and current in grp:
249
+ others[j] = current
250
+ found += 1
251
+
252
+ if found == len(others):
253
+ break
254
+
255
+ i += 1
256
+
257
+ for neighbor in current.merge_graph:
258
+ if (
259
+ level is not None
260
+ and current.merge_graph[neighbor].level is not None
261
+ and current.merge_graph[neighbor].level >= level
262
+ ):
263
+ continue
264
+ if neighbor not in parent:
265
+ queue.append(neighbor)
266
+ parent[neighbor] = current
267
+
268
+ return bfs_backtrack(self, others, parent), others
269
+
270
+ def why_val(self, level: int) -> list[Any]:
271
+ return self._val.why_equal([self.val], level)
272
+
273
+ def why_connect(self, node: Node, level: int = None) -> list[Any]:
274
+ rep = self.rep()
275
+ equivs = list(rep.edge_graph[node].keys())
276
+ if not equivs:
277
+ return None
278
+ equiv = equivs[0]
279
+ dep = rep.edge_graph[node][equiv]
280
+ return [dep] + self.why_equal(equiv, level)
281
+
282
+
283
+ def why_connect(*pairs: list[tuple[Node, Node]]) -> list[Any]:
284
+ result = []
285
+ for node1, node2 in pairs:
286
+ result += node1.why_connect(node2)
287
+ return result
288
+
289
+
290
+ def is_equiv(x: Node, y: Node, level: int = None) -> bool:
291
+ level = level or float('inf')
292
+ return x.why_equal([y], level) is not None
293
+
294
+
295
+ def is_equal(x: Node, y: Node, level: int = None) -> bool:
296
+ if x == y:
297
+ return True
298
+ if x._val is None or y._val is None:
299
+ return False
300
+ if x.val != y.val:
301
+ return False
302
+ return is_equiv(x._val, y._val, level)
303
+
304
+
305
+ def bfs_backtrack(
306
+ root: Node, leafs: list[Node], parent: dict[Node, Node]
307
+ ) -> list[Any]:
308
+ """Return the path given BFS trace of parent nodes."""
309
+ backtracked = {root} # no need to backtrack further when touching this set.
310
+ deps = []
311
+ for node in leafs:
312
+ if node is None:
313
+ return None
314
+ if node in backtracked:
315
+ continue
316
+ if node not in parent:
317
+ return None
318
+ while node not in backtracked:
319
+ backtracked.add(node)
320
+ deps.append(node.merge_graph[parent[node]])
321
+ node = parent[node]
322
+
323
+ return deps
324
+
325
+
326
+
327
+ class Point(Node):
328
+ pass
329
+
330
+
331
+ class Line(Node):
332
+ """Node of type Line."""
333
+
334
+ def new_val(self) -> Direction:
335
+ return Direction()
336
+
337
+ def why_coll(self, points: list[Point], level: int = None) -> list[Any]:
338
+ """Why points are connected to self."""
339
+ level = level or float('inf')
340
+
341
+ groups = []
342
+ for p in points:
343
+ group = [
344
+ l
345
+ for l, d in self.edge_graph[p].items()
346
+ if d is None or d.level < level
347
+ ]
348
+ if not group:
349
+ return None
350
+ groups.append(group)
351
+
352
+ min_deps = None
353
+ for line in groups[0]:
354
+ deps, others = line.why_equal_groups(groups[1:], level)
355
+ if deps is None:
356
+ continue
357
+ for p, o in zip(points, [line] + others):
358
+ deps.append(self.edge_graph[p][o])
359
+ if min_deps is None or len(deps) < len(min_deps):
360
+ min_deps = deps
361
+
362
+ if min_deps is None:
363
+ return None
364
+ return [d for d in min_deps if d is not None]
365
+
366
+
367
+ class Segment(Node):
368
+
369
+ def new_val(self) -> Length:
370
+ return Length()
371
+
372
+
373
+ class Circle(Node):
374
+ """Node of type Circle."""
375
+
376
+ def why_cyclic(self, points: list[Point], level: int = None) -> list[Any]:
377
+ """Why points are connected to self."""
378
+ level = level or float('inf')
379
+
380
+ groups = []
381
+ for p in points:
382
+ group = [
383
+ c
384
+ for c, d in self.edge_graph[p].items()
385
+ if d is None or d.level < level
386
+ ]
387
+ if not group:
388
+ return None
389
+ groups.append(group)
390
+
391
+ min_deps = None
392
+ for circle in groups[0]:
393
+ deps, others = circle.why_equal_groups(groups[1:], level)
394
+ if deps is None:
395
+ continue
396
+ for p, o in zip(points, [circle] + others):
397
+ deps.append(self.edge_graph[p][o])
398
+
399
+ if min_deps is None or len(deps) < len(min_deps):
400
+ min_deps = deps
401
+
402
+ if min_deps is None:
403
+ return None
404
+ return [d for d in min_deps if d is not None]
405
+
406
+ # geometry.py
407
+ class SemiCircle(Circle):
408
+ """Node of type SemiCircle, inheriting from Circle."""
409
+
410
+ def __init__(self, center: Point, radius: float):
411
+ """Initialize a semicircle with a center and radius."""
412
+ super().__init__(center, radius)
413
+
414
+ def contains_point(self, point: Point) -> bool:
415
+ """Check if a point lies inside the semicircle."""
416
+ # Check if point lies within the radius distance from the center (circle constraint)
417
+ if point.distance(self.center) > self.radius:
418
+ return False
419
+
420
+ # Additional logic to determine if the point is within the semicircle
421
+ return self.is_on_correct_side(point)
422
+
423
+ def is_on_correct_side(self, point: Point) -> bool:
424
+ """Check if the point is on the correct side of the semicircle."""
425
+ # Calculate the angle between the center and the point
426
+ angle = math.atan2(point.y - self.center.y, point.x - self.center.x)
427
+
428
+ # Determine the boundary angles of the semicircle
429
+ # Assuming the semicircle is oriented horizontally with the flat side down
430
+ start_angle = -math.pi / 2
431
+ end_angle = math.pi / 2
432
+
433
+ # Check if the point's angle lies within the boundary angles
434
+ return start_angle <= angle <= end_angle
435
+
436
+ def why_cyclic(self, points: list[Point], level: int = None) -> list[Any]:
437
+ """Override why_cyclic to apply semicircle constraints."""
438
+ cyclic_points = super().why_cyclic(points, level)
439
+ if cyclic_points is None:
440
+ return None
441
+
442
+ # Ensure that all points lie within the semicircle
443
+ if all(self.contains_point(p) for p in points):
444
+ return cyclic_points
445
+ return None
446
+
447
+ def why_equal(x: Node, y: Node, level: int = None) -> list[Any]:
448
+ if x == y:
449
+ return []
450
+ if not x._val or not y._val:
451
+ return None
452
+ if x._val == y._val:
453
+ return []
454
+ return x._val.why_equal([y._val], level)
455
+
456
+
457
+
458
+ class Direction(Node):
459
+ pass
460
+
461
+
462
+ def get_lines_thru_all(*points: list[Point]) -> list[Line]:
463
+ line2count = defaultdict(lambda: 0)
464
+ points = set(points)
465
+ for p in points:
466
+ for l in p.neighbors(Line):
467
+ line2count[l] += 1
468
+ return [l for l, count in line2count.items() if count == len(points)]
469
+
470
+
471
+ def line_of_and_why(
472
+ points: list[Point], level: int = None
473
+ ) -> tuple[Line, list[Any]]:
474
+ """Why points are collinear."""
475
+ for l0 in get_lines_thru_all(*points):
476
+ for l in l0.equivs():
477
+ if all([p in l.edge_graph for p in points]):
478
+ x, y = l.points
479
+ colls = list({x, y} | set(points))
480
+ # if len(colls) < 3:
481
+ # return l, []
482
+ why = l.why_coll(colls, level)
483
+ if why is not None:
484
+ return l, why
485
+
486
+ return None, None
487
+
488
+
489
+ def get_circles_thru_all(*points: list[Point]) -> list[Circle]:
490
+ circle2count = defaultdict(lambda: 0)
491
+ points = set(points)
492
+ for p in points:
493
+ for c in p.neighbors(Circle):
494
+ circle2count[c] += 1
495
+ return [c for c, count in circle2count.items() if count == len(points)]
496
+
497
+
498
+ def circle_of_and_why(
499
+ points: list[Point], level: int = None
500
+ ) -> tuple[Circle, list[Any]]:
501
+ """Why points are concyclic."""
502
+ for c0 in get_circles_thru_all(*points):
503
+ for c in c0.equivs():
504
+ if all([p in c.edge_graph for p in points]):
505
+ cycls = list(set(points))
506
+ why = c.why_cyclic(cycls, level)
507
+ if why is not None:
508
+ return c, why
509
+
510
+ return None, None
511
+
512
+
513
+ def name_map(struct: Any) -> Any:
514
+ if isinstance(struct, list):
515
+ return [name_map(x) for x in struct]
516
+ elif isinstance(struct, tuple):
517
+ return tuple([name_map(x) for x in struct])
518
+ elif isinstance(struct, set):
519
+ return set([name_map(x) for x in struct])
520
+ elif isinstance(struct, dict):
521
+ return {name_map(x): name_map(y) for x, y in struct.items()}
522
+ else:
523
+ return getattr(struct, 'name', '')
524
+
525
+
526
+ class Angle(Node):
527
+ """Node of type Angle."""
528
+
529
+ def new_val(self) -> Measure:
530
+ return Measure()
531
+
532
+ def set_directions(self, d1: Direction, d2: Direction) -> None:
533
+ self._d = d1, d2
534
+
535
+ @property
536
+ def directions(self) -> tuple[Direction, Direction]:
537
+ d1, d2 = self._d
538
+ if d1 is None or d2 is None:
539
+ return d1, d2
540
+ return d1.rep(), d2.rep()
541
+
542
+
543
+ class Measure(Node):
544
+ pass
545
+
546
+
547
+ class Length(Node):
548
+ pass
549
+
550
+
551
+ class Ratio(Node):
552
+ """Node of type Ratio."""
553
+
554
+ def new_val(self) -> Value:
555
+ return Value()
556
+
557
+ def set_lengths(self, l1: Length, l2: Length) -> None:
558
+ self._l = l1, l2
559
+
560
+ @property
561
+ def lengths(self) -> tuple[Length, Length]:
562
+ l1, l2 = self._l
563
+ if l1 is None or l2 is None:
564
+ return l1, l2
565
+ return l1.rep(), l2.rep()
566
+
567
+
568
+ class Value(Node):
569
+ pass
570
+
571
+
572
+ def all_angles(
573
+ d1: Direction, d2: Direction, level: int = None
574
+ ) -> tuple[Angle, list[Direction], list[Direction]]:
575
+ level = level or float('inf')
576
+ d1s = d1.equivs_upto(level)
577
+ d2s = d2.equivs_upto(level)
578
+
579
+ for ang in d1.rep().neighbors(Angle):
580
+ d1_, d2_ = ang._d
581
+ if d1_ in d1s and d2_ in d2s:
582
+ yield ang, d1s, d2s
583
+
584
+
585
+ def all_ratios(
586
+ d1, d2, level=None
587
+ ) -> tuple[Angle, list[Direction], list[Direction]]:
588
+ level = level or float('inf')
589
+ d1s = d1.equivs_upto(level)
590
+ d2s = d2.equivs_upto(level)
591
+
592
+ for ang in d1.rep().neighbors(Ratio):
593
+ d1_, d2_ = ang._l
594
+ if d1_ in d1s and d2_ in d2s:
595
+ yield ang, d1s, d2s
596
+
597
+
598
+ RANKING = {
599
+ Point: 0,
600
+ Line: 1,
601
+ Segment: 2,
602
+ Circle: 3,
603
+ SemiCircle: 3,
604
+ Direction: 4,
605
+ Length: 5,
606
+ Angle: 6,
607
+ Ratio: 7,
608
+ Measure: 8,
609
+ Value: 9,
610
+ }
611
+
612
+
613
+ def val_type(x: Node) -> Type[Node]:
614
+ if isinstance(x, Line):
615
+ return Direction
616
+ if isinstance(x, Segment):
617
+ return Length
618
+ if isinstance(x, Angle):
619
+ return Measure
620
+ if isinstance(x, Ratio):
621
+ return Value
ag4masses/alphageometry/graph.py CHANGED
The diff for this file is too large to render. See raw diff
 
ag4masses/alphageometry/graph_utils.py CHANGED
@@ -1,132 +1,132 @@
1
- # Copyright 2023 DeepMind Technologies Limited
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Utilizations for graph representation.
17
-
18
- Mainly for listing combinations and permutations of elements.
19
- """
20
-
21
- from geometry import Point
22
-
23
-
24
- def _cross(elems1, elems2):
25
- for e1 in elems1:
26
- for e2 in elems2:
27
- yield e1, e2
28
-
29
-
30
- def cross(elems1, elems2):
31
- return list(_cross(elems1, elems2))
32
-
33
-
34
- def _comb2(elems):
35
- if len(elems) < 2:
36
- return
37
- for i, e1 in enumerate(elems[:-1]):
38
- for e2 in elems[i + 1 :]:
39
- yield e1, e2
40
-
41
-
42
- def comb2(elems):
43
- return list(_comb2(elems))
44
-
45
-
46
- def _comb3(elems):
47
- if len(elems) < 3:
48
- return
49
- for i, e1 in enumerate(elems[:-2]):
50
- for j, e2 in enumerate(elems[i + 1 : -1]):
51
- for e3 in elems[i + j + 2 :]:
52
- yield e1, e2, e3
53
-
54
-
55
- def comb3(elems):
56
- return list(_comb3(elems))
57
-
58
-
59
- def _comb4(elems):
60
- if len(elems) < 4:
61
- return
62
- for i, e1 in enumerate(elems[:-3]):
63
- for j, e2 in enumerate(elems[i + 1 : -2]):
64
- for e3, e4 in _comb2(elems[i + j + 2 :]):
65
- yield e1, e2, e3, e4
66
-
67
-
68
- def comb4(elems):
69
- return list(_comb4(elems))
70
-
71
-
72
- def _perm2(elems):
73
- for e1, e2 in comb2(elems):
74
- yield e1, e2
75
- yield e2, e1
76
-
77
-
78
- def perm2(elems):
79
- return list(_perm2(elems))
80
-
81
-
82
- def _all_4points(l1, l2):
83
- p1s = l1.neighbors(Point)
84
- p2s = l2.neighbors(Point)
85
- for a, b in perm2(p1s):
86
- for c, d in perm2(p2s):
87
- yield a, b, c, d
88
-
89
-
90
- def all_4points(l1, l2):
91
- return list(_all_4points(l1, l2))
92
-
93
-
94
- def _all_8points(l1, l2, l3, l4):
95
- for a, b, c, d in all_4points(l1, l2):
96
- for e, f, g, h in all_4points(l3, l4):
97
- yield (a, b, c, d, e, f, g, h)
98
-
99
-
100
- def all_8points(l1, l2, l3, l4):
101
- return list(_all_8points(l1, l2, l3, l4))
102
-
103
-
104
- def _perm3(elems):
105
- for x in elems:
106
- for y in elems:
107
- if y == x:
108
- continue
109
- for z in elems:
110
- if z not in (x, y):
111
- yield x, y, z
112
-
113
-
114
- def perm3(elems):
115
- return list(_perm3(elems))
116
-
117
-
118
- def _perm4(elems):
119
- for x in elems:
120
- for y in elems:
121
- if y == x:
122
- continue
123
- for z in elems:
124
- if z in (x, y):
125
- continue
126
- for t in elems:
127
- if t not in (x, y, z):
128
- yield x, y, z, t
129
-
130
-
131
- def perm4(elems):
132
- return list(_perm4(elems))
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilizations for graph representation.
17
+
18
+ Mainly for listing combinations and permutations of elements.
19
+ """
20
+
21
+ from geometry import Point
22
+
23
+
24
+ def _cross(elems1, elems2):
25
+ for e1 in elems1:
26
+ for e2 in elems2:
27
+ yield e1, e2
28
+
29
+
30
+ def cross(elems1, elems2):
31
+ return list(_cross(elems1, elems2))
32
+
33
+
34
+ def _comb2(elems):
35
+ if len(elems) < 2:
36
+ return
37
+ for i, e1 in enumerate(elems[:-1]):
38
+ for e2 in elems[i + 1 :]:
39
+ yield e1, e2
40
+
41
+
42
+ def comb2(elems):
43
+ return list(_comb2(elems))
44
+
45
+
46
+ def _comb3(elems):
47
+ if len(elems) < 3:
48
+ return
49
+ for i, e1 in enumerate(elems[:-2]):
50
+ for j, e2 in enumerate(elems[i + 1 : -1]):
51
+ for e3 in elems[i + j + 2 :]:
52
+ yield e1, e2, e3
53
+
54
+
55
+ def comb3(elems):
56
+ return list(_comb3(elems))
57
+
58
+
59
+ def _comb4(elems):
60
+ if len(elems) < 4:
61
+ return
62
+ for i, e1 in enumerate(elems[:-3]):
63
+ for j, e2 in enumerate(elems[i + 1 : -2]):
64
+ for e3, e4 in _comb2(elems[i + j + 2 :]):
65
+ yield e1, e2, e3, e4
66
+
67
+
68
+ def comb4(elems):
69
+ return list(_comb4(elems))
70
+
71
+
72
+ def _perm2(elems):
73
+ for e1, e2 in comb2(elems):
74
+ yield e1, e2
75
+ yield e2, e1
76
+
77
+
78
+ def perm2(elems):
79
+ return list(_perm2(elems))
80
+
81
+
82
+ def _all_4points(l1, l2):
83
+ p1s = l1.neighbors(Point)
84
+ p2s = l2.neighbors(Point)
85
+ for a, b in perm2(p1s):
86
+ for c, d in perm2(p2s):
87
+ yield a, b, c, d
88
+
89
+
90
+ def all_4points(l1, l2):
91
+ return list(_all_4points(l1, l2))
92
+
93
+
94
+ def _all_8points(l1, l2, l3, l4):
95
+ for a, b, c, d in all_4points(l1, l2):
96
+ for e, f, g, h in all_4points(l3, l4):
97
+ yield (a, b, c, d, e, f, g, h)
98
+
99
+
100
+ def all_8points(l1, l2, l3, l4):
101
+ return list(_all_8points(l1, l2, l3, l4))
102
+
103
+
104
+ def _perm3(elems):
105
+ for x in elems:
106
+ for y in elems:
107
+ if y == x:
108
+ continue
109
+ for z in elems:
110
+ if z not in (x, y):
111
+ yield x, y, z
112
+
113
+
114
+ def perm3(elems):
115
+ return list(_perm3(elems))
116
+
117
+
118
+ def _perm4(elems):
119
+ for x in elems:
120
+ for y in elems:
121
+ if y == x:
122
+ continue
123
+ for z in elems:
124
+ if z in (x, y):
125
+ continue
126
+ for t in elems:
127
+ if t not in (x, y, z):
128
+ yield x, y, z, t
129
+
130
+
131
+ def perm4(elems):
132
+ return list(_perm4(elems))
ag4masses/alphageometry/inspect_defs.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import problem as pr
2
+
3
+ def inspect_definition():
4
+ # Load definitions from the file
5
+ defs = pr.Definition.from_txt_file('defs.txt', to_dict=True)
6
+
7
+ # Access the 'semicircle' definition
8
+ semicircle_def = defs.get('semicircle')
9
+
10
+ if semicircle_def:
11
+ # Print out the details of the 'semicircle' definition
12
+ print("Semicircle Definition:")
13
+ print(semicircle_def)
14
+
15
+ # Print specific attributes of the 'semicircle' definition
16
+ # Replace 'attribute_name' with the actual attribute names you want to print
17
+ if hasattr(semicircle_def, 'name'):
18
+ print(f"Name: {semicircle_def.name}")
19
+ if hasattr(semicircle_def, 'description'):
20
+ print(f"Description: {semicircle_def.description}")
21
+ if hasattr(semicircle_def, 'some_other_attribute'):
22
+ print(f"Some Other Attribute: {semicircle_def.some_other_attribute}")
23
+ else:
24
+ print("No definition found for 'semicircle'")
25
+
26
+ if __name__ == "__main__":
27
+ inspect_definition()
ag4masses/alphageometry/lm_inference.py CHANGED
@@ -1,189 +1,189 @@
1
- # Copyright 2023 DeepMind Technologies Limited
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Wrapper for language modeling inference implemented in Meliad."""
17
- from typing import Any, Dict
18
-
19
- import jax
20
- import models # pylint: disable=unused-import
21
- import t5.data
22
- from transformer import inference_utils
23
-
24
-
25
- np = jax.numpy
26
-
27
-
28
- Trainer = inference_utils.Trainer
29
-
30
- MetricsOutput = Dict[str, Any] # Metrics output by model.
31
-
32
-
33
- parse_gin_configuration = inference_utils.parse_gin_configuration
34
-
35
-
36
- class LanguageModelInference:
37
- """Meliad wrapper for LM inference."""
38
-
39
- def __init__(self, vocab_path: str, load_dir: str, mode='beam_search'):
40
- self.vocab = t5.data.SentencePieceVocabulary(vocab_path)
41
-
42
- # This task won't be pulling from a dataset.
43
- def null_iter_fn() -> None:
44
- return None
45
-
46
- process_summaries_f = inference_utils.models.process_summaries_function(
47
- self.vocab
48
- )
49
-
50
- trainer = inference_utils.training_loop.Trainer(
51
- get_training_dataset_iterator=null_iter_fn,
52
- get_test_dataset_iterator=None,
53
- pretty_print_input_function=None,
54
- process_summaries_function=process_summaries_f,
55
- load_dir=load_dir,
56
- workdir='', # Don't log or save checkpoints.
57
- replicate_mode=False,
58
- ) # Run on a single device at batch size 1.
59
- self.trainer = trainer
60
-
61
- # Create and initialize the model.
62
- (tstate, _, imodel, prngs) = trainer.initialize_model()
63
- self.imodel = imodel
64
- self.batch_size = imodel.task_config.batch_size
65
-
66
- self.n = imodel.num_heads
67
- self.h = imodel.head_size
68
-
69
- # Create an inference task.
70
- writers = {}
71
- self.task = trainer.create_training_task(mode, imodel, prngs, writers) # pylint: disable=too-many-function-args
72
-
73
- # Register any additional actions.
74
- # Actions are cleared first for use with colab.
75
- inference_utils.training_loop.clear_interstep_callbacks()
76
- inference_utils.training_loop.register_interstep_callbacks()
77
- self.tstate = tstate
78
-
79
- # some default parameters.
80
- eos = [0] * 1024
81
- for idx in self.encode_list(['.', ';']):
82
- eos[idx] = 1
83
-
84
- self.eos = np.array(eos, dtype=np.bfloat16)
85
- self.mask = jax.numpy.ones([1024], dtype=np.bfloat16)
86
-
87
- def decode(self, ids: list[int]) -> str:
88
- return self.vocab.decode(ids)
89
-
90
- def decode_list(self, tokens: list[int]) -> list[str]:
91
- return [self.decode([tok]) for tok in tokens]
92
-
93
- def encode(self, inputs_str: str) -> list[int]:
94
- return self.vocab.encode(inputs_str)
95
-
96
- def encode_list(self, inputs_strs: list[str]) -> list[int]:
97
- result = [self.vocab.encode(x) for x in inputs_strs]
98
- assert all([len(x) == 1 for x in result]), [
99
- self.decode(x) for x in result if len(x) != 1
100
- ]
101
- return [x[0] for x in result]
102
-
103
- def call(
104
- self,
105
- inputs: np.ndarray,
106
- dstate: tuple[dict[str, np.ndarray], ...] = None,
107
- eos: np.ndarray = None,
108
- mask: np.ndarray = None,
109
- ) -> MetricsOutput:
110
- """Call the meliad model."""
111
- batch_size, length = inputs.shape
112
- inputs = jax.numpy.pad(inputs, [(0, 0), (0, 1024 - length)])
113
-
114
- if eos is None:
115
- eos = self.eos
116
- if mask is None:
117
- mask = self.mask
118
-
119
- x = {'targets': inputs, 'length': length, 'eos': eos, 'mask': mask}
120
-
121
- if dstate is not None:
122
- x['start_of_sequence'] = jax.numpy.array([False] * batch_size)
123
- else:
124
- dstate = tuple(
125
- [{ # this dummy value will never be used.
126
- 'current_index': np.array([0] * batch_size, dtype=np.int32),
127
- 'keys': np.zeros(
128
- (batch_size, 2048, self.n, self.h), dtype=np.bfloat16
129
- ),
130
- 'values': np.zeros(
131
- (batch_size, 2048, self.n, self.h), dtype=np.bfloat16
132
- ),
133
- 'recurrent_kvq': None,
134
- 'relative_position_bias': np.zeros(
135
- (batch_size, self.n, 1, 1024), dtype=np.bfloat16
136
- ),
137
- }]
138
- * 12
139
- )
140
- x['start_of_sequence'] = jax.numpy.array([True] * batch_size)
141
-
142
- x['dstate'] = dstate
143
- _, metrics_np = self.task.run_step(self.tstate, x, 0)
144
- return metrics_np
145
-
146
- def beam_decode(
147
- self,
148
- inputs: str,
149
- eos_tokens: np.ndarray = None,
150
- mask_tokens: np.ndarray = None,
151
- dstate: dict[str, np.ndarray] = None,
152
- ) -> MetricsOutput:
153
- """Beam search."""
154
- inputs = jax.numpy.array([self.vocab.encode(inputs)] * self.batch_size)
155
-
156
- eos = self.eos
157
- if eos_tokens is not None:
158
- eos_ids = self.encode_list(eos_tokens)
159
- eos = np.array(
160
- [1 if idx in eos_ids else 0 for idx in range(1024)], dtype=np.bfloat16
161
- ).reshape((1, 1, 1024))
162
-
163
- mask = self.mask
164
- if mask_tokens is not None:
165
- mask_ids = self.encode_list(mask_tokens)
166
- mask = np.array(
167
- [0 if idx in mask_ids else 1 for idx in range(1024)],
168
- dtype=np.bfloat16,
169
- ).reshape((1, 1, 1024))
170
-
171
- metrics_np = self.call(inputs, dstate=dstate, eos=eos, mask=mask)
172
-
173
- finished_seqs = metrics_np['finished_seqs']
174
- finished_scores = metrics_np['finished_scores']
175
-
176
- seqs = []
177
- scores = []
178
- for seq, score in zip(finished_seqs, finished_scores):
179
- seq = self.decode(seq[1:])
180
- seqs.append(seq)
181
- scores.append(score)
182
-
183
- return {
184
- 'finished_seqs': finished_seqs,
185
- 'finished_scores': finished_scores,
186
- 'seqs_str': seqs,
187
- 'scores': scores,
188
- 'dstate': metrics_np['dstate'],
189
- }
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Wrapper for language modeling inference implemented in Meliad."""
17
+ from typing import Any, Dict
18
+
19
+ import jax
20
+ import models # pylint: disable=unused-import
21
+ import t5.data
22
+ from meliad_lib.meliad.transformer import inference_utils
23
+
24
+
25
+ np = jax.numpy
26
+
27
+
28
+ Trainer = inference_utils.Trainer
29
+
30
+ MetricsOutput = Dict[str, Any] # Metrics output by model.
31
+
32
+
33
+ parse_gin_configuration = inference_utils.parse_gin_configuration
34
+
35
+
36
+ class LanguageModelInference:
37
+ """Meliad wrapper for LM inference."""
38
+
39
+ def __init__(self, vocab_path: str, load_dir: str, mode='beam_search'):
40
+ self.vocab = t5.data.SentencePieceVocabulary(vocab_path)
41
+
42
+ # This task won't be pulling from a dataset.
43
+ def null_iter_fn() -> None:
44
+ return None
45
+
46
+ process_summaries_f = inference_utils.models.process_summaries_function(
47
+ self.vocab
48
+ )
49
+
50
+ trainer = inference_utils.training_loop.Trainer(
51
+ get_training_dataset_iterator=null_iter_fn,
52
+ get_test_dataset_iterator=None,
53
+ pretty_print_input_function=None,
54
+ process_summaries_function=process_summaries_f,
55
+ load_dir=load_dir,
56
+ workdir='', # Don't log or save checkpoints.
57
+ replicate_mode=False,
58
+ ) # Run on a single device at batch size 1.
59
+ self.trainer = trainer
60
+
61
+ # Create and initialize the model.
62
+ (tstate, _, imodel, prngs) = trainer.initialize_model()
63
+ self.imodel = imodel
64
+ self.batch_size = imodel.task_config.batch_size
65
+
66
+ self.n = imodel.num_heads
67
+ self.h = imodel.head_size
68
+
69
+ # Create an inference task.
70
+ writers = {}
71
+ self.task = trainer.create_training_task(mode, imodel, prngs, writers) # pylint: disable=too-many-function-args
72
+
73
+ # Register any additional actions.
74
+ # Actions are cleared first for use with colab.
75
+ inference_utils.training_loop.clear_interstep_callbacks()
76
+ inference_utils.training_loop.register_interstep_callbacks()
77
+ self.tstate = tstate
78
+
79
+ # some default parameters.
80
+ eos = [0] * 1024
81
+ for idx in self.encode_list(['.', ';']):
82
+ eos[idx] = 1
83
+
84
+ self.eos = np.array(eos, dtype=np.bfloat16)
85
+ self.mask = jax.numpy.ones([1024], dtype=np.bfloat16)
86
+
87
+ def decode(self, ids: list[int]) -> str:
88
+ return self.vocab.decode(ids)
89
+
90
+ def decode_list(self, tokens: list[int]) -> list[str]:
91
+ return [self.decode([tok]) for tok in tokens]
92
+
93
+ def encode(self, inputs_str: str) -> list[int]:
94
+ return self.vocab.encode(inputs_str)
95
+
96
+ def encode_list(self, inputs_strs: list[str]) -> list[int]:
97
+ result = [self.vocab.encode(x) for x in inputs_strs]
98
+ assert all([len(x) == 1 for x in result]), [
99
+ self.decode(x) for x in result if len(x) != 1
100
+ ]
101
+ return [x[0] for x in result]
102
+
103
+ def call(
104
+ self,
105
+ inputs: np.ndarray,
106
+ dstate: tuple[dict[str, np.ndarray], ...] = None,
107
+ eos: np.ndarray = None,
108
+ mask: np.ndarray = None,
109
+ ) -> MetricsOutput:
110
+ """Call the meliad model."""
111
+ batch_size, length = inputs.shape
112
+ inputs = jax.numpy.pad(inputs, [(0, 0), (0, 1024 - length)])
113
+
114
+ if eos is None:
115
+ eos = self.eos
116
+ if mask is None:
117
+ mask = self.mask
118
+
119
+ x = {'targets': inputs, 'length': length, 'eos': eos, 'mask': mask}
120
+
121
+ if dstate is not None:
122
+ x['start_of_sequence'] = jax.numpy.array([False] * batch_size)
123
+ else:
124
+ dstate = tuple(
125
+ [{ # this dummy value will never be used.
126
+ 'current_index': np.array([0] * batch_size, dtype=np.int32),
127
+ 'keys': np.zeros(
128
+ (batch_size, 2048, self.n, self.h), dtype=np.bfloat16
129
+ ),
130
+ 'values': np.zeros(
131
+ (batch_size, 2048, self.n, self.h), dtype=np.bfloat16
132
+ ),
133
+ 'recurrent_kvq': None,
134
+ 'relative_position_bias': np.zeros(
135
+ (batch_size, self.n, 1, 1024), dtype=np.bfloat16
136
+ ),
137
+ }]
138
+ * 12
139
+ )
140
+ x['start_of_sequence'] = jax.numpy.array([True] * batch_size)
141
+
142
+ x['dstate'] = dstate
143
+ _, metrics_np = self.task.run_step(self.tstate, x, 0)
144
+ return metrics_np
145
+
146
+ def beam_decode(
147
+ self,
148
+ inputs: str,
149
+ eos_tokens: np.ndarray = None,
150
+ mask_tokens: np.ndarray = None,
151
+ dstate: dict[str, np.ndarray] = None,
152
+ ) -> MetricsOutput:
153
+ """Beam search."""
154
+ inputs = jax.numpy.array([self.vocab.encode(inputs)] * self.batch_size)
155
+
156
+ eos = self.eos
157
+ if eos_tokens is not None:
158
+ eos_ids = self.encode_list(eos_tokens)
159
+ eos = np.array(
160
+ [1 if idx in eos_ids else 0 for idx in range(1024)], dtype=np.bfloat16
161
+ ).reshape((1, 1, 1024))
162
+
163
+ mask = self.mask
164
+ if mask_tokens is not None:
165
+ mask_ids = self.encode_list(mask_tokens)
166
+ mask = np.array(
167
+ [0 if idx in mask_ids else 1 for idx in range(1024)],
168
+ dtype=np.bfloat16,
169
+ ).reshape((1, 1, 1024))
170
+
171
+ metrics_np = self.call(inputs, dstate=dstate, eos=eos, mask=mask)
172
+
173
+ finished_seqs = metrics_np['finished_seqs']
174
+ finished_scores = metrics_np['finished_scores']
175
+
176
+ seqs = []
177
+ scores = []
178
+ for seq, score in zip(finished_seqs, finished_scores):
179
+ seq = self.decode(seq[1:])
180
+ seqs.append(seq)
181
+ scores.append(score)
182
+
183
+ return {
184
+ 'finished_seqs': finished_seqs,
185
+ 'finished_scores': finished_scores,
186
+ 'seqs_str': seqs,
187
+ 'scores': scores,
188
+ 'dstate': metrics_np['dstate'],
189
+ }
ag4masses/alphageometry/models.py CHANGED
@@ -1,178 +1,178 @@
1
- # Copyright 2023 DeepMind Technologies Limited
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Transformer language model generate mode."""
17
-
18
- from typing import Any, Tuple
19
- import beam_search
20
- import decoder_stack
21
- import gin
22
- import jax
23
- import jax.numpy as jnp
24
- from transformer import models
25
-
26
-
27
- @gin.configurable
28
- class DecoderOnlyLanguageModelGenerate(models.DecoderOnlyLanguageModel):
29
- """Decoder only language modeling in inference mode."""
30
-
31
- decoder_factory = decoder_stack.DecoderStackGenerate
32
-
33
- num_heads: int = gin.REQUIRED
34
- head_size: int = gin.REQUIRED
35
-
36
- def get_fake_input(self) -> dict[str, Any]:
37
- fake_input_dict = super().get_fake_input()
38
- b = self.task_config.batch_size
39
- n = self.num_heads
40
- h = self.head_size
41
- fake_input_dict.update({
42
- 'dstate': tuple(
43
- [{
44
- 'current_index': jnp.array([0] * b, dtype=jnp.int32),
45
- 'keys': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16),
46
- 'values': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16),
47
- 'recurrent_kvq': None,
48
- 'relative_position_bias': jnp.zeros(
49
- (b, n, 1, 1024), dtype=jnp.bfloat16
50
- ),
51
- }]
52
- * 12
53
- ),
54
- 'eos': jnp.zeros([1024], dtype=jnp.bfloat16),
55
- 'mask': jnp.ones([1024], dtype=jnp.bfloat16),
56
- 'length': 1,
57
- 'temperature': 1.0,
58
- })
59
- return fake_input_dict
60
-
61
- def __call__(self, inputs: ...) -> tuple[Any, dict[str, Any]]:
62
- # Make sure this code is not used on untested cases.
63
- if self.mode not in ['init', 'beam_search']:
64
- raise ValueError(f'{type(self)} cannot do mode {self.mode}')
65
- if self.decoder.supports_generate():
66
- raise ValueError(f'{type(self)}.decoder cannot supports_generate()')
67
-
68
- self.decoder(
69
- input_tokens=inputs['targets'][:, 0:1],
70
- target_tokens=None,
71
- start_of_sequence=inputs['start_of_sequence'],
72
- )
73
-
74
- b = inputs['targets'].shape[0]
75
- no_start_of_seq = jnp.array([False] * b, dtype=jnp.bool_)
76
-
77
- # This fn is used in both beam_search or topk_sampling.
78
- def tokens_to_logits_fn(
79
- input_token: jnp.ndarray, dstate: tuple[dict[str, jnp.ndarray], ...]
80
- ) -> tuple[jnp.ndarray, tuple[dict[str, jnp.ndarray], ...]]:
81
- (logits, dstate, _) = self.decoder(
82
- input_tokens=input_token,
83
- target_tokens=None,
84
- start_of_sequence=no_start_of_seq,
85
- decoder_state=dstate,
86
- )
87
- return logits[:, -1, :], dstate
88
-
89
- last_token = jax.lax.dynamic_slice_in_dim(
90
- inputs['targets'], inputs['length'] - 1, 1, axis=1
91
- )
92
-
93
- # last token is used to seed beam_search
94
- inputs['targets'] = inputs['targets'][:, 0:-1]
95
- dstate = jax.lax.cond(
96
- inputs['start_of_sequence'][0],
97
- lambda: self.generate(inputs)[0],
98
- lambda: inputs['dstate'],
99
- )
100
-
101
- # Then we run beam search, init with last_token & dstate.
102
- finished_seqs, finished_scores, dstate = beam_search.beam_search_flat(
103
- last_token,
104
- dstate,
105
- tokens_to_logits_fn,
106
- max_decode_len=512,
107
- eos=inputs['eos'].reshape((1, 1, -1)),
108
- mask=inputs['mask'].reshape((1, 1, -1)),
109
- )
110
-
111
- return 0.0, {
112
- 'finished_seqs': finished_seqs,
113
- 'finished_scores': finished_scores,
114
- 'dstate': dstate,
115
- }
116
-
117
- def generate(
118
- self, inputs: ...
119
- ) -> tuple[tuple[dict[str, jnp.ndarray, ...], ...], jnp.ndarray]:
120
- """Generate an output sequence.
121
-
122
- Args:
123
- inputs: the same as argument to _call_.
124
-
125
- Returns:
126
- An array of generated tokens of shape (batch_size, sequence_length).
127
- """
128
- input_tokens = inputs['targets'] # [b,seq_len]
129
- start_of_sequence = inputs['start_of_sequence'] # [b]
130
- target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)])
131
- batch_size = target_tokens.shape[0]
132
-
133
- # Assuming all sequences start at the same time.
134
- start0 = inputs['start_of_sequence'][0]
135
- dstate = jax.lax.cond(
136
- start0,
137
- lambda: self.decoder.init_decoder_state_vanilla( # pylint: disable=g-long-lambda
138
- 1024, start_of_sequence
139
- ),
140
- lambda: inputs['dstate'],
141
- )
142
-
143
- first_token = input_tokens[:, 0:1]
144
- no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_)
145
- temperature = 1
146
- if 'temperature' in inputs:
147
- temperature = inputs['temperature']
148
-
149
- num_steps = inputs['length']
150
- if self.mode == 'beam_search':
151
- num_steps -= 1
152
-
153
- def cond_fn(scan_state) -> jnp.bool_:
154
- _, _, i, _ = scan_state
155
- return i < num_steps
156
-
157
- def loop_fn(scan_state: Any) -> Tuple[Any, Any, Any, Any]:
158
- (dstate, input_token, i, _) = scan_state
159
-
160
- (logits, dstate, _) = self.decoder(
161
- input_tokens=input_token,
162
- target_tokens=None,
163
- start_of_sequence=no_start_of_seq,
164
- decoder_state=dstate,
165
- )
166
-
167
- logits = logits / temperature
168
- output_token = jax.lax.dynamic_slice_in_dim(target_tokens, i, 1, axis=1)
169
-
170
- return (dstate, output_token, i + 1, logits)
171
-
172
- # Scan over the sequence length.
173
- dummy_logits = jnp.zeros((batch_size, 1, 1024))
174
- initial_scan_state = (dstate, first_token, 0, dummy_logits)
175
- dstate, _, _, logits = jax.lax.while_loop(
176
- cond_fn, loop_fn, initial_scan_state
177
- )
178
- return dstate, logits
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Transformer language model generate mode."""
17
+
18
+ from typing import Any, Tuple
19
+ import beam_search
20
+ import decoder_stack
21
+ import gin
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from meliad_lib.meliad.transformer import models
25
+
26
+
27
+ @gin.configurable
28
+ class DecoderOnlyLanguageModelGenerate(models.DecoderOnlyLanguageModel):
29
+ """Decoder only language modeling in inference mode."""
30
+
31
+ decoder_factory = decoder_stack.DecoderStackGenerate
32
+
33
+ num_heads: int = gin.REQUIRED
34
+ head_size: int = gin.REQUIRED
35
+
36
+ def get_fake_input(self) -> dict[str, Any]:
37
+ fake_input_dict = super().get_fake_input()
38
+ b = self.task_config.batch_size
39
+ n = self.num_heads
40
+ h = self.head_size
41
+ fake_input_dict.update({
42
+ 'dstate': tuple(
43
+ [{
44
+ 'current_index': jnp.array([0] * b, dtype=jnp.int32),
45
+ 'keys': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16),
46
+ 'values': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16),
47
+ 'recurrent_kvq': None,
48
+ 'relative_position_bias': jnp.zeros(
49
+ (b, n, 1, 1024), dtype=jnp.bfloat16
50
+ ),
51
+ }]
52
+ * 12
53
+ ),
54
+ 'eos': jnp.zeros([1024], dtype=jnp.bfloat16),
55
+ 'mask': jnp.ones([1024], dtype=jnp.bfloat16),
56
+ 'length': 1,
57
+ 'temperature': 1.0,
58
+ })
59
+ return fake_input_dict
60
+
61
+ def __call__(self, inputs: ...) -> tuple[Any, dict[str, Any]]:
62
+ # Make sure this code is not used on untested cases.
63
+ if self.mode not in ['init', 'beam_search']:
64
+ raise ValueError(f'{type(self)} cannot do mode {self.mode}')
65
+ if self.decoder.supports_generate():
66
+ raise ValueError(f'{type(self)}.decoder cannot supports_generate()')
67
+
68
+ self.decoder(
69
+ input_tokens=inputs['targets'][:, 0:1],
70
+ target_tokens=None,
71
+ start_of_sequence=inputs['start_of_sequence'],
72
+ )
73
+
74
+ b = inputs['targets'].shape[0]
75
+ no_start_of_seq = jnp.array([False] * b, dtype=jnp.bool_)
76
+
77
+ # This fn is used in both beam_search or topk_sampling.
78
+ def tokens_to_logits_fn(
79
+ input_token: jnp.ndarray, dstate: tuple[dict[str, jnp.ndarray], ...]
80
+ ) -> tuple[jnp.ndarray, tuple[dict[str, jnp.ndarray], ...]]:
81
+ (logits, dstate, _) = self.decoder(
82
+ input_tokens=input_token,
83
+ target_tokens=None,
84
+ start_of_sequence=no_start_of_seq,
85
+ decoder_state=dstate,
86
+ )
87
+ return logits[:, -1, :], dstate
88
+
89
+ last_token = jax.lax.dynamic_slice_in_dim(
90
+ inputs['targets'], inputs['length'] - 1, 1, axis=1
91
+ )
92
+
93
+ # last token is used to seed beam_search
94
+ inputs['targets'] = inputs['targets'][:, 0:-1]
95
+ dstate = jax.lax.cond(
96
+ inputs['start_of_sequence'][0],
97
+ lambda: self.generate(inputs)[0],
98
+ lambda: inputs['dstate'],
99
+ )
100
+
101
+ # Then we run beam search, init with last_token & dstate.
102
+ finished_seqs, finished_scores, dstate = beam_search.beam_search_flat(
103
+ last_token,
104
+ dstate,
105
+ tokens_to_logits_fn,
106
+ max_decode_len=512,
107
+ eos=inputs['eos'].reshape((1, 1, -1)),
108
+ mask=inputs['mask'].reshape((1, 1, -1)),
109
+ )
110
+
111
+ return 0.0, {
112
+ 'finished_seqs': finished_seqs,
113
+ 'finished_scores': finished_scores,
114
+ 'dstate': dstate,
115
+ }
116
+
117
+ def generate(
118
+ self, inputs: ...
119
+ ) -> tuple[tuple[dict[str, jnp.ndarray, ...], ...], jnp.ndarray]:
120
+ """Generate an output sequence.
121
+
122
+ Args:
123
+ inputs: the same as argument to _call_.
124
+
125
+ Returns:
126
+ An array of generated tokens of shape (batch_size, sequence_length).
127
+ """
128
+ input_tokens = inputs['targets'] # [b,seq_len]
129
+ start_of_sequence = inputs['start_of_sequence'] # [b]
130
+ target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)])
131
+ batch_size = target_tokens.shape[0]
132
+
133
+ # Assuming all sequences start at the same time.
134
+ start0 = inputs['start_of_sequence'][0]
135
+ dstate = jax.lax.cond(
136
+ start0,
137
+ lambda: self.decoder.init_decoder_state_vanilla( # pylint: disable=g-long-lambda
138
+ 1024, start_of_sequence
139
+ ),
140
+ lambda: inputs['dstate'],
141
+ )
142
+
143
+ first_token = input_tokens[:, 0:1]
144
+ no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_)
145
+ temperature = 1
146
+ if 'temperature' in inputs:
147
+ temperature = inputs['temperature']
148
+
149
+ num_steps = inputs['length']
150
+ if self.mode == 'beam_search':
151
+ num_steps -= 1
152
+
153
+ def cond_fn(scan_state) -> jnp.bool_:
154
+ _, _, i, _ = scan_state
155
+ return i < num_steps
156
+
157
+ def loop_fn(scan_state: Any) -> Tuple[Any, Any, Any, Any]:
158
+ (dstate, input_token, i, _) = scan_state
159
+
160
+ (logits, dstate, _) = self.decoder(
161
+ input_tokens=input_token,
162
+ target_tokens=None,
163
+ start_of_sequence=no_start_of_seq,
164
+ decoder_state=dstate,
165
+ )
166
+
167
+ logits = logits / temperature
168
+ output_token = jax.lax.dynamic_slice_in_dim(target_tokens, i, 1, axis=1)
169
+
170
+ return (dstate, output_token, i + 1, logits)
171
+
172
+ # Scan over the sequence length.
173
+ dummy_logits = jnp.zeros((batch_size, 1, 1024))
174
+ initial_scan_state = (dstate, first_token, 0, dummy_logits)
175
+ dstate, _, _, logits = jax.lax.while_loop(
176
+ cond_fn, loop_fn, initial_scan_state
177
+ )
178
+ return dstate, logits
ag4masses/alphageometry/numericals.py CHANGED
@@ -25,14 +25,13 @@ from matplotlib import pyplot as plt
25
  import matplotlib.colors as mcolors
26
  import numpy as np
27
  from numpy.random import uniform as unif # pylint: disable=g-importing-member
 
28
 
29
-
30
- matplotlib.use('Agg')
31
 
32
 
33
  ATOM = 1e-12
34
 
35
-
36
  # Some variables are there for better code reading.
37
  # pylint: disable=unused-assignment
38
  # pylint: disable=unused-argument
@@ -440,6 +439,75 @@ class Circle:
440
  return [result]
441
 
442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  class HoleCircle(Circle):
444
  """Numerical circle with a missing point."""
445
 
@@ -565,6 +633,18 @@ def circle_segment_intersect(
565
  result.append(py)
566
  return result
567
 
 
 
 
 
 
 
 
 
 
 
 
 
568
 
569
  def line_segment_intersection(l: Line, A: Point, B: Point) -> Point: # pylint: disable=invalid-name
570
  a, b, c = l.coefficients
@@ -656,6 +736,13 @@ def check_circle(points: list[Point]) -> bool:
656
  oa, ob, oc = o.distance(a), o.distance(b), o.distance(c)
657
  return close_enough(oa, ob) and close_enough(ob, oc)
658
 
 
 
 
 
 
 
 
659
 
660
  def check_coll(points: list[Point]) -> bool:
661
  a, b = points[:2]
@@ -894,10 +981,12 @@ def naming_position(
894
  _ = ax
895
  r = 0.08
896
  c = Circle(center=p, radius=r)
 
897
  avoid = []
898
  for p1, p2 in lines:
899
  try:
900
  avoid.extend(circle_segment_intersect(c, p1, p2))
 
901
  except InvalidQuadSolveError:
902
  continue
903
  for x in circles:
@@ -928,6 +1017,7 @@ def draw_point(
928
  name: str,
929
  lines: list[Line],
930
  circles: list[Circle],
 
931
  color: Any = 'white',
932
  size: float = 15,
933
  ) -> None:
@@ -1029,6 +1119,133 @@ def draw_circle(
1029
  _draw_circle(ax, circle, color)
1030
  return circle
1031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1032
 
1033
  def mark_segment(
1034
  ax: matplotlib.axes.Axes, p1: Point, p2: Point, color: Any, alpha: float
@@ -1126,7 +1343,9 @@ def highlight(
1126
  _draw_line(ax, c, d, color=color2, lw=2.0, alpha=0.5)
1127
  _draw_line(ax, m, n, color=color1, lw=2.0, alpha=0.5)
1128
  _draw_line(ax, p, q, color=color2, lw=2.0, alpha=0.5)
1129
-
 
 
1130
 
1131
  HCOLORS = None
1132
 
@@ -1136,6 +1355,7 @@ def _draw(
1136
  points: list[gm.Point],
1137
  lines: list[gm.Line],
1138
  circles: list[gm.Circle],
 
1139
  goal: Any,
1140
  equals: list[tuple[Any, Any]],
1141
  highlights: list[tuple[str, list[gm.Point]]],
@@ -1158,9 +1378,10 @@ def _draw(
1158
  p1, p2 = draw_line(ax, l, color=lcolor)
1159
  line_boundaries.append((p1, p2))
1160
  circles = [draw_circle(ax, c, color=ccolor) for c in circles]
 
1161
 
1162
  for p in points:
1163
- draw_point(ax, p.num, p.name, line_boundaries, circles, color=pcolor)
1164
 
1165
  if equals:
1166
  for i, segs in enumerate(equals['segments']):
@@ -1204,6 +1425,7 @@ def draw(
1204
  points: list[gm.Point],
1205
  lines: list[gm.Line],
1206
  circles: list[gm.Circle],
 
1207
  segments: list[gm.Segment],
1208
  goal: Any = None,
1209
  highlights: list[tuple[str, list[gm.Point]]] = None,
@@ -1214,8 +1436,8 @@ def draw(
1214
  ) -> None:
1215
  """Draw everything on the same canvas."""
1216
  plt.close()
1217
- imsize = 512 / 100
1218
- fig, ax = plt.subplots(figsize=(imsize, imsize), dpi=100)
1219
 
1220
  set_theme(theme)
1221
 
@@ -1224,7 +1446,7 @@ def draw(
1224
  else:
1225
  ax.set_facecolor((1.0, 1.0, 1.0))
1226
 
1227
- _draw(ax, points, lines, circles, goal, equals, highlights)
1228
 
1229
  plt.axis('equal')
1230
  fig.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
@@ -1238,8 +1460,6 @@ def draw(
1238
  plt.savefig(save_to)
1239
  # plt.show(block=block)
1240
 
1241
-
1242
-
1243
  def close_enough(a: float, b: float, tol: float = 1e-12) -> bool:
1244
  return abs(a - b) < tol
1245
 
@@ -1560,6 +1780,9 @@ def sketch_circle(args: tuple[gm.Point, ...]) -> Circle:
1560
  a, b, c = args
1561
  return Circle(center=a, radius=b.distance(c))
1562
 
 
 
 
1563
 
1564
  def sketch_cc_tangent(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1565
  """Sketch tangents to two circles."""
 
25
  import matplotlib.colors as mcolors
26
  import numpy as np
27
  from numpy.random import uniform as unif # pylint: disable=g-importing-member
28
+ import graph as gh
29
 
30
+ matplotlib.use('TkAgg')
 
31
 
32
 
33
  ATOM = 1e-12
34
 
 
35
  # Some variables are there for better code reading.
36
  # pylint: disable=unused-assignment
37
  # pylint: disable=unused-argument
 
439
  return [result]
440
 
441
 
442
+ class SemiCircle(Circle):
443
+ """Numerical semicircle, inherits from Circle."""
444
+
445
+ def __init__(
446
+ self,
447
+ center: Optional[Point] = None,
448
+ radius: Optional[float] = None,
449
+ p1: Optional[Point] = None,
450
+ p2: Optional[Point] = None,
451
+ p3: Optional[Point] = None,
452
+ ):
453
+ self.p1 = p1
454
+ self.p2 = p2
455
+ self.p3 = p3
456
+ # Initialize as a Circle
457
+ super().__init__(center, radius, p1, p2, p3)
458
+ # If p1 and p2 define a diameter, set the center and radius accordingly
459
+ if p1 and p2 and not center:
460
+ self.center = Point((p1.x + p2.x) / 2, (p1.y + p2.y) / 2)
461
+ self.radius = p1.distance(p2) / 2
462
+ self.r2 = self.radius ** 2
463
+
464
+ # Define the direction or plane for the semicircle (important for sampling and boundaries)
465
+
466
+ def is_within_boundary(self, point: Point) -> bool:
467
+ """Check if a point is within the boundary of the semicircle."""
468
+ vector_to_point = point - self.center
469
+ angle = math.atan2(vector_to_point.y, vector_to_point.x)
470
+
471
+ # Normalize the angle within [0, 2*pi]
472
+ angle = angle if angle >= 0 else (2 * np.pi + angle)
473
+
474
+ # Check if the point is within the semicircle (half of the circle)
475
+ return -np.pi / 2 <= angle <= np.pi / 2
476
+
477
+ def sample_within(self, points: list[Point], n: int = 5) -> list[Point]:
478
+ """Sample a point within the semicircle."""
479
+ result = None
480
+ best = -1.0
481
+ for _ in range(n):
482
+ # Generate a random angle between -π/2 and π/2 for the semicircle
483
+ ang = unif(-0.5, 0.5) * np.pi
484
+ x = self.center + Point(np.cos(ang), np.sin(ang)) * self.radius
485
+
486
+ # Check if the sampled point is within the active part of the semicircle
487
+ if not self.is_within_boundary(x):
488
+ continue
489
+
490
+ # Find the minimum distance between the generated point and the provided points
491
+ mind = min([x.distance(p) for p in points])
492
+ if mind > best:
493
+ best = mind
494
+ result = x
495
+
496
+ return [result]
497
+
498
+ def intersect(self, obj: Union[Line, Circle]) -> tuple[Point, ...]:
499
+ """Find intersection points with a Line or another Circle, constrained to the semicircle."""
500
+ if isinstance(obj, Line):
501
+ intersections = obj.intersect(self)
502
+ elif isinstance(obj, Circle):
503
+ intersections = circle_circle_intersection(self, obj)
504
+ else:
505
+ return tuple()
506
+
507
+ # Filter intersections to only return points within the semicircle
508
+ return tuple(p for p in intersections if self.is_within_boundary(p))
509
+
510
+
511
  class HoleCircle(Circle):
512
  """Numerical circle with a missing point."""
513
 
 
633
  result.append(py)
634
  return result
635
 
636
+ def semicircle_segment_intersect(
637
+ circle: SemiCircle, p1: Point, p2: Point
638
+ ) -> list[Point]:
639
+ l = Line(p1, p2)
640
+ px, py = line_circle_intersection(l, circle)
641
+
642
+ result = []
643
+ if _check_between(px, p1, p2):
644
+ result.append(px)
645
+ if _check_between(py, p1, p2):
646
+ result.append(py)
647
+ return result
648
 
649
  def line_segment_intersection(l: Line, A: Point, B: Point) -> Point: # pylint: disable=invalid-name
650
  a, b, c = l.coefficients
 
736
  oa, ob, oc = o.distance(a), o.distance(b), o.distance(c)
737
  return close_enough(oa, ob) and close_enough(ob, oc)
738
 
739
+ def check_semicircle(points: list[Point]) -> bool:
740
+ if len(points) != 4:
741
+ return False
742
+ o, a, b, c = points
743
+ oa, ob, oc = o.distance(a), o.distance(b), o.distance(c)
744
+ return close_enough(oa, ob) and close_enough(ob, oc)
745
+
746
 
747
  def check_coll(points: list[Point]) -> bool:
748
  a, b = points[:2]
 
981
  _ = ax
982
  r = 0.08
983
  c = Circle(center=p, radius=r)
984
+ sc = SemiCircle(center=p, radius=r)
985
  avoid = []
986
  for p1, p2 in lines:
987
  try:
988
  avoid.extend(circle_segment_intersect(c, p1, p2))
989
+ avoid.extend(semicircle_segment_intersect(sc, p1, p2))
990
  except InvalidQuadSolveError:
991
  continue
992
  for x in circles:
 
1017
  name: str,
1018
  lines: list[Line],
1019
  circles: list[Circle],
1020
+ semicircles: list[SemiCircle],
1021
  color: Any = 'white',
1022
  size: float = 15,
1023
  ) -> None:
 
1119
  _draw_circle(ax, circle, color)
1120
  return circle
1121
 
1122
+ def check_points_semicircle(p1, p2, p3):
1123
+ """
1124
+ Check if three points are in a semicircle, and determine the circle center, radius,
1125
+ and points forming the diameter if applicable. If no pair forms a diameter, calculate
1126
+ the circle center passing through all three points.
1127
+
1128
+ Parameters:
1129
+ p1, p2, p3 (tuple): Three points as (x, y) coordinates.
1130
+
1131
+ Returns:
1132
+ dict: A dictionary containing:
1133
+ - 'center': (cx, cy), the circle center.
1134
+ - 'radius': The radius of the circle.
1135
+ - 'diameter_points': A tuple of two points that form the diameter (or None).
1136
+ - 'is_valid': True if a circle can be formed; False otherwise.
1137
+ """
1138
+ # Unpack points
1139
+ x1, y1 = p1
1140
+ x2, y2 = p2
1141
+ x3, y3 = p3
1142
+
1143
+ # Calculate circumcenter
1144
+ A = np.array([[x1 - x2, y1 - y2], [x1 - x3, y1 - y3]])
1145
+ B = np.array([((x1**2 - x2**2) + (y1**2 - y2**2)) / 2, ((x1**2 - x3**2) + (y1**2 - y3**2)) / 2])
1146
+
1147
+ try:
1148
+ center = np.linalg.solve(A, B) # Solving linear system to get circle center
1149
+ except np.linalg.LinAlgError:
1150
+ return {'is_valid': False} # Points are collinear, no unique circle
1151
+
1152
+ cx, cy = center
1153
+ radius = np.sqrt((x1 - cx)**2 + (y1 - cy)**2)
1154
+
1155
+ # Function to check if two points form a diameter
1156
+ def is_diameter(px, py, qx, qy):
1157
+ midpoint_x, midpoint_y = (px + qx) / 2, (py + qy) / 2
1158
+ return np.isclose(midpoint_x, cx) and np.isclose(midpoint_y, cy)
1159
+
1160
+ # Check for diameter
1161
+ if is_diameter(x1, y1, x2, y2):
1162
+ diameter_points = (p1, p2)
1163
+ elif is_diameter(x1, y1, x3, y3):
1164
+ diameter_points = (p1, p3)
1165
+ elif is_diameter(x2, y2, x3, y3):
1166
+ diameter_points = (p2, p3)
1167
+ else:
1168
+ diameter_points = None # No pair forms a diameter; use circumcenter
1169
+
1170
+ return {
1171
+ 'center': center,
1172
+ 'radius': radius,
1173
+ 'diameter_points': diameter_points,
1174
+ 'is_valid': True
1175
+ }
1176
+
1177
+ def _draw_semicircle(
1178
+ ax: matplotlib.axes.Axes, P1: Point, P2: Point, P3: Point, color: Any = 'cyan', lw: float = 1.2
1179
+ ) -> None:
1180
+ """
1181
+ Draws a semicircle passing through three points or with one or two points on the diameter.
1182
+
1183
+ Parameters:
1184
+ ax (matplotlib.axes.Axes): The Matplotlib Axes on which the semicircle will be drawn.
1185
+ P1, P2, P3 (Point): The three points through which the semicircle will pass.
1186
+ color (Any): Color of the semicircle.
1187
+ lw (float): Line width of the semicircle.
1188
+ """
1189
+ result = check_points_semicircle((P1.x, P1.y), (P2.x, P2.y), (P3.x, P3.y))
1190
+ if not result['is_valid']:
1191
+ print("Points are collinear; cannot form a semicircle.")
1192
+ return
1193
+
1194
+ cx, cy = result['center']
1195
+ radius = result['radius']
1196
+ diameter_points = result['diameter_points']
1197
+
1198
+ # If no pair forms a diameter, determine angles for all three points
1199
+ if diameter_points is None:
1200
+ # Calculate angles of all three points relative to the circle's center
1201
+ angles = np.arctan2(
1202
+ [P1.y - cy, P2.y - cy, P3.y - cy],
1203
+ [P1.x - cx, P2.x - cx, P3.x - cx]
1204
+ )
1205
+ angles = (angles + 2 * np.pi) % (2 * np.pi) # Normalize to [0, 2π]
1206
+
1207
+ # Determine the start and end angle for the semicircle
1208
+ start_angle = np.min(angles)
1209
+ end_angle = np.max(angles)
1210
+ if end_angle - start_angle > np.pi:
1211
+ start_angle, end_angle = end_angle, start_angle + 2 * np.pi
1212
+ else:
1213
+ # Use diameter points to define the semicircle angles
1214
+ px, py = diameter_points[0]
1215
+ qx, qy = diameter_points[1]
1216
+ start_angle = np.arctan2(py - cy, px - cx)
1217
+ end_angle = np.arctan2(qy - cy, qx - cx)
1218
+ if end_angle - start_angle > np.pi:
1219
+ start_angle, end_angle = end_angle, start_angle + 2 * np.pi
1220
+
1221
+ # Generate points for the semicircle
1222
+ t = np.linspace(start_angle, end_angle, 100)
1223
+ x = cx + radius * np.cos(t)
1224
+ y = cy + radius * np.sin(t)
1225
+
1226
+ # Plot the semicircle
1227
+ ax.plot(x, y, color=color, lw=lw)
1228
+
1229
+ def draw_semicircle(
1230
+ ax: matplotlib.axes.Axes, semicircle: SemiCircle, color: Any = 'cyan'
1231
+ ) -> SemiCircle:
1232
+ """Draw a semicircle."""
1233
+ if semicircle.num is not None:
1234
+ semicircle = semicircle.num
1235
+ else:
1236
+ points = semicircle.neighbors(gm.Point)
1237
+ if len(points) <= 2:
1238
+ return
1239
+ points = [p.num for p in points]
1240
+ p1, p2, p3 = points[:3]
1241
+ semicircle = SemiCircle(p1=p1, p2=p2, p3=p3)
1242
+ print(semicircle.p1, semicircle.p2, semicircle.p3)
1243
+ _draw_semicircle(ax, semicircle.p1, semicircle.p2, semicircle.p3, color=color)
1244
+ _draw_line(ax, semicircle.p1, semicircle.p2)
1245
+ _draw_line(ax, semicircle.p2, semicircle.p3)
1246
+ _draw_line(ax, semicircle.p1, semicircle.p3)
1247
+ return semicircle
1248
+
1249
 
1250
  def mark_segment(
1251
  ax: matplotlib.axes.Axes, p1: Point, p2: Point, color: Any, alpha: float
 
1343
  _draw_line(ax, c, d, color=color2, lw=2.0, alpha=0.5)
1344
  _draw_line(ax, m, n, color=color1, lw=2.0, alpha=0.5)
1345
  _draw_line(ax, p, q, color=color2, lw=2.0, alpha=0.5)
1346
+ elif name == 'semicircle':
1347
+ o, a, b, c = args
1348
+ _draw_semicircle(ax, SemiCircle(center=o, p1=a, p2=b, p3=c), color=color1, lw=2.0)
1349
 
1350
  HCOLORS = None
1351
 
 
1355
  points: list[gm.Point],
1356
  lines: list[gm.Line],
1357
  circles: list[gm.Circle],
1358
+ semicircles: list[gm.SemiCircle],
1359
  goal: Any,
1360
  equals: list[tuple[Any, Any]],
1361
  highlights: list[tuple[str, list[gm.Point]]],
 
1378
  p1, p2 = draw_line(ax, l, color=lcolor)
1379
  line_boundaries.append((p1, p2))
1380
  circles = [draw_circle(ax, c, color=ccolor) for c in circles]
1381
+ semicircles = [draw_semicircle(ax, c, color=ccolor) for c in semicircles]
1382
 
1383
  for p in points:
1384
+ draw_point(ax, p.num, p.name, line_boundaries, circles, semicircles, color=pcolor)
1385
 
1386
  if equals:
1387
  for i, segs in enumerate(equals['segments']):
 
1425
  points: list[gm.Point],
1426
  lines: list[gm.Line],
1427
  circles: list[gm.Circle],
1428
+ semicircles: list[gm.SemiCircle],
1429
  segments: list[gm.Segment],
1430
  goal: Any = None,
1431
  highlights: list[tuple[str, list[gm.Point]]] = None,
 
1436
  ) -> None:
1437
  """Draw everything on the same canvas."""
1438
  plt.close()
1439
+ imsize = 1280 / 200
1440
+ fig, ax = plt.subplots(figsize=(imsize, imsize), dpi=200)
1441
 
1442
  set_theme(theme)
1443
 
 
1446
  else:
1447
  ax.set_facecolor((1.0, 1.0, 1.0))
1448
 
1449
+ _draw(ax, points, lines, circles, semicircles, goal, equals, highlights)
1450
 
1451
  plt.axis('equal')
1452
  fig.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
 
1460
  plt.savefig(save_to)
1461
  # plt.show(block=block)
1462
 
 
 
1463
  def close_enough(a: float, b: float, tol: float = 1e-12) -> bool:
1464
  return abs(a - b) < tol
1465
 
 
1780
  a, b, c = args
1781
  return Circle(center=a, radius=b.distance(c))
1782
 
1783
+ def sketch_semicircle(args: tuple[gm.Point, ...]) -> SemiCircle:
1784
+ a, b, c = args
1785
+ return SemiCircle(center=a, radius=b.distance(c), p1=b, p2=c)
1786
 
1787
  def sketch_cc_tangent(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1788
  """Sketch tangents to two circles."""
ag4masses/alphageometry/pretty.py CHANGED
@@ -1,216 +1,216 @@
1
- # Copyright 2023 DeepMind Technologies Limited
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Utilities for string manipulation in the DSL."""
17
-
18
- MAP_SYMBOL = {
19
- 'T': 'perp',
20
- 'P': 'para',
21
- 'D': 'cong',
22
- 'S': 'simtri',
23
- 'I': 'circle',
24
- 'M': 'midp',
25
- 'O': 'cyclic',
26
- 'C': 'coll',
27
- '^': 'eqangle',
28
- '/': 'eqratio',
29
- '%': 'eqratio',
30
- '=': 'contri',
31
- 'X': 'collx',
32
- 'A': 'acompute',
33
- 'R': 'rcompute',
34
- 'Q': 'fixc',
35
- 'E': 'fixl',
36
- 'V': 'fixb',
37
- 'H': 'fixt',
38
- 'Z': 'fixp',
39
- 'Y': 'ind',
40
- }
41
-
42
-
43
- def map_symbol(c: str) -> str:
44
- return MAP_SYMBOL[c]
45
-
46
-
47
- def map_symbol_inv(c: str) -> str:
48
- return {v: k for k, v in MAP_SYMBOL.items()}[c]
49
-
50
-
51
- def _gcd(x: int, y: int) -> int:
52
- while y:
53
- x, y = y, x % y
54
- return x
55
-
56
-
57
- def simplify(n: int, d: int) -> tuple[int, int]:
58
- g = _gcd(n, d)
59
- return (n // g, d // g)
60
-
61
-
62
- def pretty2r(a: str, b: str, c: str, d: str) -> str:
63
- if b in (c, d):
64
- a, b = b, a
65
-
66
- if a == d:
67
- c, d = d, c
68
-
69
- return f'{a} {b} {c} {d}'
70
-
71
-
72
- def pretty2a(a: str, b: str, c: str, d: str) -> str:
73
- if b in (c, d):
74
- a, b = b, a
75
-
76
- if a == d:
77
- c, d = d, c
78
-
79
- return f'{a} {b} {c} {d}'
80
-
81
-
82
- def pretty_angle(a: str, b: str, c: str, d: str) -> str:
83
- if b in (c, d):
84
- a, b = b, a
85
- if a == d:
86
- c, d = d, c
87
-
88
- if a == c:
89
- return f'\u2220{b}{a}{d}'
90
- return f'\u2220({a}{b}-{c}{d})'
91
-
92
-
93
- def pretty_nl(name: str, args: list[str]) -> str:
94
- """Natural lang formatting a predicate."""
95
- if name == 'aconst':
96
- a, b, c, d, y = args
97
- return f'{pretty_angle(a, b, c, d)} = {y}'
98
- if name == 'rconst':
99
- a, b, c, d, y = args
100
- return f'{a}{b}:{c}{d} = {y}'
101
- if name == 'acompute':
102
- a, b, c, d = args
103
- return f'{pretty_angle(a, b, c, d)}'
104
- if name in ['coll', 'C']:
105
- return '' + ','.join(args) + ' are collinear'
106
- if name == 'collx':
107
- return '' + ','.join(list(set(args))) + ' are collinear'
108
- if name in ['cyclic', 'O']:
109
- return '' + ','.join(args) + ' are concyclic'
110
- if name in ['midp', 'midpoint', 'M']:
111
- x, a, b = args
112
- return f'{x} is midpoint of {a}{b}'
113
- if name in ['eqangle', 'eqangle6', '^']:
114
- a, b, c, d, e, f, g, h = args
115
- return f'{pretty_angle(a, b, c, d)} = {pretty_angle(e, f, g, h)}'
116
- if name in ['eqratio', 'eqratio6', '/']:
117
- return '{}{}:{}{} = {}{}:{}{}'.format(*args)
118
- if name == 'eqratio3':
119
- a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name
120
- return f'S {o} {a} {b} {o} {c} {d}'
121
- if name in ['cong', 'D']:
122
- a, b, c, d = args
123
- return f'{a}{b} = {c}{d}'
124
- if name in ['perp', 'T']:
125
- if len(args) == 2: # this is algebraic derivation.
126
- ab, cd = args # ab = 'd( ... )'
127
- return f'{ab} \u27c2 {cd}'
128
- a, b, c, d = args
129
- return f'{a}{b} \u27c2 {c}{d}'
130
- if name in ['para', 'P']:
131
- if len(args) == 2: # this is algebraic derivation.
132
- ab, cd = args # ab = 'd( ... )'
133
- return f'{ab} \u2225 {cd}'
134
- a, b, c, d = args
135
- return f'{a}{b} \u2225 {c}{d}'
136
- if name in ['simtri2', 'simtri', 'simtri*']:
137
- a, b, c, x, y, z = args
138
- return f'\u0394{a}{b}{c} is similar to \u0394{x}{y}{z}'
139
- if name in ['contri2', 'contri', 'contri*']:
140
- a, b, c, x, y, z = args
141
- return f'\u0394{a}{b}{c} is congruent to \u0394{x}{y}{z}'
142
- if name in ['circle', 'I']:
143
- o, a, b, c = args
144
- return f'{o} is the circumcenter of \\Delta {a}{b}{c}'
145
- if name == 'foot':
146
- a, b, c, d = args
147
- return f'{a} is the foot of {b} on {c}{d}'
148
-
149
-
150
- def pretty(txt: str) -> str:
151
- """Pretty formating a predicate string."""
152
- if isinstance(txt, str):
153
- txt = txt.split(' ')
154
- name, *args = txt
155
- if name == 'ind':
156
- return 'Y ' + ' '.join(args)
157
- if name in ['fixc', 'fixl', 'fixb', 'fixt', 'fixp']:
158
- return map_symbol_inv(name) + ' ' + ' '.join(args)
159
- if name == 'acompute':
160
- a, b, c, d = args
161
- return 'A ' + ' '.join(args)
162
- if name == 'rcompute':
163
- a, b, c, d = args
164
- return 'R ' + ' '.join(args)
165
- if name == 'aconst':
166
- a, b, c, d, y = args
167
- return f'^ {pretty2a(a, b, c, d)} {y}'
168
- if name == 'rconst':
169
- a, b, c, d, y = args
170
- return f'/ {pretty2r(a, b, c, d)} {y}'
171
- if name == 'coll':
172
- return 'C ' + ' '.join(args)
173
- if name == 'collx':
174
- return 'X ' + ' '.join(args)
175
- if name == 'cyclic':
176
- return 'O ' + ' '.join(args)
177
- if name in ['midp', 'midpoint']:
178
- x, a, b = args
179
- return f'M {x} {a} {b}'
180
- if name == 'eqangle':
181
- a, b, c, d, e, f, g, h = args
182
- return f'^ {pretty2a(a, b, c, d)} {pretty2a(e, f, g, h)}'
183
- if name == 'eqratio':
184
- a, b, c, d, e, f, g, h = args
185
- return f'/ {pretty2r(a, b, c, d)} {pretty2r(e, f, g, h)}'
186
- if name == 'eqratio3':
187
- a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name
188
- return f'S {o} {a} {b} {o} {c} {d}'
189
- if name == 'cong':
190
- a, b, c, d = args
191
- return f'D {a} {b} {c} {d}'
192
- if name == 'perp':
193
- if len(args) == 2: # this is algebraic derivation.
194
- ab, cd = args # ab = 'd( ... )'
195
- return f'T {ab} {cd}'
196
- a, b, c, d = args
197
- return f'T {a} {b} {c} {d}'
198
- if name == 'para':
199
- if len(args) == 2: # this is algebraic derivation.
200
- ab, cd = args # ab = 'd( ... )'
201
- return f'P {ab} {cd}'
202
- a, b, c, d = args
203
- return f'P {a} {b} {c} {d}'
204
- if name in ['simtri2', 'simtri', 'simtri*']:
205
- a, b, c, x, y, z = args
206
- return f'S {a} {b} {c} {x} {y} {z}'
207
- if name in ['contri2', 'contri', 'contri*']:
208
- a, b, c, x, y, z = args
209
- return f'= {a} {b} {c} {x} {y} {z}'
210
- if name == 'circle':
211
- o, a, b, c = args
212
- return f'I {o} {a} {b} {c}'
213
- if name == 'foot':
214
- a, b, c, d = args
215
- return f'F {a} {b} {c} {d}'
216
- return ' '.join(txt)
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilities for string manipulation in the DSL."""
17
+
18
+ MAP_SYMBOL = {
19
+ 'T': 'perp',
20
+ 'P': 'para',
21
+ 'D': 'cong',
22
+ 'S': 'simtri',
23
+ 'I': 'circle',
24
+ 'M': 'midp',
25
+ 'O': 'cyclic',
26
+ 'C': 'coll',
27
+ '^': 'eqangle',
28
+ '/': 'eqratio',
29
+ '%': 'eqratio',
30
+ '=': 'contri',
31
+ 'X': 'collx',
32
+ 'A': 'acompute',
33
+ 'R': 'rcompute',
34
+ 'Q': 'fixc',
35
+ 'E': 'fixl',
36
+ 'V': 'fixb',
37
+ 'H': 'fixt',
38
+ 'Z': 'fixp',
39
+ 'Y': 'ind',
40
+ }
41
+
42
+
43
+ def map_symbol(c: str) -> str:
44
+ return MAP_SYMBOL[c]
45
+
46
+
47
+ def map_symbol_inv(c: str) -> str:
48
+ return {v: k for k, v in MAP_SYMBOL.items()}[c]
49
+
50
+
51
+ def _gcd(x: int, y: int) -> int:
52
+ while y:
53
+ x, y = y, x % y
54
+ return x
55
+
56
+
57
+ def simplify(n: int, d: int) -> tuple[int, int]:
58
+ g = _gcd(n, d)
59
+ return (n // g, d // g)
60
+
61
+
62
+ def pretty2r(a: str, b: str, c: str, d: str) -> str:
63
+ if b in (c, d):
64
+ a, b = b, a
65
+
66
+ if a == d:
67
+ c, d = d, c
68
+
69
+ return f'{a} {b} {c} {d}'
70
+
71
+
72
+ def pretty2a(a: str, b: str, c: str, d: str) -> str:
73
+ if b in (c, d):
74
+ a, b = b, a
75
+
76
+ if a == d:
77
+ c, d = d, c
78
+
79
+ return f'{a} {b} {c} {d}'
80
+
81
+
82
+ def pretty_angle(a: str, b: str, c: str, d: str) -> str:
83
+ if b in (c, d):
84
+ a, b = b, a
85
+ if a == d:
86
+ c, d = d, c
87
+
88
+ if a == c:
89
+ return f'\u2220{b}{a}{d}'
90
+ return f'\u2220({a}{b}-{c}{d})'
91
+
92
+
93
+ def pretty_nl(name: str, args: list[str]) -> str:
94
+ """Natural lang formatting a predicate."""
95
+ if name == 'aconst':
96
+ a, b, c, d, y = args
97
+ return f'{pretty_angle(a, b, c, d)} = {y}'
98
+ if name == 'rconst':
99
+ a, b, c, d, y = args
100
+ return f'{a}{b}:{c}{d} = {y}'
101
+ if name == 'acompute':
102
+ a, b, c, d = args
103
+ return f'{pretty_angle(a, b, c, d)}'
104
+ if name in ['coll', 'C']:
105
+ return '' + ','.join(args) + ' are collinear'
106
+ if name == 'collx':
107
+ return '' + ','.join(list(set(args))) + ' are collinear'
108
+ if name in ['cyclic', 'O']:
109
+ return '' + ','.join(args) + ' are concyclic'
110
+ if name in ['midp', 'midpoint', 'M']:
111
+ x, a, b = args
112
+ return f'{x} is midpoint of {a}{b}'
113
+ if name in ['eqangle', 'eqangle6', '^']:
114
+ a, b, c, d, e, f, g, h = args
115
+ return f'{pretty_angle(a, b, c, d)} = {pretty_angle(e, f, g, h)}'
116
+ if name in ['eqratio', 'eqratio6', '/']:
117
+ return '{}{}:{}{} = {}{}:{}{}'.format(*args)
118
+ if name == 'eqratio3':
119
+ a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name
120
+ return f'S {o} {a} {b} {o} {c} {d}'
121
+ if name in ['cong', 'D']:
122
+ a, b, c, d = args
123
+ return f'{a}{b} = {c}{d}'
124
+ if name in ['perp', 'T']:
125
+ if len(args) == 2: # this is algebraic derivation.
126
+ ab, cd = args # ab = 'd( ... )'
127
+ return f'{ab} \u27c2 {cd}'
128
+ a, b, c, d = args
129
+ return f'{a}{b} \u27c2 {c}{d}'
130
+ if name in ['para', 'P']:
131
+ if len(args) == 2: # this is algebraic derivation.
132
+ ab, cd = args # ab = 'd( ... )'
133
+ return f'{ab} \u2225 {cd}'
134
+ a, b, c, d = args
135
+ return f'{a}{b} \u2225 {c}{d}'
136
+ if name in ['simtri2', 'simtri', 'simtri*']:
137
+ a, b, c, x, y, z = args
138
+ return f'\u0394{a}{b}{c} is similar to \u0394{x}{y}{z}'
139
+ if name in ['contri2', 'contri', 'contri*']:
140
+ a, b, c, x, y, z = args
141
+ return f'\u0394{a}{b}{c} is congruent to \u0394{x}{y}{z}'
142
+ if name in ['circle', 'I']:
143
+ o, a, b, c = args
144
+ return f'{o} is the circumcenter of \\Delta {a}{b}{c}'
145
+ if name == 'foot':
146
+ a, b, c, d = args
147
+ return f'{a} is the foot of {b} on {c}{d}'
148
+
149
+
150
+ def pretty(txt: str) -> str:
151
+ """Pretty formating a predicate string."""
152
+ if isinstance(txt, str):
153
+ txt = txt.split(' ')
154
+ name, *args = txt
155
+ if name == 'ind':
156
+ return 'Y ' + ' '.join(args)
157
+ if name in ['fixc', 'fixl', 'fixb', 'fixt', 'fixp']:
158
+ return map_symbol_inv(name) + ' ' + ' '.join(args)
159
+ if name == 'acompute':
160
+ a, b, c, d = args
161
+ return 'A ' + ' '.join(args)
162
+ if name == 'rcompute':
163
+ a, b, c, d = args
164
+ return 'R ' + ' '.join(args)
165
+ if name == 'aconst':
166
+ a, b, c, d, y = args
167
+ return f'^ {pretty2a(a, b, c, d)} {y}'
168
+ if name == 'rconst':
169
+ a, b, c, d, y = args
170
+ return f'/ {pretty2r(a, b, c, d)} {y}'
171
+ if name == 'coll':
172
+ return 'C ' + ' '.join(args)
173
+ if name == 'collx':
174
+ return 'X ' + ' '.join(args)
175
+ if name == 'cyclic':
176
+ return 'O ' + ' '.join(args)
177
+ if name in ['midp', 'midpoint']:
178
+ x, a, b = args
179
+ return f'M {x} {a} {b}'
180
+ if name == 'eqangle':
181
+ a, b, c, d, e, f, g, h = args
182
+ return f'^ {pretty2a(a, b, c, d)} {pretty2a(e, f, g, h)}'
183
+ if name == 'eqratio':
184
+ a, b, c, d, e, f, g, h = args
185
+ return f'/ {pretty2r(a, b, c, d)} {pretty2r(e, f, g, h)}'
186
+ if name == 'eqratio3':
187
+ a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name
188
+ return f'S {o} {a} {b} {o} {c} {d}'
189
+ if name == 'cong':
190
+ a, b, c, d = args
191
+ return f'D {a} {b} {c} {d}'
192
+ if name == 'perp':
193
+ if len(args) == 2: # this is algebraic derivation.
194
+ ab, cd = args # ab = 'd( ... )'
195
+ return f'T {ab} {cd}'
196
+ a, b, c, d = args
197
+ return f'T {a} {b} {c} {d}'
198
+ if name == 'para':
199
+ if len(args) == 2: # this is algebraic derivation.
200
+ ab, cd = args # ab = 'd( ... )'
201
+ return f'P {ab} {cd}'
202
+ a, b, c, d = args
203
+ return f'P {a} {b} {c} {d}'
204
+ if name in ['simtri2', 'simtri', 'simtri*']:
205
+ a, b, c, x, y, z = args
206
+ return f'S {a} {b} {c} {x} {y} {z}'
207
+ if name in ['contri2', 'contri', 'contri*']:
208
+ a, b, c, x, y, z = args
209
+ return f'= {a} {b} {c} {x} {y} {z}'
210
+ if name == 'circle':
211
+ o, a, b, c = args
212
+ return f'I {o} {a} {b} {c}'
213
+ if name == 'foot':
214
+ a, b, c, d = args
215
+ return f'F {a} {b} {c} {d}'
216
+ return ' '.join(txt)
ag4masses/alphageometry/problem.py CHANGED
@@ -1,1133 +1,1152 @@
1
- # Copyright 2023 DeepMind Technologies Limited
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Implements objects to represent problems, theorems, proofs, traceback."""
17
-
18
- from __future__ import annotations
19
-
20
- from collections import defaultdict # pylint: disable=g-importing-member
21
- from typing import Any
22
-
23
- import geometry as gm
24
- import pretty as pt
25
-
26
-
27
- # pylint: disable=protected-access
28
- # pylint: disable=unused-variable
29
- # pylint: disable=unused-argument
30
- # pylint: disable=unused-assignment
31
-
32
-
33
- def reshape(l: list[Any], n: int = 1) -> list[list[Any]]:
34
- assert len(l) % n == 0
35
- columns = [[] for i in range(n)]
36
- for i, x in enumerate(l):
37
- columns[i % n].append(x)
38
- return zip(*columns)
39
-
40
-
41
- def isint(x: str) -> bool:
42
- try:
43
- int(x)
44
- return True
45
- except: # pylint: disable=bare-except
46
- return False
47
-
48
-
49
- class Construction:
50
- """One predicate."""
51
-
52
- @classmethod
53
- def from_txt(cls, data: str) -> Construction:
54
- data = data.split(' ')
55
- return Construction(data[0], data[1:])
56
-
57
- def __init__(self, name: str, args: list[str]):
58
- self.name = name
59
- self.args = args
60
-
61
- def translate(self, mapping: dict[str, str]) -> Construction:
62
- args = [a if isint(a) else mapping[a] for a in self.args]
63
- return Construction(self.name, args)
64
-
65
- def txt(self) -> str:
66
- return ' '.join([self.name] + list(self.args))
67
-
68
-
69
- class Clause:
70
- """One construction (>= 1 predicate)."""
71
-
72
- @classmethod
73
- def from_txt(cls, data: str) -> Clause:
74
- if data == ' =':
75
- return Clause([], [])
76
- points, constructions = data.split(' = ')
77
- return Clause(
78
- points.split(' '),
79
- [Construction.from_txt(c) for c in constructions.split(', ')],
80
- )
81
-
82
- def __init__(self, points: list[str], constructions: list[Construction]):
83
- self.points = []
84
- self.nums = []
85
-
86
- for p in points:
87
- num = None
88
- if isinstance(p, str) and '@' in p:
89
- p, num = p.split('@')
90
- x, y = num.split('_')
91
- num = float(x), float(y)
92
- self.points.append(p)
93
- self.nums.append(num)
94
-
95
- self.constructions = constructions
96
-
97
- def translate(self, mapping: dict[str, str]) -> Clause:
98
- points0 = []
99
- for p in self.points:
100
- pcount = len(mapping) + 1
101
- name = chr(96 + pcount)
102
- if name > 'z': # pcount = 26 -> name = 'z'
103
- name = chr(97 + (pcount - 1) % 26) + str((pcount - 1) // 26)
104
-
105
- p0 = mapping.get(p, name)
106
- mapping[p] = p0
107
- points0.append(p0)
108
- return Clause(points0, [c.translate(mapping) for c in self.constructions])
109
-
110
- def add(self, name: str, args: list[str]) -> None:
111
- self.constructions.append(Construction(name, args))
112
-
113
- def txt(self) -> str:
114
- return (
115
- ' '.join(self.points)
116
- + ' = '
117
- + ', '.join(c.txt() for c in self.constructions)
118
- )
119
-
120
-
121
- def _gcd(x: int, y: int) -> int:
122
- while y:
123
- x, y = y, x % y
124
- return x
125
-
126
-
127
- def simplify(n: int, d: int) -> tuple[int, int]:
128
- g = _gcd(n, d)
129
- return (n // g, d // g)
130
-
131
-
132
- def compare_fn(dep: Dependency) -> tuple[Dependency, str]:
133
- return (dep, pt.pretty(dep))
134
-
135
-
136
- def sort_deps(deps: list[Dependency]) -> list[Dependency]:
137
- return sorted(deps, key=compare_fn)
138
-
139
-
140
- class Problem:
141
- """Describe one problem to solve."""
142
-
143
- @classmethod
144
- def from_txt_file(
145
- cls, fname: str, to_dict: bool = False, translate: bool = True
146
- ):
147
- """Load a problem from a text file."""
148
- with open(fname, 'r') as f:
149
- lines = f.read().split('\n')
150
-
151
- lines = [l for l in lines if l]
152
- data = [
153
- cls.from_txt(url + '\n' + problem, translate)
154
- for (url, problem) in reshape(lines, 2)
155
- ]
156
- if to_dict:
157
- return cls.to_dict(data)
158
- return data
159
-
160
- @classmethod
161
- def from_txt(cls, data: str, translate: bool = True) -> Problem:
162
- """Load a problem from a str object."""
163
- url = ''
164
- if '\n' in data:
165
- url, data = data.split('\n')
166
-
167
- if ' ? ' in data:
168
- clauses, goal = data.split(' ? ')
169
- goal = Construction.from_txt(goal)
170
- else:
171
- clauses, goal = data, None
172
-
173
- clauses = clauses.split('; ')
174
- problem = Problem(
175
- url=url, clauses=[Clause.from_txt(c) for c in clauses], goal=goal
176
- )
177
- if translate:
178
- return problem.translate()
179
- return problem
180
-
181
- @classmethod
182
- def to_dict(cls, data: list[Problem]) -> dict[str, Problem]:
183
- return {p.url: p for p in data}
184
-
185
- def __init__(self, url: str, clauses: list[Clause], goal: Construction):
186
- self.url = url
187
- self.clauses = clauses
188
- self.goal = goal
189
-
190
- def copy(self) -> Problem:
191
- return Problem(self.url, list(self.clauses), self.goal)
192
-
193
- def translate(self) -> Problem: # to single-char point names
194
- """Translate point names into alphabetical."""
195
- mapping = {}
196
- clauses = []
197
-
198
- for clause in self.clauses:
199
- clauses.append(clause.translate(mapping))
200
-
201
- if self.goal:
202
- goal = self.goal.translate(mapping)
203
- else:
204
- goal = self.goal
205
-
206
- p = Problem(self.url, clauses, goal)
207
- p.mapping = mapping
208
- return p
209
-
210
- def txt(self) -> str:
211
- return (
212
- '; '.join([c.txt() for c in self.clauses]) + ' ? ' + self.goal.txt()
213
- if self.goal
214
- else ''
215
- )
216
-
217
- def setup_str_from_problem(self, definitions: list[Definition]) -> str:
218
- """Construct the <theorem_premises> string from Problem object."""
219
- ref = 0
220
-
221
- string = []
222
- for clause in self.clauses:
223
- group = {}
224
- p2deps = defaultdict(list)
225
- for c in clause.constructions:
226
- cdef = definitions[c.name]
227
-
228
- if len(c.args) != len(cdef.construction.args):
229
- assert len(c.args) + len(clause.points) == len(cdef.construction.args)
230
- c.args = clause.points + c.args
231
-
232
- mapping = dict(zip(cdef.construction.args, c.args))
233
- for points, bs in cdef.basics:
234
- points = tuple([mapping[x] for x in points])
235
- for p in points:
236
- group[p] = points
237
-
238
- for b in bs:
239
- args = [mapping[a] for a in b.args]
240
- name = b.name
241
- if b.name in ['s_angle', 'aconst']:
242
- x, y, z, v = args
243
- name = 'aconst'
244
- v = int(v)
245
-
246
- if v < 0:
247
- v = -v
248
- x, z = z, x
249
-
250
- m, n = simplify(int(v), 180)
251
- args = [y, z, y, x, f'{m}pi/{n}']
252
-
253
- p2deps[points].append(hashed_txt(name, args))
254
-
255
- for k, v in p2deps.items():
256
- p2deps[k] = sort_deps(v)
257
-
258
- points = clause.points
259
- while points:
260
- p = points[0]
261
- gr = group[p]
262
- points = [x for x in points if x not in gr]
263
-
264
- deps_str = []
265
- for dep in p2deps[gr]:
266
- ref_str = '{:02}'.format(ref)
267
- dep_str = pt.pretty(dep)
268
-
269
- if dep[0] == 'aconst':
270
- m, n = map(int, dep[-1].split('pi/'))
271
- mn = f'{m}. pi / {n}.'
272
- dep_str = ' '.join(dep_str.split()[:-1] + [mn])
273
-
274
- deps_str.append(dep_str + ' ' + ref_str)
275
- ref += 1
276
-
277
- string.append(' '.join(gr) + ' : ' + ' '.join(deps_str))
278
-
279
- string = '{S} ' + ' ; '.join([s.strip() for s in string])
280
- goal = self.goal
281
- string += ' ? ' + pt.pretty([goal.name] + goal.args)
282
- return string
283
-
284
-
285
- def parse_rely(s: str) -> dict[str, str]:
286
- result = {}
287
- if not s:
288
- return result
289
- s = [x.strip() for x in s.split(',')]
290
- for x in s:
291
- a, b = x.split(':')
292
- a, b = a.strip().split(), b.strip().split()
293
- result.update({m: b for m in a})
294
- return result
295
-
296
-
297
- class Definition:
298
- """Definitions of construction statements."""
299
-
300
- @classmethod
301
- def from_txt_file(cls, fname: str, to_dict: bool = False) -> Definition:
302
- with open(fname, 'r') as f:
303
- lines = f.read()
304
- return cls.from_string(lines, to_dict)
305
-
306
- @classmethod
307
- def from_string(cls, string: str, to_dict: bool = False) -> Definition:
308
- lines = string.split('\n')
309
- data = [cls.from_txt('\n'.join(group)) for group in reshape(lines, 6)]
310
- if to_dict:
311
- return cls.to_dict(data)
312
- return data
313
-
314
- @classmethod
315
- def to_dict(cls, data: list[Definition]) -> dict[str, Definition]:
316
- return {d.construction.name: d for d in data}
317
-
318
- @classmethod
319
- def from_txt(cls, data: str) -> Definition:
320
- """Load definitions from a str object."""
321
- construction, rely, deps, basics, numerics, _ = data.split('\n')
322
- basics = [] if not basics else [b.strip() for b in basics.split(';')]
323
-
324
- levels = []
325
- for bs in basics:
326
- if ':' in bs:
327
- points, bs = bs.split(':')
328
- points = points.strip().split()
329
- else:
330
- points = []
331
- if bs.strip():
332
- bs = [Construction.from_txt(b.strip()) for b in bs.strip().split(',')]
333
- else:
334
- bs = []
335
- levels.append((points, bs))
336
-
337
- numerics = [] if not numerics else numerics.split(', ')
338
-
339
- return Definition(
340
- construction=Construction.from_txt(construction),
341
- rely=parse_rely(rely),
342
- deps=Clause.from_txt(deps),
343
- basics=levels,
344
- numerics=[Construction.from_txt(c) for c in numerics],
345
- )
346
-
347
- def __init__(
348
- self,
349
- construction: Construction,
350
- rely: dict[str, str],
351
- deps: Clause,
352
- basics: list[tuple[list[str], list[Construction]]],
353
- numerics: list[Construction],
354
- ):
355
- self.construction = construction
356
- self.rely = rely
357
- self.deps = deps
358
- self.basics = basics
359
- self.numerics = numerics
360
-
361
- args = set()
362
- for num in numerics:
363
- args.update(num.args)
364
-
365
- self.points = []
366
- self.args = []
367
- for p in self.construction.args:
368
- if p in args:
369
- self.args.append(p)
370
- else:
371
- self.points.append(p)
372
-
373
-
374
- class Theorem:
375
- """Deduction rule."""
376
-
377
- @classmethod
378
- def from_txt_file(cls, fname: str, to_dict: bool = False) -> Theorem:
379
- with open(fname, 'r') as f:
380
- theorems = f.read()
381
- return cls.from_string(theorems, to_dict)
382
-
383
- @classmethod
384
- def from_string(cls, string: str, to_dict: bool = False) -> Theorem:
385
- """Load deduction rule from a str object."""
386
- theorems = string.split('\n')
387
- theorems = [l for l in theorems if l and not l.startswith('#')]
388
- theorems = [cls.from_txt(l) for l in theorems]
389
-
390
- for i, th in enumerate(theorems):
391
- th.rule_name = 'r{:02}'.format(i)
392
-
393
- if to_dict:
394
- result = {}
395
- for t in theorems:
396
- if t.name in result:
397
- t.name += '_'
398
- result[t.rule_name] = t
399
-
400
- return result
401
-
402
- return theorems
403
-
404
- @classmethod
405
- def from_txt(cls, data: str) -> Theorem:
406
- premises, conclusion = data.split(' => ')
407
- premises = premises.split(', ')
408
- conclusion = conclusion.split(', ')
409
- return Theorem(
410
- premise=[Construction.from_txt(p) for p in premises],
411
- conclusion=[Construction.from_txt(c) for c in conclusion],
412
- )
413
-
414
- def __init__(
415
- self, premise: list[Construction], conclusion: list[Construction]
416
- ):
417
- if len(conclusion) != 1:
418
- raise ValueError('Cannot have more than one conclusion')
419
- self.name = '_'.join([p.name for p in premise + conclusion])
420
- self.premise = premise
421
- self.conclusion = conclusion
422
- self.is_arg_reduce = False
423
-
424
- assert len(self.conclusion) == 1
425
- con = self.conclusion[0]
426
-
427
- if con.name in [
428
- 'eqratio3',
429
- 'midp',
430
- 'contri',
431
- 'simtri',
432
- 'contri2',
433
- 'simtri2',
434
- 'simtri*',
435
- 'contri*',
436
- ]:
437
- return
438
-
439
- prem_args = set(sum([p.args for p in self.premise], []))
440
- con_args = set(con.args)
441
- if len(prem_args) <= len(con_args):
442
- self.is_arg_reduce = True
443
-
444
- def txt(self) -> str:
445
- premise_txt = ', '.join([clause.txt() for clause in self.premise])
446
- conclusion_txt = ', '.join([clause.txt() for clause in self.conclusion])
447
- return f'{premise_txt} => {conclusion_txt}'
448
-
449
- def conclusion_name_args(
450
- self, mapping: dict[str, gm.Point]
451
- ) -> tuple[str, list[gm.Point]]:
452
- mapping = {arg: p for arg, p in mapping.items() if isinstance(arg, str)}
453
- c = self.conclusion[0]
454
- args = [mapping[a] for a in c.args]
455
- return c.name, args
456
-
457
-
458
- def why_eqratio(
459
- d1: gm.Direction,
460
- d2: gm.Direction,
461
- d3: gm.Direction,
462
- d4: gm.Direction,
463
- level: int,
464
- ) -> list[Dependency]:
465
- """Why two ratios are equal, returns a Dependency objects."""
466
- all12 = list(gm.all_ratios(d1, d2, level))
467
- all34 = list(gm.all_ratios(d3, d4, level))
468
-
469
- min_why = None
470
- for ang12, d1s, d2s in all12:
471
- for ang34, d3s, d4s in all34:
472
- why0 = gm.why_equal(ang12, ang34, level)
473
- if why0 is None:
474
- continue
475
- d1_, d2_ = ang12._l
476
- d3_, d4_ = ang34._l
477
- why1 = gm.bfs_backtrack(d1, [d1_], d1s)
478
- why2 = gm.bfs_backtrack(d2, [d2_], d2s)
479
- why3 = gm.bfs_backtrack(d3, [d3_], d3s)
480
- why4 = gm.bfs_backtrack(d4, [d4_], d4s)
481
- why = why0 + why1 + why2 + why3 + why4
482
- if min_why is None or len(why) < len(min_why[0]):
483
- min_why = why, ang12, ang34, why0, why1, why2, why3, why4
484
-
485
- if min_why is None:
486
- return None
487
-
488
- _, ang12, ang34, why0, why1, why2, why3, why4 = min_why
489
- d1_, d2_ = ang12._l
490
- d3_, d4_ = ang34._l
491
-
492
- if d1 == d1_ and d2 == d2_ and d3 == d3_ and d4 == d4_:
493
- return why0
494
-
495
- (a_, b_), (c_, d_) = d1_._obj.points, d2_._obj.points
496
- (e_, f_), (g_, h_) = d3_._obj.points, d4_._obj.points
497
- deps = []
498
- if why0:
499
- dep = Dependency('eqratio', [a_, b_, c_, d_, e_, f_, g_, h_], '', level)
500
- dep.why = why0
501
- deps.append(dep)
502
-
503
- (a, b), (c, d) = d1._obj.points, d2._obj.points
504
- (e, f), (g, h) = d3._obj.points, d4._obj.points
505
- for why, (x, y), (x_, y_) in zip(
506
- [why1, why2, why3, why4],
507
- [(a, b), (c, d), (e, f), (g, h)],
508
- [(a_, b_), (c_, d_), (e_, f_), (g_, h_)],
509
- ):
510
- if why:
511
- dep = Dependency('cong', [x, y, x_, y_], '', level)
512
- dep.why = why
513
- deps.append(dep)
514
-
515
- return deps
516
-
517
-
518
- def why_eqangle(
519
- d1: gm.Direction,
520
- d2: gm.Direction,
521
- d3: gm.Direction,
522
- d4: gm.Direction,
523
- level: int,
524
- verbose: bool = False,
525
- ) -> list[Dependency]:
526
- """Why two angles are equal, returns a Dependency objects."""
527
- all12 = list(gm.all_angles(d1, d2, level))
528
- all34 = list(gm.all_angles(d3, d4, level))
529
-
530
- min_why = None
531
- for ang12, d1s, d2s in all12:
532
- for ang34, d3s, d4s in all34:
533
- why0 = gm.why_equal(ang12, ang34, level)
534
- if why0 is None:
535
- continue
536
- d1_, d2_ = ang12._d
537
- d3_, d4_ = ang34._d
538
- why1 = gm.bfs_backtrack(d1, [d1_], d1s)
539
- why2 = gm.bfs_backtrack(d2, [d2_], d2s)
540
- why3 = gm.bfs_backtrack(d3, [d3_], d3s)
541
- why4 = gm.bfs_backtrack(d4, [d4_], d4s)
542
- why = why0 + why1 + why2 + why3 + why4
543
- if min_why is None or len(why) < len(min_why[0]):
544
- min_why = why, ang12, ang34, why0, why1, why2, why3, why4
545
-
546
- if min_why is None:
547
- return None
548
-
549
- _, ang12, ang34, why0, why1, why2, why3, why4 = min_why
550
- why0 = gm.why_equal(ang12, ang34, level)
551
- d1_, d2_ = ang12._d
552
- d3_, d4_ = ang34._d
553
-
554
- if d1 == d1_ and d2 == d2_ and d3 == d3_ and d4 == d4_:
555
- return (d1_, d2_, d3_, d4_), why0
556
-
557
- (a_, b_), (c_, d_) = d1_._obj.points, d2_._obj.points
558
- (e_, f_), (g_, h_) = d3_._obj.points, d4_._obj.points
559
- deps = []
560
- if why0:
561
- dep = Dependency('eqangle', [a_, b_, c_, d_, e_, f_, g_, h_], '', None)
562
- dep.why = why0
563
- deps.append(dep)
564
-
565
- (a, b), (c, d) = d1._obj.points, d2._obj.points
566
- (e, f), (g, h) = d3._obj.points, d4._obj.points
567
- for why, d_xy, (x, y), d_xy_, (x_, y_) in zip(
568
- [why1, why2, why3, why4],
569
- [d1, d2, d3, d4],
570
- [(a, b), (c, d), (e, f), (g, h)],
571
- [d1_, d2_, d3_, d4_],
572
- [(a_, b_), (c_, d_), (e_, f_), (g_, h_)],
573
- ):
574
- xy, xy_ = d_xy._obj, d_xy_._obj
575
- if why:
576
- if xy == xy_:
577
- name = 'collx'
578
- else:
579
- name = 'para'
580
- dep = Dependency(name, [x_, y_, x, y], '', None)
581
- dep.why = why
582
- deps.append(dep)
583
-
584
- return (d1_, d2_, d3_, d4_), deps
585
-
586
-
587
- CONSTRUCTION_RULE = 'c0'
588
-
589
-
590
- class EmptyDependency:
591
- """Empty dependency predicate ready to get filled up."""
592
-
593
- def __init__(self, level: int, rule_name: str):
594
- self.level = level
595
- self.rule_name = rule_name or ''
596
- self.empty = True
597
- self.why = []
598
- self.trace = None
599
-
600
- def populate(self, name: str, args: list[gm.Point]) -> Dependency:
601
- dep = Dependency(name, args, self.rule_name, self.level)
602
- dep.trace2 = self.trace
603
- dep.why = list(self.why)
604
- return dep
605
-
606
- def copy(self) -> EmptyDependency:
607
- other = EmptyDependency(self.level, self.rule_name)
608
- other.why = list(self.why)
609
- return other
610
-
611
- def extend(
612
- self,
613
- g: Any,
614
- name0: str,
615
- args0: list[gm.Point],
616
- name: str,
617
- args: list[gm.Point],
618
- ) -> EmptyDependency:
619
- """Extend the dependency list by (name, args)."""
620
- dep0 = self.populate(name0, args0)
621
- deps = EmptyDependency(level=self.level, rule_name=None)
622
- dep = Dependency(name, args, None, deps.level)
623
- deps.why = [dep0, dep.why_me_or_cache(g, None)]
624
- return deps
625
-
626
- def extend_many(
627
- self,
628
- g: Any,
629
- name0: str,
630
- args0: list[gm.Point],
631
- name_args: list[tuple[str, list[gm.Point]]],
632
- ) -> EmptyDependency:
633
- """Extend the dependency list by many name_args."""
634
- if not name_args:
635
- return self
636
- dep0 = self.populate(name0, args0)
637
- deps = EmptyDependency(level=self.level, rule_name=None)
638
- deps.why = [dep0]
639
- for name, args in name_args:
640
- dep = Dependency(name, args, None, deps.level)
641
- deps.why += [dep.why_me_or_cache(g, None)]
642
- return deps
643
-
644
-
645
- def maybe_make_equal_pairs(
646
- a: gm.Point,
647
- b: gm.Point,
648
- c: gm.Point,
649
- d: gm.Point,
650
- m: gm.Point,
651
- n: gm.Point,
652
- p: gm.Point,
653
- q: gm.Point,
654
- ab: gm.Line,
655
- mn: gm.Line,
656
- g: Any,
657
- level: int,
658
- ) -> list[Dependency]:
659
- """Make a-b:c-d==m-n:p-q in case a-b==m-n or c-d==p-q."""
660
- if ab != mn:
661
- return
662
- why = []
663
- eqname = 'para' if isinstance(ab, gm.Line) else 'cong'
664
- colls = [a, b, m, n]
665
- if len(set(colls)) > 2 and eqname == 'para':
666
- dep = Dependency('collx', colls, None, level)
667
- dep.why_me(g, level)
668
- why += [dep]
669
-
670
- dep = Dependency(eqname, [c, d, p, q], None, level)
671
- dep.why_me(g, level)
672
- why += [dep]
673
- return why
674
-
675
-
676
- class Dependency(Construction):
677
- """Dependency is a predicate that other predicates depend on."""
678
-
679
- def __init__(
680
- self, name: str, args: list[gm.Point], rule_name: str, level: int
681
- ):
682
- super().__init__(name, args)
683
- self.rule_name = rule_name or ''
684
- self.level = level
685
- self.why = []
686
-
687
- self._stat = None
688
- self.trace = None
689
-
690
- def _find(self, dep_hashed: tuple[str, ...]) -> Dependency:
691
- for w in self.why:
692
- f = w._find(dep_hashed)
693
- if f:
694
- return f
695
- if w.hashed() == dep_hashed:
696
- return w
697
-
698
- def remove_loop(self) -> Dependency:
699
- f = self._find(self.hashed())
700
- if f:
701
- return f
702
- return self
703
-
704
- def copy(self) -> Dependency:
705
- dep = Dependency(self.name, self.args, self.rule_name, self.level)
706
- dep.trace = self.trace
707
- dep.why = list(self.why)
708
- return dep
709
-
710
- def why_me_or_cache(self, g: Any, level: int) -> Dependency:
711
- if self.hashed() in g.cache:
712
- return g.cache[self.hashed()]
713
- self.why_me(g, level)
714
- return self
715
-
716
- def populate(self, name: str, args: list[gm.Point]) -> Dependency:
717
- assert self.rule_name == CONSTRUCTION_RULE, self.rule_name
718
- dep = Dependency(self.name, self.args, self.rule_name, self.level)
719
- dep.why = list(self.why)
720
- return dep
721
-
722
- def why_me(self, g: Any, level: int) -> None:
723
- """Figure out the dependencies predicates of self."""
724
- name, args = self.name, self.args
725
-
726
- hashed_me = hashed(name, args)
727
- if hashed_me in g.cache:
728
- dep = g.cache[hashed_me]
729
- self.why = dep.why
730
- self.rule_name = dep.rule_name
731
- return
732
-
733
- if self.name == 'para':
734
- a, b, c, d = self.args
735
- if {a, b} == {c, d}:
736
- self.why = []
737
- return
738
-
739
- ab = g._get_line(a, b)
740
- cd = g._get_line(c, d)
741
- if ab == cd:
742
- if {a, b} == {c, d}:
743
- self.why = []
744
- self.rule_name = ''
745
- return
746
- dep = Dependency('coll', list({a, b, c, d}), 't??', None)
747
- self.why = [dep.why_me_or_cache(g, level)]
748
- return
749
-
750
- for (x, y), xy in zip([(a, b), (c, d)], [ab, cd]):
751
- x_, y_ = xy.points
752
- if {x, y} == {x_, y_}:
753
- continue
754
- d = Dependency('collx', [x, y, x_, y_], None, level)
755
- self.why += [d.why_me_or_cache(g, level)]
756
-
757
- whypara = g.why_equal(ab, cd, None)
758
- self.why += whypara
759
-
760
- elif self.name == 'midp':
761
- m, a, b = self.args
762
- ma = g._get_segment(m, a)
763
- mb = g._get_segment(m, b)
764
- dep = Dependency('coll', [m, a, b], None, None).why_me_or_cache(g, None)
765
- self.why = [dep] + g.why_equal(ma, mb, level)
766
-
767
- elif self.name == 'perp':
768
- a, b, c, d = self.args
769
- ab = g._get_line(a, b)
770
- cd = g._get_line(c, d)
771
- for (x, y), xy in zip([(a, b), (c, d)], [ab, cd]):
772
- x_, y_ = xy.points
773
- if {x, y} == {x_, y_}:
774
- continue
775
- d = Dependency('collx', [x, y, x_, y_], None, level)
776
- self.why += [d.why_me_or_cache(g, level)]
777
-
778
- _, why = why_eqangle(ab._val, cd._val, cd._val, ab._val, level)
779
- a, b = ab.points
780
- c, d = cd.points
781
-
782
- if hashed(self.name, [a, b, c, d]) != self.hashed():
783
- d = Dependency(self.name, [a, b, c, d], None, level)
784
- d.why = why
785
- why = [d]
786
-
787
- self.why += why
788
-
789
- elif self.name == 'cong':
790
- a, b, c, d = self.args
791
- ab = g._get_segment(a, b)
792
- cd = g._get_segment(c, d)
793
-
794
- self.why = g.why_equal(ab, cd, level)
795
-
796
- elif self.name == 'coll':
797
- _, why = gm.line_of_and_why(self.args, level)
798
- self.why = why
799
-
800
- elif self.name == 'collx':
801
- if g.check_coll(self.args):
802
- args = list(set(self.args))
803
- hashed_me = hashed('coll', args)
804
- if hashed_me in g.cache:
805
- dep = g.cache[hashed_me]
806
- self.why = [dep]
807
- self.rule_name = ''
808
- return
809
- _, self.why = gm.line_of_and_why(args, level)
810
- else:
811
- self.name = 'para'
812
- self.why_me(g, level)
813
-
814
- elif self.name == 'cyclic':
815
- _, why = gm.circle_of_and_why(self.args, level)
816
- self.why = why
817
-
818
- elif self.name == 'circle':
819
- o, a, b, c = self.args
820
- oa = g._get_segment(o, a)
821
- ob = g._get_segment(o, b)
822
- oc = g._get_segment(o, c)
823
- self.why = g.why_equal(oa, ob, level) + g.why_equal(oa, oc, level)
824
-
825
- elif self.name in ['eqangle', 'eqangle6']:
826
- a, b, c, d, m, n, p, q = self.args
827
-
828
- ab, why1 = g.get_line_thru_pair_why(a, b)
829
- cd, why2 = g.get_line_thru_pair_why(c, d)
830
- mn, why3 = g.get_line_thru_pair_why(m, n)
831
- pq, why4 = g.get_line_thru_pair_why(p, q)
832
-
833
- if ab is None or cd is None or mn is None or pq is None:
834
- if {a, b} == {m, n}:
835
- d = Dependency('para', [c, d, p, q], None, level)
836
- self.why = [d.why_me_or_cache(g, level)]
837
- if {a, b} == {c, d}:
838
- d = Dependency('para', [p, q, m, n], None, level)
839
- self.why = [d.why_me_or_cache(g, level)]
840
- if {c, d} == {p, q}:
841
- d = Dependency('para', [a, b, m, n], None, level)
842
- self.why = [d.why_me_or_cache(g, level)]
843
- if {p, q} == {m, n}:
844
- d = Dependency('para', [a, b, c, d], None, level)
845
- self.why = [d.why_me_or_cache(g, level)]
846
- return
847
-
848
- for (x, y), xy, whyxy in zip(
849
- [(a, b), (c, d), (m, n), (p, q)],
850
- [ab, cd, mn, pq],
851
- [why1, why2, why3, why4],
852
- ):
853
- x_, y_ = xy.points
854
- if {x, y} == {x_, y_}:
855
- continue
856
- d = Dependency('collx', [x, y, x_, y_], None, level)
857
- d.why = whyxy
858
- self.why += [d]
859
-
860
- a, b = ab.points
861
- c, d = cd.points
862
- m, n = mn.points
863
- p, q = pq.points
864
- diff = hashed(self.name, [a, b, c, d, m, n, p, q]) != self.hashed()
865
-
866
- whyeqangle = None
867
- if ab._val and cd._val and mn._val and pq._val:
868
- whyeqangle = why_eqangle(ab._val, cd._val, mn._val, pq._val, level)
869
-
870
- if whyeqangle:
871
- (dab, dcd, dmn, dpq), whyeqangle = whyeqangle
872
- if diff:
873
- d = Dependency('eqangle', [a, b, c, d, m, n, p, q], None, level)
874
- d.why = whyeqangle
875
- whyeqangle = [d]
876
- self.why += whyeqangle
877
-
878
- else:
879
- if (ab == cd and mn == pq) or (ab == mn and cd == pq):
880
- self.why += []
881
- elif ab == mn:
882
- self.why += maybe_make_equal_pairs(
883
- a, b, c, d, m, n, p, q, ab, mn, g, level
884
- )
885
- elif cd == pq:
886
- self.why += maybe_make_equal_pairs(
887
- c, d, a, b, p, q, m, n, cd, pq, g, level
888
- )
889
- elif ab == cd:
890
- self.why += maybe_make_equal_pairs(
891
- a, b, m, n, c, d, p, q, ab, cd, g, level
892
- )
893
- elif mn == pq:
894
- self.why += maybe_make_equal_pairs(
895
- m, n, a, b, p, q, c, d, mn, pq, g, level
896
- )
897
- elif g.is_equal(ab, mn) or g.is_equal(cd, pq):
898
- dep1 = Dependency('para', [a, b, m, n], None, level)
899
- dep1.why_me(g, level)
900
- dep2 = Dependency('para', [c, d, p, q], None, level)
901
- dep2.why_me(g, level)
902
- self.why += [dep1, dep2]
903
- elif g.is_equal(ab, cd) or g.is_equal(mn, pq):
904
- dep1 = Dependency('para', [a, b, c, d], None, level)
905
- dep1.why_me(g, level)
906
- dep2 = Dependency('para', [m, n, p, q], None, level)
907
- dep2.why_me(g, level)
908
- self.why += [dep1, dep2]
909
- elif ab._val and cd._val and mn._val and pq._val:
910
- self.why = why_eqangle(ab._val, cd._val, mn._val, pq._val, level)
911
-
912
- elif self.name in ['eqratio', 'eqratio6']:
913
- a, b, c, d, m, n, p, q = self.args
914
- ab = g._get_segment(a, b)
915
- cd = g._get_segment(c, d)
916
- mn = g._get_segment(m, n)
917
- pq = g._get_segment(p, q)
918
-
919
- if ab is None or cd is None or mn is None or pq is None:
920
- if {a, b} == {m, n}:
921
- d = Dependency('cong', [c, d, p, q], None, level)
922
- self.why = [d.why_me_or_cache(g, level)]
923
- if {a, b} == {c, d}:
924
- d = Dependency('cong', [p, q, m, n], None, level)
925
- self.why = [d.why_me_or_cache(g, level)]
926
- if {c, d} == {p, q}:
927
- d = Dependency('cong', [a, b, m, n], None, level)
928
- self.why = [d.why_me_or_cache(g, level)]
929
- if {p, q} == {m, n}:
930
- d = Dependency('cong', [a, b, c, d], None, level)
931
- self.why = [d.why_me_or_cache(g, level)]
932
- return
933
-
934
- if ab._val and cd._val and mn._val and pq._val:
935
- self.why = why_eqratio(ab._val, cd._val, mn._val, pq._val, level)
936
-
937
- if self.why is None:
938
- self.why = []
939
- if (ab == cd and mn == pq) or (ab == mn and cd == pq):
940
- self.why = []
941
- elif ab == mn:
942
- self.why += maybe_make_equal_pairs(
943
- a, b, c, d, m, n, p, q, ab, mn, g, level
944
- )
945
- elif cd == pq:
946
- self.why += maybe_make_equal_pairs(
947
- c, d, a, b, p, q, m, n, cd, pq, g, level
948
- )
949
- elif ab == cd:
950
- self.why += maybe_make_equal_pairs(
951
- a, b, m, n, c, d, p, q, ab, cd, g, level
952
- )
953
- elif mn == pq:
954
- self.why += maybe_make_equal_pairs(
955
- m, n, a, b, p, q, c, d, mn, pq, g, level
956
- )
957
- elif g.is_equal(ab, mn) or g.is_equal(cd, pq):
958
- dep1 = Dependency('cong', [a, b, m, n], None, level)
959
- dep1.why_me(g, level)
960
- dep2 = Dependency('cong', [c, d, p, q], None, level)
961
- dep2.why_me(g, level)
962
- self.why += [dep1, dep2]
963
- elif g.is_equal(ab, cd) or g.is_equal(mn, pq):
964
- dep1 = Dependency('cong', [a, b, c, d], None, level)
965
- dep1.why_me(g, level)
966
- dep2 = Dependency('cong', [m, n, p, q], None, level)
967
- dep2.why_me(g, level)
968
- self.why += [dep1, dep2]
969
- elif ab._val and cd._val and mn._val and pq._val:
970
- self.why = why_eqangle(ab._val, cd._val, mn._val, pq._val, level)
971
-
972
- elif self.name in ['diff', 'npara', 'nperp', 'ncoll', 'sameside']:
973
- self.why = []
974
-
975
- elif self.name == 'simtri':
976
- a, b, c, x, y, z = self.args
977
- dep1 = Dependency('eqangle', [a, b, a, c, x, y, x, z], '', level)
978
- dep1.why_me(g, level)
979
- dep2 = Dependency('eqangle', [b, a, b, c, y, x, y, z], '', level)
980
- dep2.why_me(g, level)
981
- self.rule_name = 'r34'
982
- self.why = [dep1, dep2]
983
-
984
- elif self.name == 'contri':
985
- a, b, c, x, y, z = self.args
986
- dep1 = Dependency('cong', [a, b, x, y], '', level)
987
- dep1.why_me(g, level)
988
- dep2 = Dependency('cong', [b, c, y, z], '', level)
989
- dep2.why_me(g, level)
990
- dep3 = Dependency('cong', [c, a, z, x], '', level)
991
- dep3.why_me(g, level)
992
- self.rule_name = 'r32'
993
- self.why = [dep1, dep2, dep3]
994
-
995
- elif self.name == 'ind':
996
- pass
997
-
998
- elif self.name == 'aconst':
999
- a, b, c, d, ang0 = self.args
1000
-
1001
- measure = ang0._val
1002
-
1003
- for ang in measure.neighbors(gm.Angle):
1004
- if ang == ang0:
1005
- continue
1006
- d1, d2 = ang._d
1007
- l1, l2 = d1._obj, d2._obj
1008
- (a1, b1), (c1, d1) = l1.points, l2.points
1009
-
1010
- if not g.check_para_or_coll([a, b, a1, b1]) or not g.check_para_or_coll(
1011
- [c, d, c1, d1]
1012
- ):
1013
- continue
1014
-
1015
- self.why = []
1016
- for args in [(a, b, a1, b1), (c, d, c1, d1)]:
1017
- if g.check_coll(args):
1018
- if len(set(args)) > 2:
1019
- dep = Dependency('coll', args, None, None)
1020
- self.why.append(dep.why_me_or_cache(g, level))
1021
- else:
1022
- dep = Dependency('para', args, None, None)
1023
- self.why.append(dep.why_me_or_cache(g, level))
1024
-
1025
- self.why += gm.why_equal(ang, ang0)
1026
- break
1027
-
1028
- elif self.name == 'rconst':
1029
- a, b, c, d, rat0 = self.args
1030
-
1031
- val = rat0._val
1032
-
1033
- for rat in val.neighbors(gm.Ratio):
1034
- if rat == rat0:
1035
- continue
1036
- l1, l2 = rat._l
1037
- s1, s2 = l1._obj, l2._obj
1038
- (a1, b1), (c1, d1) = list(s1.points), list(s2.points)
1039
-
1040
- if not g.check_cong([a, b, a1, b1]) or not g.check_cong([c, d, c1, d1]):
1041
- continue
1042
-
1043
- self.why = []
1044
- for args in [(a, b, a1, b1), (c, d, c1, d1)]:
1045
- if len(set(args)) > 2:
1046
- dep = Dependency('cong', args, None, None)
1047
- self.why.append(dep.why_me_or_cache(g, level))
1048
-
1049
- self.why += gm.why_equal(rat, rat0)
1050
- break
1051
-
1052
- else:
1053
- raise ValueError('Not recognize', self.name)
1054
-
1055
- def hashed(self, rename: bool = False) -> tuple[str, ...]:
1056
- return hashed(self.name, self.args, rename=rename)
1057
-
1058
-
1059
- def hashed(
1060
- name: str, args: list[gm.Point], rename: bool = False
1061
- ) -> tuple[str, ...]:
1062
- if name == 's_angle':
1063
- args = [p.name if not rename else p.new_name for p in args[:-1]] + [
1064
- str(args[-1])
1065
- ]
1066
- else:
1067
- args = [p.name if not rename else p.new_name for p in args]
1068
- return hashed_txt(name, args)
1069
-
1070
-
1071
- def hashed_txt(name: str, args: list[str]) -> tuple[str, ...]:
1072
- """Return a tuple unique to name and args upto arg permutation equivariant."""
1073
-
1074
- if name in ['const', 'aconst', 'rconst']:
1075
- a, b, c, d, y = args
1076
- a, b = sorted([a, b])
1077
- c, d = sorted([c, d])
1078
- return name, a, b, c, d, y
1079
-
1080
- if name in ['npara', 'nperp', 'para', 'cong', 'perp', 'collx']:
1081
- a, b, c, d = args
1082
-
1083
- a, b = sorted([a, b])
1084
- c, d = sorted([c, d])
1085
- (a, b), (c, d) = sorted([(a, b), (c, d)])
1086
-
1087
- return (name, a, b, c, d)
1088
-
1089
- if name in ['midp', 'midpoint']:
1090
- a, b, c = args
1091
- b, c = sorted([b, c])
1092
- return (name, a, b, c)
1093
-
1094
- if name in ['coll', 'cyclic', 'ncoll', 'diff', 'triangle']:
1095
- return (name,) + tuple(sorted(list(set(args))))
1096
-
1097
- if name == 'circle':
1098
- x, a, b, c = args
1099
- return (name, x) + tuple(sorted([a, b, c]))
1100
-
1101
- if name in ['eqangle', 'eqratio', 'eqangle6', 'eqratio6']:
1102
- a, b, c, d, e, f, g, h = args
1103
- a, b = sorted([a, b])
1104
- c, d = sorted([c, d])
1105
- e, f = sorted([e, f])
1106
- g, h = sorted([g, h])
1107
- if tuple(sorted([a, b, e, f])) > tuple(sorted([c, d, g, h])):
1108
- a, b, e, f, c, d, g, h = c, d, g, h, a, b, e, f
1109
- if (a, b, c, d) > (e, f, g, h):
1110
- a, b, c, d, e, f, g, h = e, f, g, h, a, b, c, d
1111
-
1112
- if name == 'eqangle6':
1113
- name = 'eqangle'
1114
- if name == 'eqratio6':
1115
- name = 'eqratio'
1116
- return (name,) + (a, b, c, d, e, f, g, h)
1117
-
1118
- if name in ['contri', 'simtri', 'simtri2', 'contri2', 'contri*', 'simtri*']:
1119
- a, b, c, x, y, z = args
1120
- (a, x), (b, y), (c, z) = sorted([(a, x), (b, y), (c, z)], key=sorted)
1121
- (a, b, c), (x, y, z) = sorted([(a, b, c), (x, y, z)], key=sorted)
1122
- return (name, a, b, c, x, y, z)
1123
-
1124
- if name in ['eqratio3']:
1125
- a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name
1126
- (a, c), (b, d) = sorted([(a, c), (b, d)], key=sorted)
1127
- (a, b), (c, d) = sorted([(a, b), (c, d)], key=sorted)
1128
- return (name, a, b, c, d, o, o)
1129
-
1130
- if name in ['sameside', 's_angle']:
1131
- return (name,) + tuple(args)
1132
-
1133
- raise ValueError(f'Not recognize {name} to hash.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Implements objects to represent problems, theorems, proofs, traceback."""
17
+
18
+ from __future__ import annotations
19
+
20
+ from collections import defaultdict # pylint: disable=g-importing-member
21
+ from typing import Any
22
+
23
+ import geometry as gm
24
+ import pretty as pt
25
+
26
+
27
+ # pylint: disable=protected-access
28
+ # pylint: disable=unused-variable
29
+ # pylint: disable=unused-argument
30
+ # pylint: disable=unused-assignment
31
+
32
+
33
+ def reshape(l: list[Any], n: int = 1) -> list[list[Any]]:
34
+ assert len(l) % n == 0
35
+ columns = [[] for i in range(n)]
36
+ for i, x in enumerate(l):
37
+ columns[i % n].append(x)
38
+ return zip(*columns)
39
+
40
+
41
+ def isint(x: str) -> bool:
42
+ try:
43
+ int(x)
44
+ return True
45
+ except: # pylint: disable=bare-except
46
+ return False
47
+
48
+
49
+ class Construction:
50
+ """One predicate."""
51
+
52
+ @classmethod
53
+ def from_txt(cls, data: str) -> Construction:
54
+ data = data.split(' ')
55
+ return Construction(data[0], data[1:])
56
+
57
+ def __init__(self, name: str, args: list[str]):
58
+ self.name = name
59
+ self.args = args
60
+
61
+ def translate(self, mapping: dict[str, str]) -> Construction:
62
+ args = [a if isint(a) else mapping[a] for a in self.args]
63
+ return Construction(self.name, args)
64
+
65
+ def txt(self) -> str:
66
+ return ' '.join([self.name] + list(self.args))
67
+
68
+
69
+ class Clause:
70
+ """One construction (>= 1 predicate)."""
71
+
72
+ @classmethod
73
+ def from_txt(cls, data: str) -> Clause:
74
+ if data == ' =':
75
+ return Clause([], [])
76
+ points, constructions = data.split(' = ')
77
+ return Clause(
78
+ points.split(' '),
79
+ [Construction.from_txt(c) for c in constructions.split(', ')],
80
+ )
81
+
82
+ def __init__(self, points: list[str], constructions: list[Construction]):
83
+ self.points = []
84
+ self.nums = []
85
+
86
+ for p in points:
87
+ num = None
88
+ if isinstance(p, str) and '@' in p:
89
+ p, num = p.split('@')
90
+ x, y = num.split('_')
91
+ num = float(x), float(y)
92
+ self.points.append(p)
93
+ self.nums.append(num)
94
+
95
+ self.constructions = constructions
96
+
97
+ def translate(self, mapping: dict[str, str]) -> Clause:
98
+ points0 = []
99
+ for p in self.points:
100
+ pcount = len(mapping) + 1
101
+ name = chr(96 + pcount)
102
+ if name > 'z': # pcount = 26 -> name = 'z'
103
+ name = chr(97 + (pcount - 1) % 26) + str((pcount - 1) // 26)
104
+
105
+ p0 = mapping.get(p, name)
106
+ mapping[p] = p0
107
+ points0.append(p0)
108
+ return Clause(points0, [c.translate(mapping) for c in self.constructions])
109
+
110
+ def add(self, name: str, args: list[str]) -> None:
111
+ self.constructions.append(Construction(name, args))
112
+
113
+ def txt(self) -> str:
114
+ return (
115
+ ' '.join(self.points)
116
+ + ' = '
117
+ + ', '.join(c.txt() for c in self.constructions)
118
+ )
119
+
120
+
121
+ def _gcd(x: int, y: int) -> int:
122
+ while y:
123
+ x, y = y, x % y
124
+ return x
125
+
126
+
127
+ def simplify(n: int, d: int) -> tuple[int, int]:
128
+ g = _gcd(n, d)
129
+ return (n // g, d // g)
130
+
131
+
132
+ def compare_fn(dep: Dependency) -> tuple[Dependency, str]:
133
+ return (dep, pt.pretty(dep))
134
+
135
+
136
+ def sort_deps(deps: list[Dependency]) -> list[Dependency]:
137
+ return sorted(deps, key=compare_fn)
138
+
139
+
140
+ class Problem:
141
+ """Describe one problem to solve."""
142
+
143
+ @classmethod
144
+ def from_txt_file(
145
+ cls, fname: str, to_dict: bool = False, translate: bool = True
146
+ ):
147
+ """Load a problem from a text file."""
148
+ with open(fname, 'r') as f:
149
+ lines = f.read().split('\n')
150
+
151
+ lines = [l for l in lines if l]
152
+ data = [
153
+ cls.from_txt(url + '\n' + problem, translate)
154
+ for (url, problem) in reshape(lines, 2)
155
+ ]
156
+ if to_dict:
157
+ return cls.to_dict(data)
158
+ return data
159
+
160
+ @classmethod
161
+ def from_txt(cls, data: str, translate: bool = True) -> Problem:
162
+ """Load a problem from a str object."""
163
+ url = ''
164
+ if '\n' in data:
165
+ url, data = data.split('\n')
166
+
167
+ if ' ? ' in data:
168
+ clauses, goal = data.split(' ? ')
169
+ goal = Construction.from_txt(goal)
170
+ else:
171
+ clauses, goal = data, None
172
+
173
+ clauses = clauses.split('; ')
174
+ problem = Problem(
175
+ url=url, clauses=[Clause.from_txt(c) for c in clauses], goal=goal
176
+ )
177
+ if translate:
178
+ return problem.translate()
179
+ return problem
180
+
181
+ @classmethod
182
+ def to_dict(cls, data: list[Problem]) -> dict[str, Problem]:
183
+ return {p.url: p for p in data}
184
+
185
+ def __init__(self, url: str, clauses: list[Clause], goal: Construction):
186
+ self.url = url
187
+ self.clauses = clauses
188
+ self.goal = goal
189
+
190
+ def copy(self) -> Problem:
191
+ return Problem(self.url, list(self.clauses), self.goal)
192
+
193
+ def translate(self) -> Problem: # to single-char point names
194
+ """Translate point names into alphabetical."""
195
+ mapping = {}
196
+ clauses = []
197
+
198
+ for clause in self.clauses:
199
+ clauses.append(clause.translate(mapping))
200
+
201
+ if self.goal:
202
+ goal = self.goal.translate(mapping)
203
+ else:
204
+ goal = self.goal
205
+
206
+ p = Problem(self.url, clauses, goal)
207
+ p.mapping = mapping
208
+ return p
209
+
210
+ def txt(self) -> str:
211
+ return (
212
+ '; '.join([c.txt() for c in self.clauses]) + ' ? ' + self.goal.txt()
213
+ if self.goal
214
+ else ''
215
+ )
216
+
217
+ def setup_str_from_problem(self, definitions: list[Definition]) -> str:
218
+ """Construct the <theorem_premises> string from Problem object."""
219
+ ref = 0
220
+
221
+ string = []
222
+ for clause in self.clauses:
223
+ group = {}
224
+ p2deps = defaultdict(list)
225
+ for c in clause.constructions:
226
+ cdef = definitions[c.name]
227
+
228
+ if len(c.args) != len(cdef.construction.args):
229
+ assert len(c.args) + len(clause.points) == len(cdef.construction.args)
230
+ c.args = clause.points + c.args
231
+
232
+ mapping = dict(zip(cdef.construction.args, c.args))
233
+ for points, bs in cdef.basics:
234
+ points = tuple([mapping[x] for x in points])
235
+ for p in points:
236
+ group[p] = points
237
+
238
+ for b in bs:
239
+ args = [mapping[a] for a in b.args]
240
+ name = b.name
241
+ if b.name in ['s_angle', 'aconst']:
242
+ x, y, z, v = args
243
+ name = 'aconst'
244
+ v = int(v)
245
+
246
+ if v < 0:
247
+ v = -v
248
+ x, z = z, x
249
+
250
+ m, n = simplify(int(v), 180)
251
+ args = [y, z, y, x, f'{m}pi/{n}']
252
+
253
+ p2deps[points].append(hashed_txt(name, args))
254
+
255
+ for k, v in p2deps.items():
256
+ p2deps[k] = sort_deps(v)
257
+
258
+ points = clause.points
259
+ while points:
260
+ p = points[0]
261
+ gr = group[p]
262
+ points = [x for x in points if x not in gr]
263
+
264
+ deps_str = []
265
+ for dep in p2deps[gr]:
266
+ ref_str = '{:02}'.format(ref)
267
+ dep_str = pt.pretty(dep)
268
+
269
+ if dep[0] == 'aconst':
270
+ m, n = map(int, dep[-1].split('pi/'))
271
+ mn = f'{m}. pi / {n}.'
272
+ dep_str = ' '.join(dep_str.split()[:-1] + [mn])
273
+
274
+ deps_str.append(dep_str + ' ' + ref_str)
275
+ ref += 1
276
+
277
+ string.append(' '.join(gr) + ' : ' + ' '.join(deps_str))
278
+
279
+ string = '{S} ' + ' ; '.join([s.strip() for s in string])
280
+ goal = self.goal
281
+ string += ' ? ' + pt.pretty([goal.name] + goal.args)
282
+ return string
283
+
284
+
285
+ def parse_rely(s: str) -> dict[str, str]:
286
+ result = {}
287
+ if not s:
288
+ return result
289
+ s = [x.strip() for x in s.split(',')]
290
+ for x in s:
291
+ a, b = x.split(':')
292
+ a, b = a.strip().split(), b.strip().split()
293
+ result.update({m: b for m in a})
294
+ return result
295
+
296
+
297
+ class Definition:
298
+ """Definitions of construction statements."""
299
+
300
+ @classmethod
301
+ def from_txt_file(cls, fname: str, to_dict: bool = False) -> Definition:
302
+ with open(fname, 'r') as f:
303
+ lines = f.read()
304
+ return cls.from_string(lines, to_dict)
305
+
306
+ @classmethod
307
+ def from_string(cls, string: str, to_dict: bool = False) -> Definition:
308
+ lines = string.split('\n')
309
+ data = [cls.from_txt('\n'.join(group)) for group in reshape(lines, 6)]
310
+ if to_dict:
311
+ return cls.to_dict(data)
312
+ return data
313
+
314
+ @classmethod
315
+ def to_dict(cls, data: list[Definition]) -> dict[str, Definition]:
316
+ return {d.construction.name: d for d in data}
317
+
318
+ @classmethod
319
+ def from_txt(cls, data: str) -> Definition:
320
+ """Load definitions from a str object."""
321
+ construction, rely, deps, basics, numerics, _ = data.split('\n')
322
+ basics = [] if not basics else [b.strip() for b in basics.split(';')]
323
+
324
+ levels = []
325
+ for bs in basics:
326
+ if ':' in bs:
327
+ points, bs = bs.split(':')
328
+ points = points.strip().split()
329
+ else:
330
+ points = []
331
+ if bs.strip():
332
+ bs = [Construction.from_txt(b.strip()) for b in bs.strip().split(',')]
333
+ else:
334
+ bs = []
335
+ levels.append((points, bs))
336
+
337
+ numerics = [] if not numerics else numerics.split(', ')
338
+
339
+ return Definition(
340
+ construction=Construction.from_txt(construction),
341
+ rely=parse_rely(rely),
342
+ deps=Clause.from_txt(deps),
343
+ basics=levels,
344
+ numerics=[Construction.from_txt(c) for c in numerics],
345
+ )
346
+
347
+ def __init__(
348
+ self,
349
+ construction: Construction,
350
+ rely: dict[str, str],
351
+ deps: Clause,
352
+ basics: list[tuple[list[str], list[Construction]]],
353
+ numerics: list[Construction],
354
+ ):
355
+ self.construction = construction
356
+ self.rely = rely
357
+ self.deps = deps
358
+ self.basics = basics
359
+ self.numerics = numerics
360
+
361
+ args = set()
362
+ for num in numerics:
363
+ args.update(num.args)
364
+
365
+ self.points = []
366
+ self.args = []
367
+ for p in self.construction.args:
368
+ if p in args:
369
+ self.args.append(p)
370
+ else:
371
+ self.points.append(p)
372
+
373
+
374
+ class Theorem:
375
+ """Deduction rule."""
376
+
377
+ @classmethod
378
+ def from_txt_file(cls, fname: str, to_dict: bool = False) -> Theorem:
379
+ with open(fname, 'r') as f:
380
+ theorems = f.read()
381
+ return cls.from_string(theorems, to_dict)
382
+
383
+ @classmethod
384
+ def from_string(cls, string: str, to_dict: bool = False) -> Theorem:
385
+ """Load deduction rule from a str object."""
386
+ theorems = string.split('\n')
387
+ theorems = [l for l in theorems if l and not l.startswith('#')]
388
+ theorems = [cls.from_txt(l) for l in theorems]
389
+
390
+ for i, th in enumerate(theorems):
391
+ th.rule_name = 'r{:02}'.format(i)
392
+
393
+ if to_dict:
394
+ result = {}
395
+ for t in theorems:
396
+ if t.name in result:
397
+ t.name += '_'
398
+ result[t.rule_name] = t
399
+
400
+ return result
401
+
402
+ return theorems
403
+
404
+ @classmethod
405
+ def from_txt(cls, data: str) -> Theorem:
406
+ premises, conclusion = data.split(' => ')
407
+ premises = premises.split(', ')
408
+ conclusion = conclusion.split(', ')
409
+ return Theorem(
410
+ premise=[Construction.from_txt(p) for p in premises],
411
+ conclusion=[Construction.from_txt(c) for c in conclusion],
412
+ )
413
+
414
+ def __init__(
415
+ self, premise: list[Construction], conclusion: list[Construction]
416
+ ):
417
+ if len(conclusion) != 1:
418
+ raise ValueError('Cannot have more than one conclusion')
419
+ self.name = '_'.join([p.name for p in premise + conclusion])
420
+ self.premise = premise
421
+ self.conclusion = conclusion
422
+ self.is_arg_reduce = False
423
+
424
+ assert len(self.conclusion) == 1
425
+ con = self.conclusion[0]
426
+
427
+ if con.name in [
428
+ 'eqratio3',
429
+ 'midp',
430
+ 'contri',
431
+ 'simtri',
432
+ 'contri2',
433
+ 'simtri2',
434
+ 'simtri*',
435
+ 'contri*',
436
+ ]:
437
+ return
438
+
439
+ prem_args = set(sum([p.args for p in self.premise], []))
440
+ con_args = set(con.args)
441
+ if len(prem_args) <= len(con_args):
442
+ self.is_arg_reduce = True
443
+
444
+ def txt(self) -> str:
445
+ premise_txt = ', '.join([clause.txt() for clause in self.premise])
446
+ conclusion_txt = ', '.join([clause.txt() for clause in self.conclusion])
447
+ return f'{premise_txt} => {conclusion_txt}'
448
+
449
+ def conclusion_name_args(
450
+ self, mapping: dict[str, gm.Point]
451
+ ) -> tuple[str, list[gm.Point]]:
452
+ mapping = {arg: p for arg, p in mapping.items() if isinstance(arg, str)}
453
+ c = self.conclusion[0]
454
+ args = [mapping[a] for a in c.args]
455
+ return c.name, args
456
+
457
+
458
+ def why_eqratio(
459
+ d1: gm.Direction,
460
+ d2: gm.Direction,
461
+ d3: gm.Direction,
462
+ d4: gm.Direction,
463
+ level: int,
464
+ ) -> list[Dependency]:
465
+ """Why two ratios are equal, returns a Dependency objects."""
466
+ all12 = list(gm.all_ratios(d1, d2, level))
467
+ all34 = list(gm.all_ratios(d3, d4, level))
468
+
469
+ min_why = None
470
+ for ang12, d1s, d2s in all12:
471
+ for ang34, d3s, d4s in all34:
472
+ why0 = gm.why_equal(ang12, ang34, level)
473
+ if why0 is None:
474
+ continue
475
+ d1_, d2_ = ang12._l
476
+ d3_, d4_ = ang34._l
477
+ why1 = gm.bfs_backtrack(d1, [d1_], d1s)
478
+ why2 = gm.bfs_backtrack(d2, [d2_], d2s)
479
+ why3 = gm.bfs_backtrack(d3, [d3_], d3s)
480
+ why4 = gm.bfs_backtrack(d4, [d4_], d4s)
481
+ why = why0 + why1 + why2 + why3 + why4
482
+ if min_why is None or len(why) < len(min_why[0]):
483
+ min_why = why, ang12, ang34, why0, why1, why2, why3, why4
484
+
485
+ if min_why is None:
486
+ return None
487
+
488
+ _, ang12, ang34, why0, why1, why2, why3, why4 = min_why
489
+ d1_, d2_ = ang12._l
490
+ d3_, d4_ = ang34._l
491
+
492
+ if d1 == d1_ and d2 == d2_ and d3 == d3_ and d4 == d4_:
493
+ return why0
494
+
495
+ (a_, b_), (c_, d_) = d1_._obj.points, d2_._obj.points
496
+ (e_, f_), (g_, h_) = d3_._obj.points, d4_._obj.points
497
+ deps = []
498
+ if why0:
499
+ dep = Dependency('eqratio', [a_, b_, c_, d_, e_, f_, g_, h_], '', level)
500
+ dep.why = why0
501
+ deps.append(dep)
502
+
503
+ (a, b), (c, d) = d1._obj.points, d2._obj.points
504
+ (e, f), (g, h) = d3._obj.points, d4._obj.points
505
+ for why, (x, y), (x_, y_) in zip(
506
+ [why1, why2, why3, why4],
507
+ [(a, b), (c, d), (e, f), (g, h)],
508
+ [(a_, b_), (c_, d_), (e_, f_), (g_, h_)],
509
+ ):
510
+ if why:
511
+ dep = Dependency('cong', [x, y, x_, y_], '', level)
512
+ dep.why = why
513
+ deps.append(dep)
514
+
515
+ return deps
516
+
517
+
518
+ def why_eqangle(
519
+ d1: gm.Direction,
520
+ d2: gm.Direction,
521
+ d3: gm.Direction,
522
+ d4: gm.Direction,
523
+ level: int,
524
+ verbose: bool = False,
525
+ ) -> list[Dependency]:
526
+ """Why two angles are equal, returns a Dependency objects."""
527
+ all12 = list(gm.all_angles(d1, d2, level))
528
+ all34 = list(gm.all_angles(d3, d4, level))
529
+
530
+ min_why = None
531
+ for ang12, d1s, d2s in all12:
532
+ for ang34, d3s, d4s in all34:
533
+ why0 = gm.why_equal(ang12, ang34, level)
534
+ if why0 is None:
535
+ continue
536
+ d1_, d2_ = ang12._d
537
+ d3_, d4_ = ang34._d
538
+ why1 = gm.bfs_backtrack(d1, [d1_], d1s)
539
+ why2 = gm.bfs_backtrack(d2, [d2_], d2s)
540
+ why3 = gm.bfs_backtrack(d3, [d3_], d3s)
541
+ why4 = gm.bfs_backtrack(d4, [d4_], d4s)
542
+ why = why0 + why1 + why2 + why3 + why4
543
+ if min_why is None or len(why) < len(min_why[0]):
544
+ min_why = why, ang12, ang34, why0, why1, why2, why3, why4
545
+
546
+ if min_why is None:
547
+ return None
548
+
549
+ _, ang12, ang34, why0, why1, why2, why3, why4 = min_why
550
+ why0 = gm.why_equal(ang12, ang34, level)
551
+ d1_, d2_ = ang12._d
552
+ d3_, d4_ = ang34._d
553
+
554
+ if d1 == d1_ and d2 == d2_ and d3 == d3_ and d4 == d4_:
555
+ return (d1_, d2_, d3_, d4_), why0
556
+
557
+ (a_, b_), (c_, d_) = d1_._obj.points, d2_._obj.points
558
+ (e_, f_), (g_, h_) = d3_._obj.points, d4_._obj.points
559
+ deps = []
560
+ if why0:
561
+ dep = Dependency('eqangle', [a_, b_, c_, d_, e_, f_, g_, h_], '', None)
562
+ dep.why = why0
563
+ deps.append(dep)
564
+
565
+ (a, b), (c, d) = d1._obj.points, d2._obj.points
566
+ (e, f), (g, h) = d3._obj.points, d4._obj.points
567
+ for why, d_xy, (x, y), d_xy_, (x_, y_) in zip(
568
+ [why1, why2, why3, why4],
569
+ [d1, d2, d3, d4],
570
+ [(a, b), (c, d), (e, f), (g, h)],
571
+ [d1_, d2_, d3_, d4_],
572
+ [(a_, b_), (c_, d_), (e_, f_), (g_, h_)],
573
+ ):
574
+ xy, xy_ = d_xy._obj, d_xy_._obj
575
+ if why:
576
+ if xy == xy_:
577
+ name = 'collx'
578
+ else:
579
+ name = 'para'
580
+ dep = Dependency(name, [x_, y_, x, y], '', None)
581
+ dep.why = why
582
+ deps.append(dep)
583
+
584
+ return (d1_, d2_, d3_, d4_), deps
585
+
586
+
587
+ CONSTRUCTION_RULE = 'c0'
588
+
589
+
590
+ class EmptyDependency:
591
+ """Empty dependency predicate ready to get filled up."""
592
+
593
+ def __init__(self, level: int, rule_name: str):
594
+ self.level = level
595
+ self.rule_name = rule_name or ''
596
+ self.empty = True
597
+ self.why = []
598
+ self.trace = None
599
+
600
+ def populate(self, name: str, args: list[gm.Point]) -> Dependency:
601
+ dep = Dependency(name, args, self.rule_name, self.level)
602
+ dep.trace2 = self.trace
603
+ dep.why = list(self.why)
604
+ return dep
605
+
606
+ def copy(self) -> EmptyDependency:
607
+ other = EmptyDependency(self.level, self.rule_name)
608
+ other.why = list(self.why)
609
+ return other
610
+
611
+ def extend(
612
+ self,
613
+ g: Any,
614
+ name0: str,
615
+ args0: list[gm.Point],
616
+ name: str,
617
+ args: list[gm.Point],
618
+ ) -> EmptyDependency:
619
+ """Extend the dependency list by (name, args)."""
620
+ dep0 = self.populate(name0, args0)
621
+ deps = EmptyDependency(level=self.level, rule_name=None)
622
+ dep = Dependency(name, args, None, deps.level)
623
+ deps.why = [dep0, dep.why_me_or_cache(g, None)]
624
+ return deps
625
+
626
+ def extend_many(
627
+ self,
628
+ g: Any,
629
+ name0: str,
630
+ args0: list[gm.Point],
631
+ name_args: list[tuple[str, list[gm.Point]]],
632
+ ) -> EmptyDependency:
633
+ """Extend the dependency list by many name_args."""
634
+ if not name_args:
635
+ return self
636
+ dep0 = self.populate(name0, args0)
637
+ deps = EmptyDependency(level=self.level, rule_name=None)
638
+ deps.why = [dep0]
639
+ for name, args in name_args:
640
+ dep = Dependency(name, args, None, deps.level)
641
+ deps.why += [dep.why_me_or_cache(g, None)]
642
+ return deps
643
+
644
+
645
+ def maybe_make_equal_pairs(
646
+ a: gm.Point,
647
+ b: gm.Point,
648
+ c: gm.Point,
649
+ d: gm.Point,
650
+ m: gm.Point,
651
+ n: gm.Point,
652
+ p: gm.Point,
653
+ q: gm.Point,
654
+ ab: gm.Line,
655
+ mn: gm.Line,
656
+ g: Any,
657
+ level: int,
658
+ ) -> list[Dependency]:
659
+ """Make a-b:c-d==m-n:p-q in case a-b==m-n or c-d==p-q."""
660
+ if ab != mn:
661
+ return
662
+ why = []
663
+ eqname = 'para' if isinstance(ab, gm.Line) else 'cong'
664
+ colls = [a, b, m, n]
665
+ if len(set(colls)) > 2 and eqname == 'para':
666
+ dep = Dependency('collx', colls, None, level)
667
+ dep.why_me(g, level)
668
+ why += [dep]
669
+
670
+ dep = Dependency(eqname, [c, d, p, q], None, level)
671
+ dep.why_me(g, level)
672
+ why += [dep]
673
+ return why
674
+
675
+
676
+ class Dependency(Construction):
677
+ """Dependency is a predicate that other predicates depend on."""
678
+
679
+ def __init__(
680
+ self, name: str, args: list[gm.Point], rule_name: str, level: int
681
+ ):
682
+ super().__init__(name, args)
683
+ self.rule_name = rule_name or ''
684
+ self.level = level
685
+ self.why = []
686
+
687
+ self._stat = None
688
+ self.trace = None
689
+
690
+ def _find(self, dep_hashed: tuple[str, ...]) -> Dependency:
691
+ for w in self.why:
692
+ f = w._find(dep_hashed)
693
+ if f:
694
+ return f
695
+ if w.hashed() == dep_hashed:
696
+ return w
697
+
698
+ def remove_loop(self) -> Dependency:
699
+ f = self._find(self.hashed())
700
+ if f:
701
+ return f
702
+ return self
703
+
704
+ def copy(self) -> Dependency:
705
+ dep = Dependency(self.name, self.args, self.rule_name, self.level)
706
+ dep.trace = self.trace
707
+ dep.why = list(self.why)
708
+ return dep
709
+
710
+ def why_me_or_cache(self, g: Any, level: int) -> Dependency:
711
+ if self.hashed() in g.cache:
712
+ return g.cache[self.hashed()]
713
+ self.why_me(g, level)
714
+ return self
715
+
716
+ def populate(self, name: str, args: list[gm.Point]) -> Dependency:
717
+ assert self.rule_name == CONSTRUCTION_RULE, self.rule_name
718
+ dep = Dependency(self.name, self.args, self.rule_name, self.level)
719
+ dep.why = list(self.why)
720
+ return dep
721
+
722
+ def why_me(self, g: Any, level: int) -> None:
723
+ """Figure out the dependencies predicates of self."""
724
+ name, args = self.name, self.args
725
+
726
+ hashed_me = hashed(name, args)
727
+ if hashed_me in g.cache:
728
+ dep = g.cache[hashed_me]
729
+ self.why = dep.why
730
+ self.rule_name = dep.rule_name
731
+ return
732
+
733
+ if self.name == 'para':
734
+ a, b, c, d = self.args
735
+ if {a, b} == {c, d}:
736
+ self.why = []
737
+ return
738
+
739
+ ab = g._get_line(a, b)
740
+ cd = g._get_line(c, d)
741
+ if ab == cd:
742
+ if {a, b} == {c, d}:
743
+ self.why = []
744
+ self.rule_name = ''
745
+ return
746
+ dep = Dependency('coll', list({a, b, c, d}), 't??', None)
747
+ self.why = [dep.why_me_or_cache(g, level)]
748
+ return
749
+
750
+ for (x, y), xy in zip([(a, b), (c, d)], [ab, cd]):
751
+ x_, y_ = xy.points
752
+ if {x, y} == {x_, y_}:
753
+ continue
754
+ d = Dependency('collx', [x, y, x_, y_], None, level)
755
+ self.why += [d.why_me_or_cache(g, level)]
756
+
757
+ whypara = g.why_equal(ab, cd, None)
758
+ self.why += whypara
759
+
760
+ elif self.name == 'midp':
761
+ m, a, b = self.args
762
+ ma = g._get_segment(m, a)
763
+ mb = g._get_segment(m, b)
764
+ dep = Dependency('coll', [m, a, b], None, None).why_me_or_cache(g, None)
765
+ self.why = [dep] + g.why_equal(ma, mb, level)
766
+
767
+ elif self.name == 'perp':
768
+ a, b, c, d = self.args
769
+ ab = g._get_line(a, b)
770
+ cd = g._get_line(c, d)
771
+ for (x, y), xy in zip([(a, b), (c, d)], [ab, cd]):
772
+ x_, y_ = xy.points
773
+ if {x, y} == {x_, y_}:
774
+ continue
775
+ d = Dependency('collx', [x, y, x_, y_], None, level)
776
+ self.why += [d.why_me_or_cache(g, level)]
777
+
778
+ _, why = why_eqangle(ab._val, cd._val, cd._val, ab._val, level)
779
+ a, b = ab.points
780
+ c, d = cd.points
781
+
782
+ if hashed(self.name, [a, b, c, d]) != self.hashed():
783
+ d = Dependency(self.name, [a, b, c, d], None, level)
784
+ d.why = why
785
+ why = [d]
786
+
787
+ self.why += why
788
+
789
+ elif self.name == 'cong':
790
+ a, b, c, d = self.args
791
+ ab = g._get_segment(a, b)
792
+ cd = g._get_segment(c, d)
793
+
794
+ self.why = g.why_equal(ab, cd, level)
795
+
796
+ elif self.name == 'coll':
797
+ _, why = gm.line_of_and_why(self.args, level)
798
+ self.why = why
799
+
800
+ elif self.name == 'collx':
801
+ if g.check_coll(self.args):
802
+ args = list(set(self.args))
803
+ hashed_me = hashed('coll', args)
804
+ if hashed_me in g.cache:
805
+ dep = g.cache[hashed_me]
806
+ self.why = [dep]
807
+ self.rule_name = ''
808
+ return
809
+ _, self.why = gm.line_of_and_why(args, level)
810
+ else:
811
+ self.name = 'para'
812
+ self.why_me(g, level)
813
+
814
+ elif self.name == 'cyclic':
815
+ _, why = gm.circle_of_and_why(self.args, level)
816
+ self.why = why
817
+
818
+ elif self.name == 'circle':
819
+ o, a, b, c = self.args
820
+ oa = g._get_segment(o, a)
821
+ ob = g._get_segment(o, b)
822
+ oc = g._get_segment(o, c)
823
+ self.why = g.why_equal(oa, ob, level) + g.why_equal(oa, oc, level)
824
+
825
+ elif self.name == 'semicircle':
826
+ o, a, b, c = self.args # o: center, a & b: endpoints, c: another point to check
827
+ oa = g._get_segment(o, a) # Segment from o to a
828
+ ob = g._get_segment(o, b) # Segment from o to b
829
+ oc = g._get_segment(o, c) # Segment from o to c
830
+
831
+ # Check that segments are equal (radius check)
832
+ self.why = g.why_equal(oa, ob, level) + g.why_equal(oa, oc, level)
833
+
834
+ # Additional checks for semicircle properties can be added here
835
+ # For example, ensure that point c lies on the semicircle arc defined by a and b
836
+ self.why += g.why_on_arc(a, b, c, level) # This function needs to be implemented to check if c is on the arc
837
+
838
+
839
+
840
+ elif self.name in ['eqangle', 'eqangle6']:
841
+ a, b, c, d, m, n, p, q = self.args
842
+
843
+ ab, why1 = g.get_line_thru_pair_why(a, b)
844
+ cd, why2 = g.get_line_thru_pair_why(c, d)
845
+ mn, why3 = g.get_line_thru_pair_why(m, n)
846
+ pq, why4 = g.get_line_thru_pair_why(p, q)
847
+
848
+ if ab is None or cd is None or mn is None or pq is None:
849
+ if {a, b} == {m, n}:
850
+ d = Dependency('para', [c, d, p, q], None, level)
851
+ self.why = [d.why_me_or_cache(g, level)]
852
+ if {a, b} == {c, d}:
853
+ d = Dependency('para', [p, q, m, n], None, level)
854
+ self.why = [d.why_me_or_cache(g, level)]
855
+ if {c, d} == {p, q}:
856
+ d = Dependency('para', [a, b, m, n], None, level)
857
+ self.why = [d.why_me_or_cache(g, level)]
858
+ if {p, q} == {m, n}:
859
+ d = Dependency('para', [a, b, c, d], None, level)
860
+ self.why = [d.why_me_or_cache(g, level)]
861
+ return
862
+
863
+ for (x, y), xy, whyxy in zip(
864
+ [(a, b), (c, d), (m, n), (p, q)],
865
+ [ab, cd, mn, pq],
866
+ [why1, why2, why3, why4],
867
+ ):
868
+ x_, y_ = xy.points
869
+ if {x, y} == {x_, y_}:
870
+ continue
871
+ d = Dependency('collx', [x, y, x_, y_], None, level)
872
+ d.why = whyxy
873
+ self.why += [d]
874
+
875
+ a, b = ab.points
876
+ c, d = cd.points
877
+ m, n = mn.points
878
+ p, q = pq.points
879
+ diff = hashed(self.name, [a, b, c, d, m, n, p, q]) != self.hashed()
880
+
881
+ whyeqangle = None
882
+ if ab._val and cd._val and mn._val and pq._val:
883
+ whyeqangle = why_eqangle(ab._val, cd._val, mn._val, pq._val, level)
884
+
885
+ if whyeqangle:
886
+ (dab, dcd, dmn, dpq), whyeqangle = whyeqangle
887
+ if diff:
888
+ d = Dependency('eqangle', [a, b, c, d, m, n, p, q], None, level)
889
+ d.why = whyeqangle
890
+ whyeqangle = [d]
891
+ self.why += whyeqangle
892
+
893
+ else:
894
+ if (ab == cd and mn == pq) or (ab == mn and cd == pq):
895
+ self.why += []
896
+ elif ab == mn:
897
+ self.why += maybe_make_equal_pairs(
898
+ a, b, c, d, m, n, p, q, ab, mn, g, level
899
+ )
900
+ elif cd == pq:
901
+ self.why += maybe_make_equal_pairs(
902
+ c, d, a, b, p, q, m, n, cd, pq, g, level
903
+ )
904
+ elif ab == cd:
905
+ self.why += maybe_make_equal_pairs(
906
+ a, b, m, n, c, d, p, q, ab, cd, g, level
907
+ )
908
+ elif mn == pq:
909
+ self.why += maybe_make_equal_pairs(
910
+ m, n, a, b, p, q, c, d, mn, pq, g, level
911
+ )
912
+ elif g.is_equal(ab, mn) or g.is_equal(cd, pq):
913
+ dep1 = Dependency('para', [a, b, m, n], None, level)
914
+ dep1.why_me(g, level)
915
+ dep2 = Dependency('para', [c, d, p, q], None, level)
916
+ dep2.why_me(g, level)
917
+ self.why += [dep1, dep2]
918
+ elif g.is_equal(ab, cd) or g.is_equal(mn, pq):
919
+ dep1 = Dependency('para', [a, b, c, d], None, level)
920
+ dep1.why_me(g, level)
921
+ dep2 = Dependency('para', [m, n, p, q], None, level)
922
+ dep2.why_me(g, level)
923
+ self.why += [dep1, dep2]
924
+ elif ab._val and cd._val and mn._val and pq._val:
925
+ self.why = why_eqangle(ab._val, cd._val, mn._val, pq._val, level)
926
+
927
+ elif self.name in ['eqratio', 'eqratio6']:
928
+ a, b, c, d, m, n, p, q = self.args
929
+ ab = g._get_segment(a, b)
930
+ cd = g._get_segment(c, d)
931
+ mn = g._get_segment(m, n)
932
+ pq = g._get_segment(p, q)
933
+
934
+ if ab is None or cd is None or mn is None or pq is None:
935
+ if {a, b} == {m, n}:
936
+ d = Dependency('cong', [c, d, p, q], None, level)
937
+ self.why = [d.why_me_or_cache(g, level)]
938
+ if {a, b} == {c, d}:
939
+ d = Dependency('cong', [p, q, m, n], None, level)
940
+ self.why = [d.why_me_or_cache(g, level)]
941
+ if {c, d} == {p, q}:
942
+ d = Dependency('cong', [a, b, m, n], None, level)
943
+ self.why = [d.why_me_or_cache(g, level)]
944
+ if {p, q} == {m, n}:
945
+ d = Dependency('cong', [a, b, c, d], None, level)
946
+ self.why = [d.why_me_or_cache(g, level)]
947
+ return
948
+
949
+ if ab._val and cd._val and mn._val and pq._val:
950
+ self.why = why_eqratio(ab._val, cd._val, mn._val, pq._val, level)
951
+
952
+ if self.why is None:
953
+ self.why = []
954
+ if (ab == cd and mn == pq) or (ab == mn and cd == pq):
955
+ self.why = []
956
+ elif ab == mn:
957
+ self.why += maybe_make_equal_pairs(
958
+ a, b, c, d, m, n, p, q, ab, mn, g, level
959
+ )
960
+ elif cd == pq:
961
+ self.why += maybe_make_equal_pairs(
962
+ c, d, a, b, p, q, m, n, cd, pq, g, level
963
+ )
964
+ elif ab == cd:
965
+ self.why += maybe_make_equal_pairs(
966
+ a, b, m, n, c, d, p, q, ab, cd, g, level
967
+ )
968
+ elif mn == pq:
969
+ self.why += maybe_make_equal_pairs(
970
+ m, n, a, b, p, q, c, d, mn, pq, g, level
971
+ )
972
+ elif g.is_equal(ab, mn) or g.is_equal(cd, pq):
973
+ dep1 = Dependency('cong', [a, b, m, n], None, level)
974
+ dep1.why_me(g, level)
975
+ dep2 = Dependency('cong', [c, d, p, q], None, level)
976
+ dep2.why_me(g, level)
977
+ self.why += [dep1, dep2]
978
+ elif g.is_equal(ab, cd) or g.is_equal(mn, pq):
979
+ dep1 = Dependency('cong', [a, b, c, d], None, level)
980
+ dep1.why_me(g, level)
981
+ dep2 = Dependency('cong', [m, n, p, q], None, level)
982
+ dep2.why_me(g, level)
983
+ self.why += [dep1, dep2]
984
+ elif ab._val and cd._val and mn._val and pq._val:
985
+ self.why = why_eqangle(ab._val, cd._val, mn._val, pq._val, level)
986
+
987
+ elif self.name in ['diff', 'npara', 'nperp', 'ncoll', 'sameside']:
988
+ self.why = []
989
+
990
+ elif self.name == 'simtri':
991
+ a, b, c, x, y, z = self.args
992
+ dep1 = Dependency('eqangle', [a, b, a, c, x, y, x, z], '', level)
993
+ dep1.why_me(g, level)
994
+ dep2 = Dependency('eqangle', [b, a, b, c, y, x, y, z], '', level)
995
+ dep2.why_me(g, level)
996
+ self.rule_name = 'r34'
997
+ self.why = [dep1, dep2]
998
+
999
+ elif self.name == 'contri':
1000
+ a, b, c, x, y, z = self.args
1001
+ dep1 = Dependency('cong', [a, b, x, y], '', level)
1002
+ dep1.why_me(g, level)
1003
+ dep2 = Dependency('cong', [b, c, y, z], '', level)
1004
+ dep2.why_me(g, level)
1005
+ dep3 = Dependency('cong', [c, a, z, x], '', level)
1006
+ dep3.why_me(g, level)
1007
+ self.rule_name = 'r32'
1008
+ self.why = [dep1, dep2, dep3]
1009
+
1010
+ elif self.name == 'ind':
1011
+ pass
1012
+
1013
+ elif self.name == 'aconst':
1014
+ a, b, c, d, ang0 = self.args
1015
+
1016
+ measure = ang0._val
1017
+
1018
+ for ang in measure.neighbors(gm.Angle):
1019
+ if ang == ang0:
1020
+ continue
1021
+ d1, d2 = ang._d
1022
+ l1, l2 = d1._obj, d2._obj
1023
+ (a1, b1), (c1, d1) = l1.points, l2.points
1024
+
1025
+ if not g.check_para_or_coll([a, b, a1, b1]) or not g.check_para_or_coll(
1026
+ [c, d, c1, d1]
1027
+ ):
1028
+ continue
1029
+
1030
+ self.why = []
1031
+ for args in [(a, b, a1, b1), (c, d, c1, d1)]:
1032
+ if g.check_coll(args):
1033
+ if len(set(args)) > 2:
1034
+ dep = Dependency('coll', args, None, None)
1035
+ self.why.append(dep.why_me_or_cache(g, level))
1036
+ else:
1037
+ dep = Dependency('para', args, None, None)
1038
+ self.why.append(dep.why_me_or_cache(g, level))
1039
+
1040
+ self.why += gm.why_equal(ang, ang0)
1041
+ break
1042
+
1043
+ elif self.name == 'rconst':
1044
+ a, b, c, d, rat0 = self.args
1045
+
1046
+ val = rat0._val
1047
+
1048
+ for rat in val.neighbors(gm.Ratio):
1049
+ if rat == rat0:
1050
+ continue
1051
+ l1, l2 = rat._l
1052
+ s1, s2 = l1._obj, l2._obj
1053
+ (a1, b1), (c1, d1) = list(s1.points), list(s2.points)
1054
+
1055
+ if not g.check_cong([a, b, a1, b1]) or not g.check_cong([c, d, c1, d1]):
1056
+ continue
1057
+
1058
+ self.why = []
1059
+ for args in [(a, b, a1, b1), (c, d, c1, d1)]:
1060
+ if len(set(args)) > 2:
1061
+ dep = Dependency('cong', args, None, None)
1062
+ self.why.append(dep.why_me_or_cache(g, level))
1063
+
1064
+ self.why += gm.why_equal(rat, rat0)
1065
+ break
1066
+
1067
+ else:
1068
+ raise ValueError('Not recognize', self.name)
1069
+
1070
+ def hashed(self, rename: bool = False) -> tuple[str, ...]:
1071
+ return hashed(self.name, self.args, rename=rename)
1072
+
1073
+
1074
+ def hashed(
1075
+ name: str, args: list[gm.Point], rename: bool = False
1076
+ ) -> tuple[str, ...]:
1077
+ if name == 's_angle':
1078
+ args = [p.name if not rename else p.new_name for p in args[:-1]] + [
1079
+ str(args[-1])
1080
+ ]
1081
+ else:
1082
+ args = [p.name if not rename else p.new_name for p in args]
1083
+ return hashed_txt(name, args)
1084
+
1085
+
1086
+ def hashed_txt(name: str, args: list[str]) -> tuple[str, ...]:
1087
+ """Return a tuple unique to name and args upto arg permutation equivariant."""
1088
+
1089
+ if name in ['const', 'aconst', 'rconst']:
1090
+ a, b, c, d, y = args
1091
+ a, b = sorted([a, b])
1092
+ c, d = sorted([c, d])
1093
+ return name, a, b, c, d, y
1094
+
1095
+ if name in ['npara', 'nperp', 'para', 'cong', 'perp', 'collx']:
1096
+ a, b, c, d = args
1097
+
1098
+ a, b = sorted([a, b])
1099
+ c, d = sorted([c, d])
1100
+ (a, b), (c, d) = sorted([(a, b), (c, d)])
1101
+
1102
+ return (name, a, b, c, d)
1103
+
1104
+ if name in ['midp', 'midpoint']:
1105
+ a, b, c = args
1106
+ b, c = sorted([b, c])
1107
+ return (name, a, b, c)
1108
+
1109
+ if name in ['coll', 'cyclic', 'ncoll', 'diff', 'triangle']:
1110
+ return (name,) + tuple(sorted(list(set(args))))
1111
+
1112
+ if name == 'circle':
1113
+ x, a, b, c = args
1114
+ return (name, x) + tuple(sorted([a, b, c]))
1115
+
1116
+ if name == 'semicircle':
1117
+ x, a, b, c = args
1118
+ return (name, x) + tuple(sorted([a, b, c]))
1119
+
1120
+ if name in ['eqangle', 'eqratio', 'eqangle6', 'eqratio6']:
1121
+ a, b, c, d, e, f, g, h = args
1122
+ a, b = sorted([a, b])
1123
+ c, d = sorted([c, d])
1124
+ e, f = sorted([e, f])
1125
+ g, h = sorted([g, h])
1126
+ if tuple(sorted([a, b, e, f])) > tuple(sorted([c, d, g, h])):
1127
+ a, b, e, f, c, d, g, h = c, d, g, h, a, b, e, f
1128
+ if (a, b, c, d) > (e, f, g, h):
1129
+ a, b, c, d, e, f, g, h = e, f, g, h, a, b, c, d
1130
+
1131
+ if name == 'eqangle6':
1132
+ name = 'eqangle'
1133
+ if name == 'eqratio6':
1134
+ name = 'eqratio'
1135
+ return (name,) + (a, b, c, d, e, f, g, h)
1136
+
1137
+ if name in ['contri', 'simtri', 'simtri2', 'contri2', 'contri*', 'simtri*']:
1138
+ a, b, c, x, y, z = args
1139
+ (a, x), (b, y), (c, z) = sorted([(a, x), (b, y), (c, z)], key=sorted)
1140
+ (a, b, c), (x, y, z) = sorted([(a, b, c), (x, y, z)], key=sorted)
1141
+ return (name, a, b, c, x, y, z)
1142
+
1143
+ if name in ['eqratio3']:
1144
+ a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name
1145
+ (a, c), (b, d) = sorted([(a, c), (b, d)], key=sorted)
1146
+ (a, b), (c, d) = sorted([(a, b), (c, d)], key=sorted)
1147
+ return (name, a, b, c, d, o, o)
1148
+
1149
+ if name in ['sameside', 's_angle']:
1150
+ return (name,) + tuple(args)
1151
+
1152
+ raise ValueError(f'Not recognize {name} to hash.')
ag4masses/alphageometry/rules.txt CHANGED
@@ -41,3 +41,7 @@ eqratio6 B A B C Q P Q R, eqangle6 B A B C Q P Q R, ncoll A B C => simtri* A B C
41
  eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri* A B C P Q R
42
  para a b c d, coll m a d, coll n b c, eqratio6 m a m d n b n c, sameside m a d n b c => para m n a b
43
  para a b c d, coll m a d, coll n b c, para m n a b => eqratio6 m a m d n b n c
 
 
 
 
 
41
  eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri* A B C P Q R
42
  para a b c d, coll m a d, coll n b c, eqratio6 m a m d n b n c, sameside m a d n b c => para m n a b
43
  para a b c d, coll m a d, coll n b c, para m n a b => eqratio6 m a m d n b n c
44
+ semicircle O A B C, perp O A A X => eqangle A X A B C A C B
45
+ semicircle O A B C, eqangle A X A B C A C B => perp O A A X
46
+ semicircle O A B C, midp M B C => eqangle A B A C O B O M
47
+ semicircle O A B C, coll M B C, eqangle A B A C O B O M => midp M B C
ag4masses/alphageometry/trace_back.py CHANGED
@@ -1,374 +1,374 @@
1
- # Copyright 2023 DeepMind Technologies Limited
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Implements DAG-level traceback."""
17
-
18
- from typing import Any
19
-
20
- import geometry as gm
21
- import pretty as pt
22
- import problem
23
-
24
-
25
- pretty = pt.pretty
26
-
27
-
28
- def point_levels(
29
- setup: list[problem.Dependency], existing_points: list[gm.Point]
30
- ) -> list[tuple[set[gm.Point], list[problem.Dependency]]]:
31
- """Reformat setup into levels of point constructions."""
32
- levels = []
33
- for con in setup:
34
- plevel = max([p.plevel for p in con.args if isinstance(p, gm.Point)])
35
-
36
- while len(levels) - 1 < plevel:
37
- levels.append((set(), []))
38
-
39
- for p in con.args:
40
- if not isinstance(p, gm.Point):
41
- continue
42
- if existing_points and p in existing_points:
43
- continue
44
-
45
- levels[p.plevel][0].add(p)
46
-
47
- cons = levels[plevel][1]
48
- cons.append(con)
49
-
50
- return [(p, c) for p, c in levels if p or c]
51
-
52
-
53
- def point_log(
54
- setup: list[problem.Dependency],
55
- ref_id: dict[tuple[str, ...], int],
56
- existing_points=list[gm.Point],
57
- ) -> list[tuple[list[gm.Point], list[problem.Dependency]]]:
58
- """Reformat setup into groups of point constructions."""
59
- log = []
60
-
61
- levels = point_levels(setup, existing_points)
62
-
63
- for points, cons in levels:
64
- for con in cons:
65
- if con.hashed() not in ref_id:
66
- ref_id[con.hashed()] = len(ref_id)
67
-
68
- log.append((points, cons))
69
-
70
- return log
71
-
72
-
73
- def setup_to_levels(
74
- setup: list[problem.Dependency],
75
- ) -> list[list[problem.Dependency]]:
76
- """Reformat setup into levels of point constructions."""
77
- levels = []
78
- for d in setup:
79
- plevel = max([p.plevel for p in d.args if isinstance(p, gm.Point)])
80
- while len(levels) - 1 < plevel:
81
- levels.append([])
82
-
83
- levels[plevel].append(d)
84
-
85
- levels = [lvl for lvl in levels if lvl]
86
- return levels
87
-
88
-
89
- def separate_dependency_difference(
90
- query: problem.Dependency,
91
- log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
92
- ) -> tuple[
93
- list[tuple[list[problem.Dependency], list[problem.Dependency]]],
94
- list[problem.Dependency],
95
- list[problem.Dependency],
96
- set[gm.Point],
97
- set[gm.Point],
98
- ]:
99
- """Identify and separate the dependency difference."""
100
- setup = []
101
- log_, log = log, []
102
- for prems, cons in log_:
103
- if not prems:
104
- setup.extend(cons)
105
- continue
106
- cons_ = []
107
- for con in cons:
108
- if con.rule_name == 'c0':
109
- setup.append(con)
110
- else:
111
- cons_.append(con)
112
- if not cons_:
113
- continue
114
-
115
- prems = [p for p in prems if p.name != 'ind']
116
- log.append((prems, cons_))
117
-
118
- points = set(query.args)
119
- queue = list(query.args)
120
- i = 0
121
- while i < len(queue):
122
- q = queue[i]
123
- i += 1
124
- if not isinstance(q, gm.Point):
125
- continue
126
- for p in q.rely_on:
127
- if p not in points:
128
- points.add(p)
129
- queue.append(p)
130
-
131
- setup_, setup, aux_setup, aux_points = setup, [], [], set()
132
- for con in setup_:
133
- if con.name == 'ind':
134
- continue
135
- elif any([p not in points for p in con.args if isinstance(p, gm.Point)]):
136
- aux_setup.append(con)
137
- aux_points.update(
138
- [p for p in con.args if isinstance(p, gm.Point) and p not in points]
139
- )
140
- else:
141
- setup.append(con)
142
-
143
- return log, setup, aux_setup, points, aux_points
144
-
145
-
146
- def recursive_traceback(
147
- query: problem.Dependency,
148
- ) -> list[tuple[list[problem.Dependency], list[problem.Dependency]]]:
149
- """Recursively traceback from the query, i.e. the conclusion."""
150
- visited = set()
151
- log = []
152
- stack = []
153
-
154
- def read(q: problem.Dependency) -> None:
155
- q = q.remove_loop()
156
- hashed = q.hashed()
157
- if hashed in visited:
158
- return
159
-
160
- if hashed[0] in ['ncoll', 'npara', 'nperp', 'diff', 'sameside']:
161
- return
162
-
163
- nonlocal stack
164
-
165
- stack.append(hashed)
166
- prems = []
167
-
168
- if q.rule_name != problem.CONSTRUCTION_RULE:
169
- all_deps = []
170
- dep_names = set()
171
- for d in q.why:
172
- if d.hashed() in dep_names:
173
- continue
174
- dep_names.add(d.hashed())
175
- all_deps.append(d)
176
-
177
- for d in all_deps:
178
- h = d.hashed()
179
- if h not in visited:
180
- read(d)
181
- if h in visited:
182
- prems.append(d)
183
-
184
- visited.add(hashed)
185
- hashs = sorted([d.hashed() for d in prems])
186
- found = False
187
- for ps, qs in log:
188
- if sorted([d.hashed() for d in ps]) == hashs:
189
- qs += [q]
190
- found = True
191
- break
192
- if not found:
193
- log.append((prems, [q]))
194
-
195
- stack.pop(-1)
196
-
197
- read(query)
198
-
199
- # post process log: separate multi-conclusion lines
200
- log_, log = log, []
201
- for ps, qs in log_:
202
- for q in qs:
203
- log.append((ps, [q]))
204
-
205
- return log
206
-
207
-
208
- def collx_to_coll_setup(
209
- setup: list[problem.Dependency],
210
- ) -> list[problem.Dependency]:
211
- """Convert collx to coll in setups."""
212
- result = []
213
- for level in setup_to_levels(setup):
214
- hashs = set()
215
- for dep in level:
216
- if dep.name == 'collx':
217
- dep.name = 'coll'
218
- dep.args = list(set(dep.args))
219
-
220
- if dep.hashed() in hashs:
221
- continue
222
- hashs.add(dep.hashed())
223
- result.append(dep)
224
-
225
- return result
226
-
227
-
228
- def collx_to_coll(
229
- setup: list[problem.Dependency],
230
- aux_setup: list[problem.Dependency],
231
- log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
232
- ) -> tuple[
233
- list[problem.Dependency],
234
- list[problem.Dependency],
235
- list[tuple[list[problem.Dependency], list[problem.Dependency]]],
236
- ]:
237
- """Convert collx to coll and dedup."""
238
- setup = collx_to_coll_setup(setup)
239
- aux_setup = collx_to_coll_setup(aux_setup)
240
-
241
- con_set = set([p.hashed() for p in setup + aux_setup])
242
- log_, log = log, []
243
- for prems, cons in log_:
244
- prem_set = set()
245
- prems_, prems = prems, []
246
- for p in prems_:
247
- if p.name == 'collx':
248
- p.name = 'coll'
249
- p.args = list(set(p.args))
250
- if p.hashed() in prem_set:
251
- continue
252
- prem_set.add(p.hashed())
253
- prems.append(p)
254
-
255
- cons_, cons = cons, []
256
- for c in cons_:
257
- if c.name == 'collx':
258
- c.name = 'coll'
259
- c.args = list(set(c.args))
260
- if c.hashed() in con_set:
261
- continue
262
- con_set.add(c.hashed())
263
- cons.append(c)
264
-
265
- if not cons or not prems:
266
- continue
267
-
268
- log.append((prems, cons))
269
-
270
- return setup, aux_setup, log
271
-
272
-
273
- def get_logs(
274
- query: problem.Dependency, g: Any, merge_trivials: bool = False
275
- ) -> tuple[
276
- list[problem.Dependency],
277
- list[problem.Dependency],
278
- list[tuple[list[problem.Dependency], list[problem.Dependency]]],
279
- set[gm.Point],
280
- ]:
281
- """Given a DAG and conclusion N, return the premise, aux, proof."""
282
- query = query.why_me_or_cache(g, query.level)
283
- log = recursive_traceback(query)
284
- log, setup, aux_setup, setup_points, _ = separate_dependency_difference(
285
- query, log
286
- )
287
-
288
- setup, aux_setup, log = collx_to_coll(setup, aux_setup, log)
289
-
290
- setup, aux_setup, log = shorten_and_shave(
291
- setup, aux_setup, log, merge_trivials
292
- )
293
-
294
- return setup, aux_setup, log, setup_points
295
-
296
-
297
- def shorten_and_shave(
298
- setup: list[problem.Dependency],
299
- aux_setup: list[problem.Dependency],
300
- log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
301
- merge_trivials: bool = False,
302
- ) -> tuple[
303
- list[problem.Dependency],
304
- list[problem.Dependency],
305
- list[tuple[list[problem.Dependency], list[problem.Dependency]]],
306
- ]:
307
- """Shorten the proof by removing unused predicates."""
308
- log, _ = shorten_proof(log, merge_trivials=merge_trivials)
309
-
310
- all_prems = sum([list(prems) for prems, _ in log], [])
311
- all_prems = set([p.hashed() for p in all_prems])
312
- setup = [d for d in setup if d.hashed() in all_prems]
313
- aux_setup = [d for d in aux_setup if d.hashed() in all_prems]
314
- return setup, aux_setup, log
315
-
316
-
317
- def join_prems(
318
- con: problem.Dependency,
319
- con2prems: dict[tuple[str, ...], list[problem.Dependency]],
320
- expanded: set[tuple[str, ...]],
321
- ) -> list[problem.Dependency]:
322
- """Join proof steps with the same premises."""
323
- h = con.hashed()
324
- if h in expanded or h not in con2prems:
325
- return [con]
326
-
327
- result = []
328
- for p in con2prems[h]:
329
- result += join_prems(p, con2prems, expanded)
330
- return result
331
-
332
-
333
- def shorten_proof(
334
- log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
335
- merge_trivials: bool = False,
336
- ) -> tuple[
337
- list[tuple[list[problem.Dependency], list[problem.Dependency]]],
338
- dict[tuple[str, ...], list[problem.Dependency]],
339
- ]:
340
- """Join multiple trivials proof steps into one."""
341
- pops = set()
342
- con2prem = {}
343
- for prems, cons in log:
344
- assert len(cons) == 1
345
- con = cons[0]
346
- if con.rule_name == '': # pylint: disable=g-explicit-bool-comparison
347
- con2prem[con.hashed()] = prems
348
- elif not merge_trivials:
349
- # except for the ones that are premises to non-trivial steps.
350
- pops.update({p.hashed() for p in prems})
351
-
352
- for p in pops:
353
- if p in con2prem:
354
- con2prem.pop(p)
355
-
356
- expanded = set()
357
- log2 = []
358
- for i, (prems, cons) in enumerate(log):
359
- con = cons[0]
360
- if i < len(log) - 1 and con.hashed() in con2prem:
361
- continue
362
-
363
- hashs = set()
364
- new_prems = []
365
-
366
- for p in sum([join_prems(p, con2prem, expanded) for p in prems], []):
367
- if p.hashed() not in hashs:
368
- new_prems.append(p)
369
- hashs.add(p.hashed())
370
-
371
- log2 += [(new_prems, [con])]
372
- expanded.add(con.hashed())
373
-
374
- return log2, con2prem
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Implements DAG-level traceback."""
17
+
18
+ from typing import Any
19
+
20
+ import geometry as gm
21
+ import pretty as pt
22
+ import problem
23
+
24
+
25
+ pretty = pt.pretty
26
+
27
+
28
+ def point_levels(
29
+ setup: list[problem.Dependency], existing_points: list[gm.Point]
30
+ ) -> list[tuple[set[gm.Point], list[problem.Dependency]]]:
31
+ """Reformat setup into levels of point constructions."""
32
+ levels = []
33
+ for con in setup:
34
+ plevel = max([p.plevel for p in con.args if isinstance(p, gm.Point)])
35
+
36
+ while len(levels) - 1 < plevel:
37
+ levels.append((set(), []))
38
+
39
+ for p in con.args:
40
+ if not isinstance(p, gm.Point):
41
+ continue
42
+ if existing_points and p in existing_points:
43
+ continue
44
+
45
+ levels[p.plevel][0].add(p)
46
+
47
+ cons = levels[plevel][1]
48
+ cons.append(con)
49
+
50
+ return [(p, c) for p, c in levels if p or c]
51
+
52
+
53
+ def point_log(
54
+ setup: list[problem.Dependency],
55
+ ref_id: dict[tuple[str, ...], int],
56
+ existing_points=list[gm.Point],
57
+ ) -> list[tuple[list[gm.Point], list[problem.Dependency]]]:
58
+ """Reformat setup into groups of point constructions."""
59
+ log = []
60
+
61
+ levels = point_levels(setup, existing_points)
62
+
63
+ for points, cons in levels:
64
+ for con in cons:
65
+ if con.hashed() not in ref_id:
66
+ ref_id[con.hashed()] = len(ref_id)
67
+
68
+ log.append((points, cons))
69
+
70
+ return log
71
+
72
+
73
+ def setup_to_levels(
74
+ setup: list[problem.Dependency],
75
+ ) -> list[list[problem.Dependency]]:
76
+ """Reformat setup into levels of point constructions."""
77
+ levels = []
78
+ for d in setup:
79
+ plevel = max([p.plevel for p in d.args if isinstance(p, gm.Point)])
80
+ while len(levels) - 1 < plevel:
81
+ levels.append([])
82
+
83
+ levels[plevel].append(d)
84
+
85
+ levels = [lvl for lvl in levels if lvl]
86
+ return levels
87
+
88
+
89
+ def separate_dependency_difference(
90
+ query: problem.Dependency,
91
+ log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
92
+ ) -> tuple[
93
+ list[tuple[list[problem.Dependency], list[problem.Dependency]]],
94
+ list[problem.Dependency],
95
+ list[problem.Dependency],
96
+ set[gm.Point],
97
+ set[gm.Point],
98
+ ]:
99
+ """Identify and separate the dependency difference."""
100
+ setup = []
101
+ log_, log = log, []
102
+ for prems, cons in log_:
103
+ if not prems:
104
+ setup.extend(cons)
105
+ continue
106
+ cons_ = []
107
+ for con in cons:
108
+ if con.rule_name == 'c0':
109
+ setup.append(con)
110
+ else:
111
+ cons_.append(con)
112
+ if not cons_:
113
+ continue
114
+
115
+ prems = [p for p in prems if p.name != 'ind']
116
+ log.append((prems, cons_))
117
+
118
+ points = set(query.args)
119
+ queue = list(query.args)
120
+ i = 0
121
+ while i < len(queue):
122
+ q = queue[i]
123
+ i += 1
124
+ if not isinstance(q, gm.Point):
125
+ continue
126
+ for p in q.rely_on:
127
+ if p not in points:
128
+ points.add(p)
129
+ queue.append(p)
130
+
131
+ setup_, setup, aux_setup, aux_points = setup, [], [], set()
132
+ for con in setup_:
133
+ if con.name == 'ind':
134
+ continue
135
+ elif any([p not in points for p in con.args if isinstance(p, gm.Point)]):
136
+ aux_setup.append(con)
137
+ aux_points.update(
138
+ [p for p in con.args if isinstance(p, gm.Point) and p not in points]
139
+ )
140
+ else:
141
+ setup.append(con)
142
+
143
+ return log, setup, aux_setup, points, aux_points
144
+
145
+
146
+ def recursive_traceback(
147
+ query: problem.Dependency,
148
+ ) -> list[tuple[list[problem.Dependency], list[problem.Dependency]]]:
149
+ """Recursively traceback from the query, i.e. the conclusion."""
150
+ visited = set()
151
+ log = []
152
+ stack = []
153
+
154
+ def read(q: problem.Dependency) -> None:
155
+ q = q.remove_loop()
156
+ hashed = q.hashed()
157
+ if hashed in visited:
158
+ return
159
+
160
+ if hashed[0] in ['ncoll', 'npara', 'nperp', 'diff', 'sameside']:
161
+ return
162
+
163
+ nonlocal stack
164
+
165
+ stack.append(hashed)
166
+ prems = []
167
+
168
+ if q.rule_name != problem.CONSTRUCTION_RULE:
169
+ all_deps = []
170
+ dep_names = set()
171
+ for d in q.why:
172
+ if d.hashed() in dep_names:
173
+ continue
174
+ dep_names.add(d.hashed())
175
+ all_deps.append(d)
176
+
177
+ for d in all_deps:
178
+ h = d.hashed()
179
+ if h not in visited:
180
+ read(d)
181
+ if h in visited:
182
+ prems.append(d)
183
+
184
+ visited.add(hashed)
185
+ hashs = sorted([d.hashed() for d in prems])
186
+ found = False
187
+ for ps, qs in log:
188
+ if sorted([d.hashed() for d in ps]) == hashs:
189
+ qs += [q]
190
+ found = True
191
+ break
192
+ if not found:
193
+ log.append((prems, [q]))
194
+
195
+ stack.pop(-1)
196
+
197
+ read(query)
198
+
199
+ # post process log: separate multi-conclusion lines
200
+ log_, log = log, []
201
+ for ps, qs in log_:
202
+ for q in qs:
203
+ log.append((ps, [q]))
204
+
205
+ return log
206
+
207
+
208
+ def collx_to_coll_setup(
209
+ setup: list[problem.Dependency],
210
+ ) -> list[problem.Dependency]:
211
+ """Convert collx to coll in setups."""
212
+ result = []
213
+ for level in setup_to_levels(setup):
214
+ hashs = set()
215
+ for dep in level:
216
+ if dep.name == 'collx':
217
+ dep.name = 'coll'
218
+ dep.args = list(set(dep.args))
219
+
220
+ if dep.hashed() in hashs:
221
+ continue
222
+ hashs.add(dep.hashed())
223
+ result.append(dep)
224
+
225
+ return result
226
+
227
+
228
+ def collx_to_coll(
229
+ setup: list[problem.Dependency],
230
+ aux_setup: list[problem.Dependency],
231
+ log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
232
+ ) -> tuple[
233
+ list[problem.Dependency],
234
+ list[problem.Dependency],
235
+ list[tuple[list[problem.Dependency], list[problem.Dependency]]],
236
+ ]:
237
+ """Convert collx to coll and dedup."""
238
+ setup = collx_to_coll_setup(setup)
239
+ aux_setup = collx_to_coll_setup(aux_setup)
240
+
241
+ con_set = set([p.hashed() for p in setup + aux_setup])
242
+ log_, log = log, []
243
+ for prems, cons in log_:
244
+ prem_set = set()
245
+ prems_, prems = prems, []
246
+ for p in prems_:
247
+ if p.name == 'collx':
248
+ p.name = 'coll'
249
+ p.args = list(set(p.args))
250
+ if p.hashed() in prem_set:
251
+ continue
252
+ prem_set.add(p.hashed())
253
+ prems.append(p)
254
+
255
+ cons_, cons = cons, []
256
+ for c in cons_:
257
+ if c.name == 'collx':
258
+ c.name = 'coll'
259
+ c.args = list(set(c.args))
260
+ if c.hashed() in con_set:
261
+ continue
262
+ con_set.add(c.hashed())
263
+ cons.append(c)
264
+
265
+ if not cons or not prems:
266
+ continue
267
+
268
+ log.append((prems, cons))
269
+
270
+ return setup, aux_setup, log
271
+
272
+
273
+ def get_logs(
274
+ query: problem.Dependency, g: Any, merge_trivials: bool = False
275
+ ) -> tuple[
276
+ list[problem.Dependency],
277
+ list[problem.Dependency],
278
+ list[tuple[list[problem.Dependency], list[problem.Dependency]]],
279
+ set[gm.Point],
280
+ ]:
281
+ """Given a DAG and conclusion N, return the premise, aux, proof."""
282
+ query = query.why_me_or_cache(g, query.level)
283
+ log = recursive_traceback(query)
284
+ log, setup, aux_setup, setup_points, _ = separate_dependency_difference(
285
+ query, log
286
+ )
287
+
288
+ setup, aux_setup, log = collx_to_coll(setup, aux_setup, log)
289
+
290
+ setup, aux_setup, log = shorten_and_shave(
291
+ setup, aux_setup, log, merge_trivials
292
+ )
293
+
294
+ return setup, aux_setup, log, setup_points
295
+
296
+
297
+ def shorten_and_shave(
298
+ setup: list[problem.Dependency],
299
+ aux_setup: list[problem.Dependency],
300
+ log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
301
+ merge_trivials: bool = False,
302
+ ) -> tuple[
303
+ list[problem.Dependency],
304
+ list[problem.Dependency],
305
+ list[tuple[list[problem.Dependency], list[problem.Dependency]]],
306
+ ]:
307
+ """Shorten the proof by removing unused predicates."""
308
+ log, _ = shorten_proof(log, merge_trivials=merge_trivials)
309
+
310
+ all_prems = sum([list(prems) for prems, _ in log], [])
311
+ all_prems = set([p.hashed() for p in all_prems])
312
+ setup = [d for d in setup if d.hashed() in all_prems]
313
+ aux_setup = [d for d in aux_setup if d.hashed() in all_prems]
314
+ return setup, aux_setup, log
315
+
316
+
317
+ def join_prems(
318
+ con: problem.Dependency,
319
+ con2prems: dict[tuple[str, ...], list[problem.Dependency]],
320
+ expanded: set[tuple[str, ...]],
321
+ ) -> list[problem.Dependency]:
322
+ """Join proof steps with the same premises."""
323
+ h = con.hashed()
324
+ if h in expanded or h not in con2prems:
325
+ return [con]
326
+
327
+ result = []
328
+ for p in con2prems[h]:
329
+ result += join_prems(p, con2prems, expanded)
330
+ return result
331
+
332
+
333
+ def shorten_proof(
334
+ log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
335
+ merge_trivials: bool = False,
336
+ ) -> tuple[
337
+ list[tuple[list[problem.Dependency], list[problem.Dependency]]],
338
+ dict[tuple[str, ...], list[problem.Dependency]],
339
+ ]:
340
+ """Join multiple trivials proof steps into one."""
341
+ pops = set()
342
+ con2prem = {}
343
+ for prems, cons in log:
344
+ assert len(cons) == 1
345
+ con = cons[0]
346
+ if con.rule_name == '': # pylint: disable=g-explicit-bool-comparison
347
+ con2prem[con.hashed()] = prems
348
+ elif not merge_trivials:
349
+ # except for the ones that are premises to non-trivial steps.
350
+ pops.update({p.hashed() for p in prems})
351
+
352
+ for p in pops:
353
+ if p in con2prem:
354
+ con2prem.pop(p)
355
+
356
+ expanded = set()
357
+ log2 = []
358
+ for i, (prems, cons) in enumerate(log):
359
+ con = cons[0]
360
+ if i < len(log) - 1 and con.hashed() in con2prem:
361
+ continue
362
+
363
+ hashs = set()
364
+ new_prems = []
365
+
366
+ for p in sum([join_prems(p, con2prem, expanded) for p in prems], []):
367
+ if p.hashed() not in hashs:
368
+ new_prems.append(p)
369
+ hashs.add(p.hashed())
370
+
371
+ log2 += [(new_prems, [con])]
372
+ expanded.add(con.hashed())
373
+
374
+ return log2, con2prem
ag4masses/alphageometry/transformer_layer.py CHANGED
@@ -1,527 +1,526 @@
1
- # Copyright 2023 DeepMind Technologies Limited
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """A single transformer layer in inference mode.
17
-
18
- Modified
19
- https://github.com/google-research/meliad/blob/main/transformer/transformer_layer.py
20
- To accommodate sequence packing + kv cache + relative position during test time.
21
- """
22
-
23
- from typing import Callable, Mapping, NewType, Optional, Tuple
24
-
25
- from absl import logging
26
- import gin
27
- import jax
28
- import jax.numpy as jnp
29
- from transformer import attention
30
- from transformer import nn_components
31
- from transformer import position
32
- from transformer import transformer_layer
33
-
34
-
35
- Array = jnp.ndarray
36
- DecoderState = NewType("DecoderState", Mapping[str, Array])
37
- WindowState = Optional[Tuple[attention.KVITuple, Array]]
38
-
39
-
40
- @jax.vmap
41
- def update_slice_in_dim_1(array: Array, update: Array, idx: Array) -> Array:
42
- """Update a stored keys/values slice for different-lengthed seqs in batch."""
43
- return jax.lax.dynamic_update_slice_in_dim(array, update, idx, axis=0)
44
-
45
-
46
- def slice_in_dim_1(window_length: int) -> Callable[[Array, Array], Array]:
47
- @jax.vmap
48
- def fn(array: Array, idx: Array) -> Array:
49
- return jax.lax.dynamic_slice_in_dim(array, idx, window_length, axis=0)
50
-
51
- return fn
52
-
53
-
54
- @gin.configurable
55
- class TransformerLayerGenerate(transformer_layer.TransformerLayer):
56
- """Full transformer layer, with attention."""
57
-
58
- def _next_decoder_state(
59
- self, decoder_state: DecoderState, keys: Array, values: Array
60
- ) -> Tuple[DecoderState, Array, Array]:
61
- """Compute the next decoder state, and return keys,values to attend to.
62
-
63
- The keys,values returned from this function are drawn from the prior
64
- decoding state, and comprise a full window of local context.
65
-
66
- Args:
67
- decoder_state: The current decoder state, initially created using
68
- init_decoder_state().
69
- keys: The key for the current token, of shape (batch_size, 1, dim)
70
- values: The value for the current token of shape (batch_size, 1, dim)
71
-
72
- Returns:
73
- (next_decoder_state,
74
- window of keys of shape (batch_size, window_length, dim),
75
- window of values of shape (batch_size, window_length, dim))
76
- """
77
-
78
- assert keys.shape[1] == 1 # single-token autoregressive decoding.
79
-
80
- # Unpack decoder_state
81
- stored_keys = decoder_state["keys"]
82
- stored_values = decoder_state["values"]
83
- curr_index = decoder_state["current_index"]
84
-
85
- # Slice to get window_length-sized chunk of previous keys,values.
86
- out_decoder_state = {}
87
- curr_win_index = curr_index - self.window_length
88
-
89
- # out_keys = jax.lax.dynamic_slice_in_dim(
90
- # stored_keys, curr_win_index, self.window_length, axis=1)
91
- out_keys = slice_in_dim_1(self.window_length)(stored_keys, curr_win_index)
92
-
93
- # out_values = jax.lax.dynamic_slice_in_dim(
94
- # stored_values, curr_win_index, self.window_length, axis=1)
95
- out_values = slice_in_dim_1(self.window_length)(
96
- stored_values, curr_win_index
97
- )
98
-
99
- # Write current keys,values to stored keys, values.
100
- # stored_keys = jax.lax.dynamic_update_slice_in_dim(
101
- # stored_keys, keys, curr_index, axis=1)
102
- stored_keys = update_slice_in_dim_1(stored_keys, keys, curr_index)
103
- # stored_values = jax.lax.dynamic_update_slice_in_dim(
104
- # stored_values, values, curr_index, axis=1)
105
- stored_values = update_slice_in_dim_1(stored_values, values, curr_index)
106
- curr_index = curr_index + 1
107
-
108
- # Pack a new decoder_state object.
109
- out_decoder_state["keys"] = stored_keys
110
- out_decoder_state["values"] = stored_values
111
- out_decoder_state["current_index"] = curr_index
112
- out_decoder_state["relative_position_bias"] = decoder_state[
113
- "relative_position_bias"
114
- ]
115
- out_decoder_state["recurrent_kvq"] = decoder_state["recurrent_kvq"]
116
-
117
- return (DecoderState(out_decoder_state), out_keys, out_values)
118
-
119
- def __call__(
120
- self,
121
- xs: Array,
122
- start_of_sequence: Array,
123
- *,
124
- importance: Optional[Array] = None,
125
- cross_attention_kv: Optional[Tuple[Array, Array]] = None,
126
- window_state: Optional[WindowState] = None,
127
- decoder_state: Optional[DecoderState] = None,
128
- ):
129
- """Computes attention over a sequence of inputs.
130
-
131
- Args:
132
- xs: input sequence of shape (batch_size, sequence_length, num_hidden)
133
- start_of_sequence: An input array of shape (batch_size) --- The following
134
- must be passed by keyword only. ---
135
- importance: Array of shape (batch_size, sequence_length). An importance
136
- bias for attention.
137
- cross_attention_kv: Keys and values from encoder for cross-attention.
138
- window_state: State object which contains context from the prior window
139
- when using a transformer-XL or sliding window. Initially created with
140
- load_window_state().
141
- decoder_state: State object for autoregressive decoding, initially created
142
- with from init_decoder_state().
143
-
144
- Returns:
145
- (ys: outputs of shape (batch_size, sequence_length, num_hidden),
146
- importance_score: importance score for the next layer,
147
- next_window_state: state to pass to the next window,
148
- next_decoder_state: next decoder state for autoregressive decoding,
149
- viz_dict: dictionary of visualizations
150
- )
151
- """
152
-
153
- xs = jnp.asarray(xs, dtype=self.dtype)
154
- logging.info("tlayer: recurrent = %r", self.recurrent_attention)
155
- logging.info("tlayer: compute_importance = %r", self.compute_importance)
156
-
157
- is_training = self.mode == "train"
158
-
159
- # Compute keys, values and queries.
160
- # ---------------------------------
161
- logging.info("tlayer: compute keys,values,queries.")
162
- (keys, values, queries, queries2) = self.tbase.kvq(xs)
163
- attention_scale_factors = self.tbase.attention_scale_factors()
164
- (_, sequence_length, num_heads, _) = queries.shape # (b, k, h, d)
165
-
166
- # Get biases and masks that are shared across windows.
167
- # ----------------------------------------------------
168
- if decoder_state is not None:
169
- logging.info("tlayer: using autoregressive decoder.")
170
- # When decoding, prior keys,values are loaded from the decoder state.
171
- # Other values are precomputed, and loaded from the decoder state.
172
- # The decoder state will be updated with the current token.
173
- assert window_state is None
174
-
175
- prev_kvi = None
176
- recurrent_state = None # Use precomputed recurrent_kvq.
177
- cross_attention_kv = None
178
- rel_position_bias = decoder_state["relative_position_bias"]
179
- causal_mask = None
180
- dropout_multiplier = None
181
-
182
- # Reuse cached recurrent keys,values for each token.
183
- cached_recurrent_kvq = decoder_state["recurrent_kvq"]
184
- if cached_recurrent_kvq is not None:
185
- assert cross_attention_kv is None
186
- cross_attention_kv = (cached_recurrent_kvq[0], cached_recurrent_kvq[1])
187
- del cached_recurrent_kvq
188
-
189
- # Get a full window of keys,values and update decoder state.
190
- (decoder_state, keys, values) = self._next_decoder_state(
191
- decoder_state, keys, values
192
- )
193
-
194
- # Each query attends to window_length prior keys.
195
- assert keys.shape[1] == self.window_length
196
- kq_relative_offset = self.window_length
197
-
198
- if not self.use_long_xl_architecture:
199
- kqpos = position.relative_positions(
200
- 1, self.window_length, offset=0
201
- ) # 2D mask
202
- current_idx = decoder_state["current_index"]
203
-
204
- # add (batch, heads) dims for kqpos
205
- kqpos = jnp.expand_dims(kqpos, axis=(0, 1))
206
- kqpos = jnp.tile(kqpos, (1, self.num_heads, 1, 1))
207
-
208
- # add (_, heads, _) dim for current_idx
209
- current_idx = jnp.expand_dims(current_idx, axis=(1, 2, 3))
210
-
211
- causal_mask = kqpos > self.window_length * 2 - current_idx
212
- else:
213
- logging.info("tlayer: windowed attention.")
214
- # When training, attention is done using windows or chunks, and prior
215
- # context (e.g. keys,values from the previous window) is stored in the
216
- # window_state object.
217
- (prev_kvi, recurrent_state) = (
218
- window_state # pytype: disable=attribute-error
219
- )
220
-
221
- # Get the size of the sliding window for pos bias, dropout, & causal mask.
222
- (num_queries, num_keys) = attention.sliding_attention_window_shape(
223
- (keys, values, importance),
224
- prev_kvi,
225
- queries,
226
- window_length=self.window_length,
227
- )
228
- kq_relative_offset = num_keys - num_queries
229
-
230
- # Get the relative position bias.
231
- # The bias doesn't depend on the query content, and so can be precomputed.
232
- if self.relative_positions is not None:
233
- rel_position_bias = self.relative_positions(
234
- num_queries, num_keys, bidirectional=False
235
- )
236
- else:
237
- rel_position_bias = None
238
-
239
- # Get causal mask.
240
- if self.use_causal_mask:
241
- causal_mask = position.causal_mask(
242
- num_queries, num_keys, window_length=self.window_length
243
- )
244
- else:
245
- causal_mask = None
246
-
247
- # Apply dropout to the attention matrix.
248
- # The mask will be broadcast across batches and windows.
249
- if self.attn_dropout_rate > 0.0 and is_training:
250
- dropout_rng = self.make_rng("dropout")
251
- attn_shape = (self.num_heads, num_queries, num_keys)
252
- dropout_multiplier = nn_components.dropout_multiplier_mask(
253
- dropout_rng, self.attn_dropout_rate, attn_shape, self.dtype
254
- )
255
- else:
256
- dropout_multiplier = None
257
-
258
- # Load and store values into external memory, if memory is not None.
259
- # ------------------------------------------------------------------
260
- (mode, _, update_memory) = self._get_cache_name_from_mode(self.mode)
261
- external_kv = self._query_external_memory(
262
- keys,
263
- values,
264
- queries,
265
- start_of_sequence=start_of_sequence,
266
- mode=mode,
267
- update_memory=decoder_state is None and update_memory,
268
- )
269
-
270
- if (
271
- self.memory is not None
272
- and self.memory_combine_with_local == "TRAINABLE_WEIGHTED_MEAN"
273
- ):
274
- external_memory_bias = jnp.asarray(self.memory_bias, dtype=self.dtype)
275
- external_memory_bias = jnp.reshape(
276
- external_memory_bias, (1, 1, num_heads, 1)
277
- )
278
- external_memory_bias = jax.nn.sigmoid(external_memory_bias)
279
- else:
280
- external_memory_bias = None
281
-
282
- # Compute the number of windows.
283
- # ------------------------------
284
- if sequence_length < self.window_length:
285
- num_windows = 1 # Happens with autoregressive decoding.
286
- elif sequence_length == self.window_length:
287
- num_windows = 1
288
- if self.use_long_xl_architecture:
289
- assert prev_kvi is not None
290
- else:
291
- if not self.use_long_xl_architecture:
292
- raise ValueError("Can only use sliding window with Transformer XL.")
293
- num_windows = sequence_length // self.window_length
294
- if (num_windows * self.window_length) != sequence_length:
295
- raise ValueError(
296
- f"Window length {self.window_length} must be a "
297
- + f"multiple of sequence length {sequence_length}"
298
- )
299
- logging.info("tlayer: num_windows = %d.", num_windows)
300
-
301
- # Define the function to do attention within a single window.
302
- # ---------------------------------------------------------
303
- def single_window_attention(
304
- carry: tuple[Array, Array], inputs_w: tuple[Array, Array]
305
- ) -> tuple[tuple[Array, Array], tuple[Array, Array]]:
306
- # This function uses the following variables from the outer scope.
307
- # They are listed here for clarity.
308
- nonlocal rel_position_bias
309
- nonlocal causal_mask
310
- nonlocal kq_relative_offset
311
- nonlocal dropout_multiplier
312
- nonlocal attention_scale_factors
313
- nonlocal external_memory_bias
314
- nonlocal cross_attention_kv # externally supplied.
315
-
316
- # keys,values,queries over the whole sequence will be split into chunks.
317
- # xs_w, kvqi_w, etc. are the chunk for the current window.
318
- (prev_kvi_w, rec_state) = carry # carried from one window to the next.
319
- (kvqi_w, external_kv_w) = inputs_w # inputs to the current window.
320
- # (keys_curr_w, values_curr_w, _, _, importance_curr_w) = kvqi_w
321
-
322
- # Concatenate keys,values from the previous window with the current
323
- # window to implement sliding window attention.
324
- (kvqi_w, next_kvi_w) = attention.concat_kvqi(kvqi_w, prev_kvi_w)
325
- (keys_w, values_w, queries_w, queries2_w, importance_w) = kvqi_w
326
-
327
- # Perform recurrent attention within the current window to get the next
328
- # recurrent state, and set up cross attention.
329
- if rec_state is not None:
330
- logging.info("tlayer: recurrent attention.")
331
-
332
- # NOTE -- recurrent states and input tokens are handled separately,
333
- # because they have separate learned positional embeddings. Due to
334
- # the way TransformerBase does cross-attention, this means that we use
335
- # separate key,value layers for rec_state and tokens_w.
336
-
337
- # Keys, values, queries from recurrent state.
338
- logging.info("tlayer: recurrent kvq.")
339
- rec_kvq = self.recurrent_tbase.kvq(rec_state)
340
- r_scale_factors = self.recurrent_tbase.attention_scale_factors()
341
- (r_keys, r_values, r_queries, r_queries2) = rec_kvq
342
-
343
- # Joint attention over both recurrent states and input tokens.
344
- logging.info("tlayer: recurrent self-attention.")
345
- r_attn_ys = attention.simple_attention(
346
- r_keys,
347
- r_values,
348
- r_queries,
349
- None,
350
- scale_factor=r_scale_factors[0],
351
- dtype=self.dtype,
352
- )
353
-
354
- logging.info("tlayer: recurrent cross-attention.")
355
- r_cross_attn_ys = attention.simple_attention(
356
- keys_w,
357
- values_w,
358
- r_queries2,
359
- importance_w,
360
- scale_factor=r_scale_factors[1],
361
- dtype=self.dtype,
362
- )
363
-
364
- # Recurrent post-attention FFN.
365
- logging.info("tlayer: recurrent ffn.")
366
- next_rec_state = self.recurrent_tbase.post_attn_ffn(
367
- rec_state, r_attn_ys, r_cross_attn_ys
368
- )
369
-
370
- # Get keys and values for cross-attention from recurrent state.
371
- assert cross_attention_kv is None
372
- local_cross_attention_kv = (r_keys, r_values)
373
- else:
374
- # Get keys and values for cross-attention from external argument.
375
- next_rec_state = None
376
- local_cross_attention_kv = cross_attention_kv
377
-
378
- # If using RoPE, keys and queries are rotated before self-attention.
379
- if self.relative_position_type == "rotary":
380
- logging.info(
381
- "Using rotary position encodings (RoPE), offset = %d",
382
- kq_relative_offset,
383
- )
384
- (keys_w, queries_w) = position.rotate_kq(
385
- keys_w, queries_w, max_wavelength=10_000, offset=kq_relative_offset
386
- )
387
-
388
- # Self-attention over input tokens.
389
- logging.info("tlayer: self-attention.")
390
- attn_ys_w = attention.simple_attention(
391
- keys_w,
392
- values_w,
393
- queries_w,
394
- importance_w,
395
- relative_position_bias=rel_position_bias,
396
- scale_factor=attention_scale_factors[0],
397
- causal_mask=causal_mask,
398
- dropout_multiplier=dropout_multiplier,
399
- dtype=self.dtype,
400
- )
401
-
402
- # Attention over external memory.
403
- if external_kv_w is not None:
404
- (external_keys_w, external_values_w) = external_kv_w
405
- y_ext = attention.external_attention(
406
- external_keys_w,
407
- external_values_w,
408
- queries_w,
409
- scale_factor=attention_scale_factors[0],
410
- )
411
- if external_memory_bias is not None:
412
- ebias = external_memory_bias
413
- attn_ys_w = (attn_ys_w * (1 - ebias)) + (y_ext * ebias)
414
- elif self.memory_combine_with_local == "ADD":
415
- attn_ys_w += y_ext
416
- elif self.memory_combine_with_local == "STOP_FORWARD":
417
- attn_ys_w = y_ext + (attn_ys_w - jax.lax.stop_gradient(attn_ys_w))
418
- else:
419
- raise ValueError(
420
- f"Unexpected setting: {self.memory_combine_with_local = }"
421
- )
422
-
423
- # Cross attention from input tokens to encoder or recurrent state.
424
- if local_cross_attention_kv is not None:
425
- logging.info("tlayer: cross-attention.")
426
- (c_keys, c_values) = local_cross_attention_kv
427
-
428
- # Cross-attention using queries2.
429
- cross_attn_ys_w = attention.simple_attention(
430
- c_keys,
431
- c_values,
432
- queries2_w,
433
- None,
434
- scale_factor=attention_scale_factors[1],
435
- dtype=self.dtype,
436
- )
437
- else:
438
- cross_attn_ys_w = None
439
-
440
- # End function single_window_attention(...)
441
- return ((next_kvi_w, next_rec_state), (attn_ys_w, cross_attn_ys_w))
442
-
443
- # Initialize recurrent_tbase before calling jax.lax.scan.
444
- # Otherwise flax will throw a tantrum.
445
- if (
446
- self.recurrent_attention
447
- and 0 <= self.max_unrolled_windows
448
- and self.max_unrolled_windows < num_windows
449
- ):
450
- logging.info("tlayer: force initialization of recurrent_tbase.")
451
- self.recurrent_tbase.force_init(recurrent_state)
452
-
453
- # Perform sliding window attention over all keys,values,queries.
454
- # --------------------------------------------------------------
455
- initial_carry = (prev_kvi, recurrent_state) # window state.
456
- kvqi = (keys, values, queries, queries2, importance)
457
- attn_inputs = (kvqi, external_kv)
458
- (next_carry, attn_outputs) = attention.split_and_scan(
459
- single_window_attention,
460
- initial_carry,
461
- attn_inputs,
462
- sections=num_windows,
463
- axis=1,
464
- max_unrolled_windows=self.max_unrolled_windows,
465
- )
466
- (attn_ys, cross_attn_ys) = attn_outputs
467
-
468
- logging.info("tlayer: End windows.")
469
-
470
- # Post-attention MLP, resnet, and FFN.
471
- # ------------------------------------
472
- logging.info("tlayer: final FFN.")
473
- ys = self.tbase.post_attn_ffn(xs, attn_ys, cross_attn_ys)
474
-
475
- # Compute importance scores for each token if requested.
476
- if self.compute_importance:
477
- (batch_size, sequence_length, _) = ys.shape
478
- importance_score = self.importance_layer(ys)
479
- importance_score = importance_score.reshape((batch_size, sequence_length))
480
- else:
481
- importance_score = None
482
-
483
- next_window_state = next_carry if window_state is not None else None
484
- viz_dict = {} # Visualizations, not currently enabled.
485
- return (ys, importance_score, next_window_state, decoder_state, viz_dict)
486
-
487
- def init_decoder_state_vanilla(
488
- self, sequence_length: int, start_of_sequence: Array
489
- ) -> DecoderState:
490
- """Initialize decoder state for autoregressive generation.
491
-
492
- Args:
493
- sequence_length: The maximum length of the sequence to generate.
494
- start_of_sequence: Array of boolean of shape (batch_size,) True if
495
- starting a new sequence (with no prefix).
496
-
497
- Returns:
498
- A state object that can be passed to __call__.
499
- """
500
-
501
- if not self.use_causal_mask:
502
- raise ValueError("Generator must have been trained with a causal mask.")
503
-
504
- # Get relative position bias.
505
- rel_position_bias = self.relative_positions(
506
- 1, self.window_length, offset=self.window_length, bidirectional=False
507
- )
508
- rel_position_bias = jnp.tile(rel_position_bias, (self.batch_size, 1, 1, 1))
509
-
510
- # Initialize autoregressive storage for (key, value) pairs.
511
- # Include space for a prefix of window_length tokens.
512
- num_keys = sequence_length + self.window_length
513
- stored_shape = (self.batch_size, num_keys, self.num_heads, self.head_size)
514
- stored_keys = jnp.zeros(stored_shape, dtype=self.dtype)
515
- stored_values = jnp.zeros(stored_shape, dtype=self.dtype)
516
-
517
- recurrent_kvq = None
518
- current_index = jnp.array([self.window_length] * self.batch_size)
519
-
520
- decoder_state_dict = {
521
- "keys": stored_keys,
522
- "values": stored_values,
523
- "current_index": current_index,
524
- "relative_position_bias": rel_position_bias,
525
- "recurrent_kvq": recurrent_kvq,
526
- }
527
- return DecoderState(decoder_state_dict)
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """A single transformer layer in inference mode.
17
+
18
+ Modified
19
+ https://github.com/google-research/meliad/blob/main/transformer/transformer_layer.py
20
+ To accommodate sequence packing + kv cache + relative position during test time.
21
+ """
22
+
23
+ from typing import Callable, Mapping, NewType, Optional, Tuple
24
+
25
+ from absl import logging
26
+ import gin
27
+ import jax
28
+ import jax.numpy as jnp
29
+ from meliad_lib.meliad.transformer import attention
30
+ from meliad_lib.meliad.transformer import nn_components
31
+ from meliad_lib.meliad.transformer import position
32
+ from meliad_lib.meliad.transformer import transformer_layer
33
+
34
+ Array = jnp.ndarray
35
+ DecoderState = NewType("DecoderState", Mapping[str, Array])
36
+ WindowState = Optional[Tuple[attention.KVITuple, Array]]
37
+
38
+
39
+ @jax.vmap
40
+ def update_slice_in_dim_1(array: Array, update: Array, idx: Array) -> Array:
41
+ """Update a stored keys/values slice for different-lengthed seqs in batch."""
42
+ return jax.lax.dynamic_update_slice_in_dim(array, update, idx, axis=0)
43
+
44
+
45
+ def slice_in_dim_1(window_length: int) -> Callable[[Array, Array], Array]:
46
+ @jax.vmap
47
+ def fn(array: Array, idx: Array) -> Array:
48
+ return jax.lax.dynamic_slice_in_dim(array, idx, window_length, axis=0)
49
+
50
+ return fn
51
+
52
+
53
+ @gin.configurable
54
+ class TransformerLayerGenerate(transformer_layer.TransformerLayer):
55
+ """Full transformer layer, with attention."""
56
+
57
+ def _next_decoder_state(
58
+ self, decoder_state: DecoderState, keys: Array, values: Array
59
+ ) -> Tuple[DecoderState, Array, Array]:
60
+ """Compute the next decoder state, and return keys,values to attend to.
61
+
62
+ The keys,values returned from this function are drawn from the prior
63
+ decoding state, and comprise a full window of local context.
64
+
65
+ Args:
66
+ decoder_state: The current decoder state, initially created using
67
+ init_decoder_state().
68
+ keys: The key for the current token, of shape (batch_size, 1, dim)
69
+ values: The value for the current token of shape (batch_size, 1, dim)
70
+
71
+ Returns:
72
+ (next_decoder_state,
73
+ window of keys of shape (batch_size, window_length, dim),
74
+ window of values of shape (batch_size, window_length, dim))
75
+ """
76
+
77
+ assert keys.shape[1] == 1 # single-token autoregressive decoding.
78
+
79
+ # Unpack decoder_state
80
+ stored_keys = decoder_state["keys"]
81
+ stored_values = decoder_state["values"]
82
+ curr_index = decoder_state["current_index"]
83
+
84
+ # Slice to get window_length-sized chunk of previous keys,values.
85
+ out_decoder_state = {}
86
+ curr_win_index = curr_index - self.window_length
87
+
88
+ # out_keys = jax.lax.dynamic_slice_in_dim(
89
+ # stored_keys, curr_win_index, self.window_length, axis=1)
90
+ out_keys = slice_in_dim_1(self.window_length)(stored_keys, curr_win_index)
91
+
92
+ # out_values = jax.lax.dynamic_slice_in_dim(
93
+ # stored_values, curr_win_index, self.window_length, axis=1)
94
+ out_values = slice_in_dim_1(self.window_length)(
95
+ stored_values, curr_win_index
96
+ )
97
+
98
+ # Write current keys,values to stored keys, values.
99
+ # stored_keys = jax.lax.dynamic_update_slice_in_dim(
100
+ # stored_keys, keys, curr_index, axis=1)
101
+ stored_keys = update_slice_in_dim_1(stored_keys, keys, curr_index)
102
+ # stored_values = jax.lax.dynamic_update_slice_in_dim(
103
+ # stored_values, values, curr_index, axis=1)
104
+ stored_values = update_slice_in_dim_1(stored_values, values, curr_index)
105
+ curr_index = curr_index + 1
106
+
107
+ # Pack a new decoder_state object.
108
+ out_decoder_state["keys"] = stored_keys
109
+ out_decoder_state["values"] = stored_values
110
+ out_decoder_state["current_index"] = curr_index
111
+ out_decoder_state["relative_position_bias"] = decoder_state[
112
+ "relative_position_bias"
113
+ ]
114
+ out_decoder_state["recurrent_kvq"] = decoder_state["recurrent_kvq"]
115
+
116
+ return (DecoderState(out_decoder_state), out_keys, out_values)
117
+
118
+ def __call__(
119
+ self,
120
+ xs: Array,
121
+ start_of_sequence: Array,
122
+ *,
123
+ importance: Optional[Array] = None,
124
+ cross_attention_kv: Optional[Tuple[Array, Array]] = None,
125
+ window_state: Optional[WindowState] = None,
126
+ decoder_state: Optional[DecoderState] = None,
127
+ ):
128
+ """Computes attention over a sequence of inputs.
129
+
130
+ Args:
131
+ xs: input sequence of shape (batch_size, sequence_length, num_hidden)
132
+ start_of_sequence: An input array of shape (batch_size) --- The following
133
+ must be passed by keyword only. ---
134
+ importance: Array of shape (batch_size, sequence_length). An importance
135
+ bias for attention.
136
+ cross_attention_kv: Keys and values from encoder for cross-attention.
137
+ window_state: State object which contains context from the prior window
138
+ when using a transformer-XL or sliding window. Initially created with
139
+ load_window_state().
140
+ decoder_state: State object for autoregressive decoding, initially created
141
+ with from init_decoder_state().
142
+
143
+ Returns:
144
+ (ys: outputs of shape (batch_size, sequence_length, num_hidden),
145
+ importance_score: importance score for the next layer,
146
+ next_window_state: state to pass to the next window,
147
+ next_decoder_state: next decoder state for autoregressive decoding,
148
+ viz_dict: dictionary of visualizations
149
+ )
150
+ """
151
+
152
+ xs = jnp.asarray(xs, dtype=self.dtype)
153
+ logging.info("tlayer: recurrent = %r", self.recurrent_attention)
154
+ logging.info("tlayer: compute_importance = %r", self.compute_importance)
155
+
156
+ is_training = self.mode == "train"
157
+
158
+ # Compute keys, values and queries.
159
+ # ---------------------------------
160
+ logging.info("tlayer: compute keys,values,queries.")
161
+ (keys, values, queries, queries2) = self.tbase.kvq(xs)
162
+ attention_scale_factors = self.tbase.attention_scale_factors()
163
+ (_, sequence_length, num_heads, _) = queries.shape # (b, k, h, d)
164
+
165
+ # Get biases and masks that are shared across windows.
166
+ # ----------------------------------------------------
167
+ if decoder_state is not None:
168
+ logging.info("tlayer: using autoregressive decoder.")
169
+ # When decoding, prior keys,values are loaded from the decoder state.
170
+ # Other values are precomputed, and loaded from the decoder state.
171
+ # The decoder state will be updated with the current token.
172
+ assert window_state is None
173
+
174
+ prev_kvi = None
175
+ recurrent_state = None # Use precomputed recurrent_kvq.
176
+ cross_attention_kv = None
177
+ rel_position_bias = decoder_state["relative_position_bias"]
178
+ causal_mask = None
179
+ dropout_multiplier = None
180
+
181
+ # Reuse cached recurrent keys,values for each token.
182
+ cached_recurrent_kvq = decoder_state["recurrent_kvq"]
183
+ if cached_recurrent_kvq is not None:
184
+ assert cross_attention_kv is None
185
+ cross_attention_kv = (cached_recurrent_kvq[0], cached_recurrent_kvq[1])
186
+ del cached_recurrent_kvq
187
+
188
+ # Get a full window of keys,values and update decoder state.
189
+ (decoder_state, keys, values) = self._next_decoder_state(
190
+ decoder_state, keys, values
191
+ )
192
+
193
+ # Each query attends to window_length prior keys.
194
+ assert keys.shape[1] == self.window_length
195
+ kq_relative_offset = self.window_length
196
+
197
+ if not self.use_long_xl_architecture:
198
+ kqpos = position.relative_positions(
199
+ 1, self.window_length, offset=0
200
+ ) # 2D mask
201
+ current_idx = decoder_state["current_index"]
202
+
203
+ # add (batch, heads) dims for kqpos
204
+ kqpos = jnp.expand_dims(kqpos, axis=(0, 1))
205
+ kqpos = jnp.tile(kqpos, (1, self.num_heads, 1, 1))
206
+
207
+ # add (_, heads, _) dim for current_idx
208
+ current_idx = jnp.expand_dims(current_idx, axis=(1, 2, 3))
209
+
210
+ causal_mask = kqpos > self.window_length * 2 - current_idx
211
+ else:
212
+ logging.info("tlayer: windowed attention.")
213
+ # When training, attention is done using windows or chunks, and prior
214
+ # context (e.g. keys,values from the previous window) is stored in the
215
+ # window_state object.
216
+ (prev_kvi, recurrent_state) = (
217
+ window_state # pytype: disable=attribute-error
218
+ )
219
+
220
+ # Get the size of the sliding window for pos bias, dropout, & causal mask.
221
+ (num_queries, num_keys) = attention.sliding_attention_window_shape(
222
+ (keys, values, importance),
223
+ prev_kvi,
224
+ queries,
225
+ window_length=self.window_length,
226
+ )
227
+ kq_relative_offset = num_keys - num_queries
228
+
229
+ # Get the relative position bias.
230
+ # The bias doesn't depend on the query content, and so can be precomputed.
231
+ if self.relative_positions is not None:
232
+ rel_position_bias = self.relative_positions(
233
+ num_queries, num_keys, bidirectional=False
234
+ )
235
+ else:
236
+ rel_position_bias = None
237
+
238
+ # Get causal mask.
239
+ if self.use_causal_mask:
240
+ causal_mask = position.causal_mask(
241
+ num_queries, num_keys, window_length=self.window_length
242
+ )
243
+ else:
244
+ causal_mask = None
245
+
246
+ # Apply dropout to the attention matrix.
247
+ # The mask will be broadcast across batches and windows.
248
+ if self.attn_dropout_rate > 0.0 and is_training:
249
+ dropout_rng = self.make_rng("dropout")
250
+ attn_shape = (self.num_heads, num_queries, num_keys)
251
+ dropout_multiplier = nn_components.dropout_multiplier_mask(
252
+ dropout_rng, self.attn_dropout_rate, attn_shape, self.dtype
253
+ )
254
+ else:
255
+ dropout_multiplier = None
256
+
257
+ # Load and store values into external memory, if memory is not None.
258
+ # ------------------------------------------------------------------
259
+ (mode, _, update_memory) = self._get_cache_name_from_mode(self.mode)
260
+ external_kv = self._query_external_memory(
261
+ keys,
262
+ values,
263
+ queries,
264
+ start_of_sequence=start_of_sequence,
265
+ mode=mode,
266
+ update_memory=decoder_state is None and update_memory,
267
+ )
268
+
269
+ if (
270
+ self.memory is not None
271
+ and self.memory_combine_with_local == "TRAINABLE_WEIGHTED_MEAN"
272
+ ):
273
+ external_memory_bias = jnp.asarray(self.memory_bias, dtype=self.dtype)
274
+ external_memory_bias = jnp.reshape(
275
+ external_memory_bias, (1, 1, num_heads, 1)
276
+ )
277
+ external_memory_bias = jax.nn.sigmoid(external_memory_bias)
278
+ else:
279
+ external_memory_bias = None
280
+
281
+ # Compute the number of windows.
282
+ # ------------------------------
283
+ if sequence_length < self.window_length:
284
+ num_windows = 1 # Happens with autoregressive decoding.
285
+ elif sequence_length == self.window_length:
286
+ num_windows = 1
287
+ if self.use_long_xl_architecture:
288
+ assert prev_kvi is not None
289
+ else:
290
+ if not self.use_long_xl_architecture:
291
+ raise ValueError("Can only use sliding window with Transformer XL.")
292
+ num_windows = sequence_length // self.window_length
293
+ if (num_windows * self.window_length) != sequence_length:
294
+ raise ValueError(
295
+ f"Window length {self.window_length} must be a "
296
+ + f"multiple of sequence length {sequence_length}"
297
+ )
298
+ logging.info("tlayer: num_windows = %d.", num_windows)
299
+
300
+ # Define the function to do attention within a single window.
301
+ # ---------------------------------------------------------
302
+ def single_window_attention(
303
+ carry: tuple[Array, Array], inputs_w: tuple[Array, Array]
304
+ ) -> tuple[tuple[Array, Array], tuple[Array, Array]]:
305
+ # This function uses the following variables from the outer scope.
306
+ # They are listed here for clarity.
307
+ nonlocal rel_position_bias
308
+ nonlocal causal_mask
309
+ nonlocal kq_relative_offset
310
+ nonlocal dropout_multiplier
311
+ nonlocal attention_scale_factors
312
+ nonlocal external_memory_bias
313
+ nonlocal cross_attention_kv # externally supplied.
314
+
315
+ # keys,values,queries over the whole sequence will be split into chunks.
316
+ # xs_w, kvqi_w, etc. are the chunk for the current window.
317
+ (prev_kvi_w, rec_state) = carry # carried from one window to the next.
318
+ (kvqi_w, external_kv_w) = inputs_w # inputs to the current window.
319
+ # (keys_curr_w, values_curr_w, _, _, importance_curr_w) = kvqi_w
320
+
321
+ # Concatenate keys,values from the previous window with the current
322
+ # window to implement sliding window attention.
323
+ (kvqi_w, next_kvi_w) = attention.concat_kvqi(kvqi_w, prev_kvi_w)
324
+ (keys_w, values_w, queries_w, queries2_w, importance_w) = kvqi_w
325
+
326
+ # Perform recurrent attention within the current window to get the next
327
+ # recurrent state, and set up cross attention.
328
+ if rec_state is not None:
329
+ logging.info("tlayer: recurrent attention.")
330
+
331
+ # NOTE -- recurrent states and input tokens are handled separately,
332
+ # because they have separate learned positional embeddings. Due to
333
+ # the way TransformerBase does cross-attention, this means that we use
334
+ # separate key,value layers for rec_state and tokens_w.
335
+
336
+ # Keys, values, queries from recurrent state.
337
+ logging.info("tlayer: recurrent kvq.")
338
+ rec_kvq = self.recurrent_tbase.kvq(rec_state)
339
+ r_scale_factors = self.recurrent_tbase.attention_scale_factors()
340
+ (r_keys, r_values, r_queries, r_queries2) = rec_kvq
341
+
342
+ # Joint attention over both recurrent states and input tokens.
343
+ logging.info("tlayer: recurrent self-attention.")
344
+ r_attn_ys = attention.simple_attention(
345
+ r_keys,
346
+ r_values,
347
+ r_queries,
348
+ None,
349
+ scale_factor=r_scale_factors[0],
350
+ dtype=self.dtype,
351
+ )
352
+
353
+ logging.info("tlayer: recurrent cross-attention.")
354
+ r_cross_attn_ys = attention.simple_attention(
355
+ keys_w,
356
+ values_w,
357
+ r_queries2,
358
+ importance_w,
359
+ scale_factor=r_scale_factors[1],
360
+ dtype=self.dtype,
361
+ )
362
+
363
+ # Recurrent post-attention FFN.
364
+ logging.info("tlayer: recurrent ffn.")
365
+ next_rec_state = self.recurrent_tbase.post_attn_ffn(
366
+ rec_state, r_attn_ys, r_cross_attn_ys
367
+ )
368
+
369
+ # Get keys and values for cross-attention from recurrent state.
370
+ assert cross_attention_kv is None
371
+ local_cross_attention_kv = (r_keys, r_values)
372
+ else:
373
+ # Get keys and values for cross-attention from external argument.
374
+ next_rec_state = None
375
+ local_cross_attention_kv = cross_attention_kv
376
+
377
+ # If using RoPE, keys and queries are rotated before self-attention.
378
+ if self.relative_position_type == "rotary":
379
+ logging.info(
380
+ "Using rotary position encodings (RoPE), offset = %d",
381
+ kq_relative_offset,
382
+ )
383
+ (keys_w, queries_w) = position.rotate_kq(
384
+ keys_w, queries_w, max_wavelength=10_000, offset=kq_relative_offset
385
+ )
386
+
387
+ # Self-attention over input tokens.
388
+ logging.info("tlayer: self-attention.")
389
+ attn_ys_w = attention.simple_attention(
390
+ keys_w,
391
+ values_w,
392
+ queries_w,
393
+ importance_w,
394
+ relative_position_bias=rel_position_bias,
395
+ scale_factor=attention_scale_factors[0],
396
+ causal_mask=causal_mask,
397
+ dropout_multiplier=dropout_multiplier,
398
+ dtype=self.dtype,
399
+ )
400
+
401
+ # Attention over external memory.
402
+ if external_kv_w is not None:
403
+ (external_keys_w, external_values_w) = external_kv_w
404
+ y_ext = attention.external_attention(
405
+ external_keys_w,
406
+ external_values_w,
407
+ queries_w,
408
+ scale_factor=attention_scale_factors[0],
409
+ )
410
+ if external_memory_bias is not None:
411
+ ebias = external_memory_bias
412
+ attn_ys_w = (attn_ys_w * (1 - ebias)) + (y_ext * ebias)
413
+ elif self.memory_combine_with_local == "ADD":
414
+ attn_ys_w += y_ext
415
+ elif self.memory_combine_with_local == "STOP_FORWARD":
416
+ attn_ys_w = y_ext + (attn_ys_w - jax.lax.stop_gradient(attn_ys_w))
417
+ else:
418
+ raise ValueError(
419
+ f"Unexpected setting: {self.memory_combine_with_local = }"
420
+ )
421
+
422
+ # Cross attention from input tokens to encoder or recurrent state.
423
+ if local_cross_attention_kv is not None:
424
+ logging.info("tlayer: cross-attention.")
425
+ (c_keys, c_values) = local_cross_attention_kv
426
+
427
+ # Cross-attention using queries2.
428
+ cross_attn_ys_w = attention.simple_attention(
429
+ c_keys,
430
+ c_values,
431
+ queries2_w,
432
+ None,
433
+ scale_factor=attention_scale_factors[1],
434
+ dtype=self.dtype,
435
+ )
436
+ else:
437
+ cross_attn_ys_w = None
438
+
439
+ # End function single_window_attention(...)
440
+ return ((next_kvi_w, next_rec_state), (attn_ys_w, cross_attn_ys_w))
441
+
442
+ # Initialize recurrent_tbase before calling jax.lax.scan.
443
+ # Otherwise flax will throw a tantrum.
444
+ if (
445
+ self.recurrent_attention
446
+ and 0 <= self.max_unrolled_windows
447
+ and self.max_unrolled_windows < num_windows
448
+ ):
449
+ logging.info("tlayer: force initialization of recurrent_tbase.")
450
+ self.recurrent_tbase.force_init(recurrent_state)
451
+
452
+ # Perform sliding window attention over all keys,values,queries.
453
+ # --------------------------------------------------------------
454
+ initial_carry = (prev_kvi, recurrent_state) # window state.
455
+ kvqi = (keys, values, queries, queries2, importance)
456
+ attn_inputs = (kvqi, external_kv)
457
+ (next_carry, attn_outputs) = attention.split_and_scan(
458
+ single_window_attention,
459
+ initial_carry,
460
+ attn_inputs,
461
+ sections=num_windows,
462
+ axis=1,
463
+ max_unrolled_windows=self.max_unrolled_windows,
464
+ )
465
+ (attn_ys, cross_attn_ys) = attn_outputs
466
+
467
+ logging.info("tlayer: End windows.")
468
+
469
+ # Post-attention MLP, resnet, and FFN.
470
+ # ------------------------------------
471
+ logging.info("tlayer: final FFN.")
472
+ ys = self.tbase.post_attn_ffn(xs, attn_ys, cross_attn_ys)
473
+
474
+ # Compute importance scores for each token if requested.
475
+ if self.compute_importance:
476
+ (batch_size, sequence_length, _) = ys.shape
477
+ importance_score = self.importance_layer(ys)
478
+ importance_score = importance_score.reshape((batch_size, sequence_length))
479
+ else:
480
+ importance_score = None
481
+
482
+ next_window_state = next_carry if window_state is not None else None
483
+ viz_dict = {} # Visualizations, not currently enabled.
484
+ return (ys, importance_score, next_window_state, decoder_state, viz_dict)
485
+
486
+ def init_decoder_state_vanilla(
487
+ self, sequence_length: int, start_of_sequence: Array
488
+ ) -> DecoderState:
489
+ """Initialize decoder state for autoregressive generation.
490
+
491
+ Args:
492
+ sequence_length: The maximum length of the sequence to generate.
493
+ start_of_sequence: Array of boolean of shape (batch_size,) True if
494
+ starting a new sequence (with no prefix).
495
+
496
+ Returns:
497
+ A state object that can be passed to __call__.
498
+ """
499
+
500
+ if not self.use_causal_mask:
501
+ raise ValueError("Generator must have been trained with a causal mask.")
502
+
503
+ # Get relative position bias.
504
+ rel_position_bias = self.relative_positions(
505
+ 1, self.window_length, offset=self.window_length, bidirectional=False
506
+ )
507
+ rel_position_bias = jnp.tile(rel_position_bias, (self.batch_size, 1, 1, 1))
508
+
509
+ # Initialize autoregressive storage for (key, value) pairs.
510
+ # Include space for a prefix of window_length tokens.
511
+ num_keys = sequence_length + self.window_length
512
+ stored_shape = (self.batch_size, num_keys, self.num_heads, self.head_size)
513
+ stored_keys = jnp.zeros(stored_shape, dtype=self.dtype)
514
+ stored_values = jnp.zeros(stored_shape, dtype=self.dtype)
515
+
516
+ recurrent_kvq = None
517
+ current_index = jnp.array([self.window_length] * self.batch_size)
518
+
519
+ decoder_state_dict = {
520
+ "keys": stored_keys,
521
+ "values": stored_values,
522
+ "current_index": current_index,
523
+ "relative_position_bias": rel_position_bias,
524
+ "recurrent_kvq": recurrent_kvq,
525
+ }
526
+ return DecoderState(decoder_state_dict)