Sebastien commited on
Commit
4484b8a
·
0 Parent(s):

first commit

Browse files
.github/workflows/python-app.yml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This workflow will install Python dependencies, run tests and lint with a single version of Python
2
+ # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3
+
4
+ name: Python application
5
+
6
+ on:
7
+ push:
8
+ branches: [ "main" ]
9
+ pull_request:
10
+ branches: [ "main" ]
11
+
12
+ permissions:
13
+ contents: read
14
+
15
+ jobs:
16
+ build:
17
+
18
+ runs-on: ubuntu-latest
19
+
20
+ steps:
21
+ - uses: actions/checkout@v4
22
+ - name: Set up Python 3.10
23
+ uses: actions/setup-python@v3
24
+ with:
25
+ python-version: "3.10"
26
+ - name: Cache pip
27
+ uses: actions/cache@v3
28
+ with:
29
+ path: ~/.cache/pip
30
+ key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
31
+ restore-keys: |
32
+ ${{ runner.os }}-pip-
33
+ - name: Install dependencies
34
+ run: |
35
+ python -m pip install --upgrade pip
36
+ pip install flake8 pytest
37
+ if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
38
+ - name: Lint with flake8
39
+ run: |
40
+ # stop the build if there are Python syntax errors or undefined names
41
+ flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
42
+ # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
43
+ flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
44
+ - name: Test with pytest
45
+ run: |
46
+ PYTHONPATH=. pytest tests
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
LICENSE ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU AFFERO GENERAL PUBLIC LICENSE
2
+ Version 3, 19 November 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU Affero General Public License is a free, copyleft license for
11
+ software and other kinds of works, specifically designed to ensure
12
+ cooperation with the community in the case of network server software.
13
+
14
+ The licenses for most software and other practical works are designed
15
+ to take away your freedom to share and change the works. By contrast,
16
+ our General Public Licenses are intended to guarantee your freedom to
17
+ share and change all versions of a program--to make sure it remains free
18
+ software for all its users.
19
+
20
+ When we speak of free software, we are referring to freedom, not
21
+ price. Our General Public Licenses are designed to make sure that you
22
+ have the freedom to distribute copies of free software (and charge for
23
+ them if you wish), that you receive source code or can get it if you
24
+ want it, that you can change the software or use pieces of it in new
25
+ free programs, and that you know you can do these things.
26
+
27
+ Developers that use our General Public Licenses protect your rights
28
+ with two steps: (1) assert copyright on the software, and (2) offer
29
+ you this License which gives you legal permission to copy, distribute
30
+ and/or modify the software.
31
+
32
+ A secondary benefit of defending all users' freedom is that
33
+ improvements made in alternate versions of the program, if they
34
+ receive widespread use, become available for other developers to
35
+ incorporate. Many developers of free software are heartened and
36
+ encouraged by the resulting cooperation. However, in the case of
37
+ software used on network servers, this result may fail to come about.
38
+ The GNU General Public License permits making a modified version and
39
+ letting the public access it on a server without ever releasing its
40
+ source code to the public.
41
+
42
+ The GNU Affero General Public License is designed specifically to
43
+ ensure that, in such cases, the modified source code becomes available
44
+ to the community. It requires the operator of a network server to
45
+ provide the source code of the modified version running there to the
46
+ users of that server. Therefore, public use of a modified version, on
47
+ a publicly accessible server, gives the public access to the source
48
+ code of the modified version.
49
+
50
+ An older license, called the Affero General Public License and
51
+ published by Affero, was designed to accomplish similar goals. This is
52
+ a different license, not a version of the Affero GPL, but Affero has
53
+ released a new version of the Affero GPL which permits relicensing under
54
+ this license.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ TERMS AND CONDITIONS
60
+
61
+ 0. Definitions.
62
+
63
+ "This License" refers to version 3 of the GNU Affero General Public License.
64
+
65
+ "Copyright" also means copyright-like laws that apply to other kinds of
66
+ works, such as semiconductor masks.
67
+
68
+ "The Program" refers to any copyrightable work licensed under this
69
+ License. Each licensee is addressed as "you". "Licensees" and
70
+ "recipients" may be individuals or organizations.
71
+
72
+ To "modify" a work means to copy from or adapt all or part of the work
73
+ in a fashion requiring copyright permission, other than the making of an
74
+ exact copy. The resulting work is called a "modified version" of the
75
+ earlier work or a work "based on" the earlier work.
76
+
77
+ A "covered work" means either the unmodified Program or a work based
78
+ on the Program.
79
+
80
+ To "propagate" a work means to do anything with it that, without
81
+ permission, would make you directly or secondarily liable for
82
+ infringement under applicable copyright law, except executing it on a
83
+ computer or modifying a private copy. Propagation includes copying,
84
+ distribution (with or without modification), making available to the
85
+ public, and in some countries other activities as well.
86
+
87
+ To "convey" a work means any kind of propagation that enables other
88
+ parties to make or receive copies. Mere interaction with a user through
89
+ a computer network, with no transfer of a copy, is not conveying.
90
+
91
+ An interactive user interface displays "Appropriate Legal Notices"
92
+ to the extent that it includes a convenient and prominently visible
93
+ feature that (1) displays an appropriate copyright notice, and (2)
94
+ tells the user that there is no warranty for the work (except to the
95
+ extent that warranties are provided), that licensees may convey the
96
+ work under this License, and how to view a copy of this License. If
97
+ the interface presents a list of user commands or options, such as a
98
+ menu, a prominent item in the list meets this criterion.
99
+
100
+ 1. Source Code.
101
+
102
+ The "source code" for a work means the preferred form of the work
103
+ for making modifications to it. "Object code" means any non-source
104
+ form of a work.
105
+
106
+ A "Standard Interface" means an interface that either is an official
107
+ standard defined by a recognized standards body, or, in the case of
108
+ interfaces specified for a particular programming language, one that
109
+ is widely used among developers working in that language.
110
+
111
+ The "System Libraries" of an executable work include anything, other
112
+ than the work as a whole, that (a) is included in the normal form of
113
+ packaging a Major Component, but which is not part of that Major
114
+ Component, and (b) serves only to enable use of the work with that
115
+ Major Component, or to implement a Standard Interface for which an
116
+ implementation is available to the public in source code form. A
117
+ "Major Component", in this context, means a major essential component
118
+ (kernel, window system, and so on) of the specific operating system
119
+ (if any) on which the executable work runs, or a compiler used to
120
+ produce the work, or an object code interpreter used to run it.
121
+
122
+ The "Corresponding Source" for a work in object code form means all
123
+ the source code needed to generate, install, and (for an executable
124
+ work) run the object code and to modify the work, including scripts to
125
+ control those activities. However, it does not include the work's
126
+ System Libraries, or general-purpose tools or generally available free
127
+ programs which are used unmodified in performing those activities but
128
+ which are not part of the work. For example, Corresponding Source
129
+ includes interface definition files associated with source files for
130
+ the work, and the source code for shared libraries and dynamically
131
+ linked subprograms that the work is specifically designed to require,
132
+ such as by intimate data communication or control flow between those
133
+ subprograms and other parts of the work.
134
+
135
+ The Corresponding Source need not include anything that users
136
+ can regenerate automatically from other parts of the Corresponding
137
+ Source.
138
+
139
+ The Corresponding Source for a work in source code form is that
140
+ same work.
141
+
142
+ 2. Basic Permissions.
143
+
144
+ All rights granted under this License are granted for the term of
145
+ copyright on the Program, and are irrevocable provided the stated
146
+ conditions are met. This License explicitly affirms your unlimited
147
+ permission to run the unmodified Program. The output from running a
148
+ covered work is covered by this License only if the output, given its
149
+ content, constitutes a covered work. This License acknowledges your
150
+ rights of fair use or other equivalent, as provided by copyright law.
151
+
152
+ You may make, run and propagate covered works that you do not
153
+ convey, without conditions so long as your license otherwise remains
154
+ in force. You may convey covered works to others for the sole purpose
155
+ of having them make modifications exclusively for you, or provide you
156
+ with facilities for running those works, provided that you comply with
157
+ the terms of this License in conveying all material for which you do
158
+ not control copyright. Those thus making or running the covered works
159
+ for you must do so exclusively on your behalf, under your direction
160
+ and control, on terms that prohibit them from making any copies of
161
+ your copyrighted material outside their relationship with you.
162
+
163
+ Conveying under any other circumstances is permitted solely under
164
+ the conditions stated below. Sublicensing is not allowed; section 10
165
+ makes it unnecessary.
166
+
167
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168
+
169
+ No covered work shall be deemed part of an effective technological
170
+ measure under any applicable law fulfilling obligations under article
171
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172
+ similar laws prohibiting or restricting circumvention of such
173
+ measures.
174
+
175
+ When you convey a covered work, you waive any legal power to forbid
176
+ circumvention of technological measures to the extent such circumvention
177
+ is effected by exercising rights under this License with respect to
178
+ the covered work, and you disclaim any intention to limit operation or
179
+ modification of the work as a means of enforcing, against the work's
180
+ users, your or third parties' legal rights to forbid circumvention of
181
+ technological measures.
182
+
183
+ 4. Conveying Verbatim Copies.
184
+
185
+ You may convey verbatim copies of the Program's source code as you
186
+ receive it, in any medium, provided that you conspicuously and
187
+ appropriately publish on each copy an appropriate copyright notice;
188
+ keep intact all notices stating that this License and any
189
+ non-permissive terms added in accord with section 7 apply to the code;
190
+ keep intact all notices of the absence of any warranty; and give all
191
+ recipients a copy of this License along with the Program.
192
+
193
+ You may charge any price or no price for each copy that you convey,
194
+ and you may offer support or warranty protection for a fee.
195
+
196
+ 5. Conveying Modified Source Versions.
197
+
198
+ You may convey a work based on the Program, or the modifications to
199
+ produce it from the Program, in the form of source code under the
200
+ terms of section 4, provided that you also meet all of these conditions:
201
+
202
+ a) The work must carry prominent notices stating that you modified
203
+ it, and giving a relevant date.
204
+
205
+ b) The work must carry prominent notices stating that it is
206
+ released under this License and any conditions added under section
207
+ 7. This requirement modifies the requirement in section 4 to
208
+ "keep intact all notices".
209
+
210
+ c) You must license the entire work, as a whole, under this
211
+ License to anyone who comes into possession of a copy. This
212
+ License will therefore apply, along with any applicable section 7
213
+ additional terms, to the whole of the work, and all its parts,
214
+ regardless of how they are packaged. This License gives no
215
+ permission to license the work in any other way, but it does not
216
+ invalidate such permission if you have separately received it.
217
+
218
+ d) If the work has interactive user interfaces, each must display
219
+ Appropriate Legal Notices; however, if the Program has interactive
220
+ interfaces that do not display Appropriate Legal Notices, your
221
+ work need not make them do so.
222
+
223
+ A compilation of a covered work with other separate and independent
224
+ works, which are not by their nature extensions of the covered work,
225
+ and which are not combined with it such as to form a larger program,
226
+ in or on a volume of a storage or distribution medium, is called an
227
+ "aggregate" if the compilation and its resulting copyright are not
228
+ used to limit the access or legal rights of the compilation's users
229
+ beyond what the individual works permit. Inclusion of a covered work
230
+ in an aggregate does not cause this License to apply to the other
231
+ parts of the aggregate.
232
+
233
+ 6. Conveying Non-Source Forms.
234
+
235
+ You may convey a covered work in object code form under the terms
236
+ of sections 4 and 5, provided that you also convey the
237
+ machine-readable Corresponding Source under the terms of this License,
238
+ in one of these ways:
239
+
240
+ a) Convey the object code in, or embodied in, a physical product
241
+ (including a physical distribution medium), accompanied by the
242
+ Corresponding Source fixed on a durable physical medium
243
+ customarily used for software interchange.
244
+
245
+ b) Convey the object code in, or embodied in, a physical product
246
+ (including a physical distribution medium), accompanied by a
247
+ written offer, valid for at least three years and valid for as
248
+ long as you offer spare parts or customer support for that product
249
+ model, to give anyone who possesses the object code either (1) a
250
+ copy of the Corresponding Source for all the software in the
251
+ product that is covered by this License, on a durable physical
252
+ medium customarily used for software interchange, for a price no
253
+ more than your reasonable cost of physically performing this
254
+ conveying of source, or (2) access to copy the
255
+ Corresponding Source from a network server at no charge.
256
+
257
+ c) Convey individual copies of the object code with a copy of the
258
+ written offer to provide the Corresponding Source. This
259
+ alternative is allowed only occasionally and noncommercially, and
260
+ only if you received the object code with such an offer, in accord
261
+ with subsection 6b.
262
+
263
+ d) Convey the object code by offering access from a designated
264
+ place (gratis or for a charge), and offer equivalent access to the
265
+ Corresponding Source in the same way through the same place at no
266
+ further charge. You need not require recipients to copy the
267
+ Corresponding Source along with the object code. If the place to
268
+ copy the object code is a network server, the Corresponding Source
269
+ may be on a different server (operated by you or a third party)
270
+ that supports equivalent copying facilities, provided you maintain
271
+ clear directions next to the object code saying where to find the
272
+ Corresponding Source. Regardless of what server hosts the
273
+ Corresponding Source, you remain obligated to ensure that it is
274
+ available for as long as needed to satisfy these requirements.
275
+
276
+ e) Convey the object code using peer-to-peer transmission, provided
277
+ you inform other peers where the object code and Corresponding
278
+ Source of the work are being offered to the general public at no
279
+ charge under subsection 6d.
280
+
281
+ A separable portion of the object code, whose source code is excluded
282
+ from the Corresponding Source as a System Library, need not be
283
+ included in conveying the object code work.
284
+
285
+ A "User Product" is either (1) a "consumer product", which means any
286
+ tangible personal property which is normally used for personal, family,
287
+ or household purposes, or (2) anything designed or sold for incorporation
288
+ into a dwelling. In determining whether a product is a consumer product,
289
+ doubtful cases shall be resolved in favor of coverage. For a particular
290
+ product received by a particular user, "normally used" refers to a
291
+ typical or common use of that class of product, regardless of the status
292
+ of the particular user or of the way in which the particular user
293
+ actually uses, or expects or is expected to use, the product. A product
294
+ is a consumer product regardless of whether the product has substantial
295
+ commercial, industrial or non-consumer uses, unless such uses represent
296
+ the only significant mode of use of the product.
297
+
298
+ "Installation Information" for a User Product means any methods,
299
+ procedures, authorization keys, or other information required to install
300
+ and execute modified versions of a covered work in that User Product from
301
+ a modified version of its Corresponding Source. The information must
302
+ suffice to ensure that the continued functioning of the modified object
303
+ code is in no case prevented or interfered with solely because
304
+ modification has been made.
305
+
306
+ If you convey an object code work under this section in, or with, or
307
+ specifically for use in, a User Product, and the conveying occurs as
308
+ part of a transaction in which the right of possession and use of the
309
+ User Product is transferred to the recipient in perpetuity or for a
310
+ fixed term (regardless of how the transaction is characterized), the
311
+ Corresponding Source conveyed under this section must be accompanied
312
+ by the Installation Information. But this requirement does not apply
313
+ if neither you nor any third party retains the ability to install
314
+ modified object code on the User Product (for example, the work has
315
+ been installed in ROM).
316
+
317
+ The requirement to provide Installation Information does not include a
318
+ requirement to continue to provide support service, warranty, or updates
319
+ for a work that has been modified or installed by the recipient, or for
320
+ the User Product in which it has been modified or installed. Access to a
321
+ network may be denied when the modification itself materially and
322
+ adversely affects the operation of the network or violates the rules and
323
+ protocols for communication across the network.
324
+
325
+ Corresponding Source conveyed, and Installation Information provided,
326
+ in accord with this section must be in a format that is publicly
327
+ documented (and with an implementation available to the public in
328
+ source code form), and must require no special password or key for
329
+ unpacking, reading or copying.
330
+
331
+ 7. Additional Terms.
332
+
333
+ "Additional permissions" are terms that supplement the terms of this
334
+ License by making exceptions from one or more of its conditions.
335
+ Additional permissions that are applicable to the entire Program shall
336
+ be treated as though they were included in this License, to the extent
337
+ that they are valid under applicable law. If additional permissions
338
+ apply only to part of the Program, that part may be used separately
339
+ under those permissions, but the entire Program remains governed by
340
+ this License without regard to the additional permissions.
341
+
342
+ When you convey a copy of a covered work, you may at your option
343
+ remove any additional permissions from that copy, or from any part of
344
+ it. (Additional permissions may be written to require their own
345
+ removal in certain cases when you modify the work.) You may place
346
+ additional permissions on material, added by you to a covered work,
347
+ for which you have or can give appropriate copyright permission.
348
+
349
+ Notwithstanding any other provision of this License, for material you
350
+ add to a covered work, you may (if authorized by the copyright holders of
351
+ that material) supplement the terms of this License with terms:
352
+
353
+ a) Disclaiming warranty or limiting liability differently from the
354
+ terms of sections 15 and 16 of this License; or
355
+
356
+ b) Requiring preservation of specified reasonable legal notices or
357
+ author attributions in that material or in the Appropriate Legal
358
+ Notices displayed by works containing it; or
359
+
360
+ c) Prohibiting misrepresentation of the origin of that material, or
361
+ requiring that modified versions of such material be marked in
362
+ reasonable ways as different from the original version; or
363
+
364
+ d) Limiting the use for publicity purposes of names of licensors or
365
+ authors of the material; or
366
+
367
+ e) Declining to grant rights under trademark law for use of some
368
+ trade names, trademarks, or service marks; or
369
+
370
+ f) Requiring indemnification of licensors and authors of that
371
+ material by anyone who conveys the material (or modified versions of
372
+ it) with contractual assumptions of liability to the recipient, for
373
+ any liability that these contractual assumptions directly impose on
374
+ those licensors and authors.
375
+
376
+ All other non-permissive additional terms are considered "further
377
+ restrictions" within the meaning of section 10. If the Program as you
378
+ received it, or any part of it, contains a notice stating that it is
379
+ governed by this License along with a term that is a further
380
+ restriction, you may remove that term. If a license document contains
381
+ a further restriction but permits relicensing or conveying under this
382
+ License, you may add to a covered work material governed by the terms
383
+ of that license document, provided that the further restriction does
384
+ not survive such relicensing or conveying.
385
+
386
+ If you add terms to a covered work in accord with this section, you
387
+ must place, in the relevant source files, a statement of the
388
+ additional terms that apply to those files, or a notice indicating
389
+ where to find the applicable terms.
390
+
391
+ Additional terms, permissive or non-permissive, may be stated in the
392
+ form of a separately written license, or stated as exceptions;
393
+ the above requirements apply either way.
394
+
395
+ 8. Termination.
396
+
397
+ You may not propagate or modify a covered work except as expressly
398
+ provided under this License. Any attempt otherwise to propagate or
399
+ modify it is void, and will automatically terminate your rights under
400
+ this License (including any patent licenses granted under the third
401
+ paragraph of section 11).
402
+
403
+ However, if you cease all violation of this License, then your
404
+ license from a particular copyright holder is reinstated (a)
405
+ provisionally, unless and until the copyright holder explicitly and
406
+ finally terminates your license, and (b) permanently, if the copyright
407
+ holder fails to notify you of the violation by some reasonable means
408
+ prior to 60 days after the cessation.
409
+
410
+ Moreover, your license from a particular copyright holder is
411
+ reinstated permanently if the copyright holder notifies you of the
412
+ violation by some reasonable means, this is the first time you have
413
+ received notice of violation of this License (for any work) from that
414
+ copyright holder, and you cure the violation prior to 30 days after
415
+ your receipt of the notice.
416
+
417
+ Termination of your rights under this section does not terminate the
418
+ licenses of parties who have received copies or rights from you under
419
+ this License. If your rights have been terminated and not permanently
420
+ reinstated, you do not qualify to receive new licenses for the same
421
+ material under section 10.
422
+
423
+ 9. Acceptance Not Required for Having Copies.
424
+
425
+ You are not required to accept this License in order to receive or
426
+ run a copy of the Program. Ancillary propagation of a covered work
427
+ occurring solely as a consequence of using peer-to-peer transmission
428
+ to receive a copy likewise does not require acceptance. However,
429
+ nothing other than this License grants you permission to propagate or
430
+ modify any covered work. These actions infringe copyright if you do
431
+ not accept this License. Therefore, by modifying or propagating a
432
+ covered work, you indicate your acceptance of this License to do so.
433
+
434
+ 10. Automatic Licensing of Downstream Recipients.
435
+
436
+ Each time you convey a covered work, the recipient automatically
437
+ receives a license from the original licensors, to run, modify and
438
+ propagate that work, subject to this License. You are not responsible
439
+ for enforcing compliance by third parties with this License.
440
+
441
+ An "entity transaction" is a transaction transferring control of an
442
+ organization, or substantially all assets of one, or subdividing an
443
+ organization, or merging organizations. If propagation of a covered
444
+ work results from an entity transaction, each party to that
445
+ transaction who receives a copy of the work also receives whatever
446
+ licenses to the work the party's predecessor in interest had or could
447
+ give under the previous paragraph, plus a right to possession of the
448
+ Corresponding Source of the work from the predecessor in interest, if
449
+ the predecessor has it or can get it with reasonable efforts.
450
+
451
+ You may not impose any further restrictions on the exercise of the
452
+ rights granted or affirmed under this License. For example, you may
453
+ not impose a license fee, royalty, or other charge for exercise of
454
+ rights granted under this License, and you may not initiate litigation
455
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
456
+ any patent claim is infringed by making, using, selling, offering for
457
+ sale, or importing the Program or any portion of it.
458
+
459
+ 11. Patents.
460
+
461
+ A "contributor" is a copyright holder who authorizes use under this
462
+ License of the Program or a work on which the Program is based. The
463
+ work thus licensed is called the contributor's "contributor version".
464
+
465
+ A contributor's "essential patent claims" are all patent claims
466
+ owned or controlled by the contributor, whether already acquired or
467
+ hereafter acquired, that would be infringed by some manner, permitted
468
+ by this License, of making, using, or selling its contributor version,
469
+ but do not include claims that would be infringed only as a
470
+ consequence of further modification of the contributor version. For
471
+ purposes of this definition, "control" includes the right to grant
472
+ patent sublicenses in a manner consistent with the requirements of
473
+ this License.
474
+
475
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
476
+ patent license under the contributor's essential patent claims, to
477
+ make, use, sell, offer for sale, import and otherwise run, modify and
478
+ propagate the contents of its contributor version.
479
+
480
+ In the following three paragraphs, a "patent license" is any express
481
+ agreement or commitment, however denominated, not to enforce a patent
482
+ (such as an express permission to practice a patent or covenant not to
483
+ sue for patent infringement). To "grant" such a patent license to a
484
+ party means to make such an agreement or commitment not to enforce a
485
+ patent against the party.
486
+
487
+ If you convey a covered work, knowingly relying on a patent license,
488
+ and the Corresponding Source of the work is not available for anyone
489
+ to copy, free of charge and under the terms of this License, through a
490
+ publicly available network server or other readily accessible means,
491
+ then you must either (1) cause the Corresponding Source to be so
492
+ available, or (2) arrange to deprive yourself of the benefit of the
493
+ patent license for this particular work, or (3) arrange, in a manner
494
+ consistent with the requirements of this License, to extend the patent
495
+ license to downstream recipients. "Knowingly relying" means you have
496
+ actual knowledge that, but for the patent license, your conveying the
497
+ covered work in a country, or your recipient's use of the covered work
498
+ in a country, would infringe one or more identifiable patents in that
499
+ country that you have reason to believe are valid.
500
+
501
+ If, pursuant to or in connection with a single transaction or
502
+ arrangement, you convey, or propagate by procuring conveyance of, a
503
+ covered work, and grant a patent license to some of the parties
504
+ receiving the covered work authorizing them to use, propagate, modify
505
+ or convey a specific copy of the covered work, then the patent license
506
+ you grant is automatically extended to all recipients of the covered
507
+ work and works based on it.
508
+
509
+ A patent license is "discriminatory" if it does not include within
510
+ the scope of its coverage, prohibits the exercise of, or is
511
+ conditioned on the non-exercise of one or more of the rights that are
512
+ specifically granted under this License. You may not convey a covered
513
+ work if you are a party to an arrangement with a third party that is
514
+ in the business of distributing software, under which you make payment
515
+ to the third party based on the extent of your activity of conveying
516
+ the work, and under which the third party grants, to any of the
517
+ parties who would receive the covered work from you, a discriminatory
518
+ patent license (a) in connection with copies of the covered work
519
+ conveyed by you (or copies made from those copies), or (b) primarily
520
+ for and in connection with specific products or compilations that
521
+ contain the covered work, unless you entered into that arrangement,
522
+ or that patent license was granted, prior to 28 March 2007.
523
+
524
+ Nothing in this License shall be construed as excluding or limiting
525
+ any implied license or other defenses to infringement that may
526
+ otherwise be available to you under applicable patent law.
527
+
528
+ 12. No Surrender of Others' Freedom.
529
+
530
+ If conditions are imposed on you (whether by court order, agreement or
531
+ otherwise) that contradict the conditions of this License, they do not
532
+ excuse you from the conditions of this License. If you cannot convey a
533
+ covered work so as to satisfy simultaneously your obligations under this
534
+ License and any other pertinent obligations, then as a consequence you may
535
+ not convey it at all. For example, if you agree to terms that obligate you
536
+ to collect a royalty for further conveying from those to whom you convey
537
+ the Program, the only way you could satisfy both those terms and this
538
+ License would be to refrain entirely from conveying the Program.
539
+
540
+ 13. Remote Network Interaction; Use with the GNU General Public License.
541
+
542
+ Notwithstanding any other provision of this License, if you modify the
543
+ Program, your modified version must prominently offer all users
544
+ interacting with it remotely through a computer network (if your version
545
+ supports such interaction) an opportunity to receive the Corresponding
546
+ Source of your version by providing access to the Corresponding Source
547
+ from a network server at no charge, through some standard or customary
548
+ means of facilitating copying of software. This Corresponding Source
549
+ shall include the Corresponding Source for any work covered by version 3
550
+ of the GNU General Public License that is incorporated pursuant to the
551
+ following paragraph.
552
+
553
+ Notwithstanding any other provision of this License, you have
554
+ permission to link or combine any covered work with a work licensed
555
+ under version 3 of the GNU General Public License into a single
556
+ combined work, and to convey the resulting work. The terms of this
557
+ License will continue to apply to the part which is the covered work,
558
+ but the work with which it is combined will remain governed by version
559
+ 3 of the GNU General Public License.
560
+
561
+ 14. Revised Versions of this License.
562
+
563
+ The Free Software Foundation may publish revised and/or new versions of
564
+ the GNU Affero General Public License from time to time. Such new versions
565
+ will be similar in spirit to the present version, but may differ in detail to
566
+ address new problems or concerns.
567
+
568
+ Each version is given a distinguishing version number. If the
569
+ Program specifies that a certain numbered version of the GNU Affero General
570
+ Public License "or any later version" applies to it, you have the
571
+ option of following the terms and conditions either of that numbered
572
+ version or of any later version published by the Free Software
573
+ Foundation. If the Program does not specify a version number of the
574
+ GNU Affero General Public License, you may choose any version ever published
575
+ by the Free Software Foundation.
576
+
577
+ If the Program specifies that a proxy can decide which future
578
+ versions of the GNU Affero General Public License can be used, that proxy's
579
+ public statement of acceptance of a version permanently authorizes you
580
+ to choose that version for the Program.
581
+
582
+ Later license versions may give you additional or different
583
+ permissions. However, no additional obligations are imposed on any
584
+ author or copyright holder as a result of your choosing to follow a
585
+ later version.
586
+
587
+ 15. Disclaimer of Warranty.
588
+
589
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597
+
598
+ 16. Limitation of Liability.
599
+
600
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608
+ SUCH DAMAGES.
609
+
610
+ 17. Interpretation of Sections 15 and 16.
611
+
612
+ If the disclaimer of warranty and limitation of liability provided
613
+ above cannot be given local legal effect according to their terms,
614
+ reviewing courts shall apply local law that most closely approximates
615
+ an absolute waiver of all civil liability in connection with the
616
+ Program, unless a warranty or assumption of liability accompanies a
617
+ copy of the Program in return for a fee.
618
+
619
+ END OF TERMS AND CONDITIONS
620
+
621
+ How to Apply These Terms to Your New Programs
622
+
623
+ If you develop a new program, and you want it to be of the greatest
624
+ possible use to the public, the best way to achieve this is to make it
625
+ free software which everyone can redistribute and change under these terms.
626
+
627
+ To do so, attach the following notices to the program. It is safest
628
+ to attach them to the start of each source file to most effectively
629
+ state the exclusion of warranty; and each file should have at least
630
+ the "copyright" line and a pointer to where the full notice is found.
631
+
632
+ <one line to give the program's name and a brief idea of what it does.>
633
+ Copyright (C) <year> <name of author>
634
+
635
+ This program is free software: you can redistribute it and/or modify
636
+ it under the terms of the GNU Affero General Public License as published
637
+ by the Free Software Foundation, either version 3 of the License, or
638
+ (at your option) any later version.
639
+
640
+ This program is distributed in the hope that it will be useful,
641
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
642
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643
+ GNU Affero General Public License for more details.
644
+
645
+ You should have received a copy of the GNU Affero General Public License
646
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
647
+
648
+ Also add information on how to contact you by electronic and paper mail.
649
+
650
+ If your software can interact with users remotely through a computer
651
+ network, you should also make sure that it provides a way for users to
652
+ get its source. For example, if your program is a web application, its
653
+ interface could display a "Source" link that leads users to an archive
654
+ of the code. There are many ways you could offer source, and different
655
+ solutions will be better for different programs; see section 13 for the
656
+ specific requirements.
657
+
658
+ You should also get your employer (if you work as a programmer) or school,
659
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
660
+ For more information on this, and how to apply and follow the GNU AGPL, see
661
+ <https://www.gnu.org/licenses/>.
Readme.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## metrics to track
2
+ - loss per epoch per model boost layer
3
+ - number of error per epoch model boost layer
4
+ - number of resolved puzzles per epochs
5
+ - threshold per epochs per model layer
6
+ - number of filled digits per model boost layer per epoch for both pis ans abs
7
+
8
+ ## TODO
9
+ - jupyter notebook to python file
10
+ - threshold compute on test set (with adding a gap) each epoch. and training threshold initialised with test thresholds that evolve each error during training.
11
+
12
+ ## Possible way
13
+ - it might be smart to store the intermitent states as boost layereds "buffers". at the end the first X go to the model layer 0 let write it as puseudo code
14
+
15
+ ### Method threshold
16
+ ```
17
+ global init
18
+ th -> -10
19
+ training step
20
+ init
21
+ pass
22
+ training loop
23
+ keep th behind the error limit
24
+ validation step
25
+ init
26
+ compute_th =-10
27
+ validation loop
28
+ keep compute_th behind error limit + marge
29
+ but use th
30
+ end
31
+ th= compute_th
32
+ ```
33
+
34
+ ### Method training
35
+ ```
36
+ Xs -> the x initial batch vector
37
+ Y -> the y batch vector
38
+ Xs' = M0(Xs)
39
+ then we filter Xs'=Y -> resolved sudokus
40
+ Xs'==Xs -> we add the rows to X1 buffer
41
+ and the remaning Xs' is added to X0 buffer.
42
+ ```
43
+
44
+ then we look at each buffers X0 to Xn and we process each of them that are => batch size.
45
+
46
+ When every buffer are smaller than batch size the process is finished.
47
+
48
+ object
49
+ ```
50
+ Buffers
51
+ get_batch(limit_batch_size=True) -> idx, Xb # Xb could be none. (Xb should be a shuffled sample of the batch)
52
+ add_batch(Xp, idx)
53
+ ```
54
+
55
+
56
+ ### Loss optimisation
57
+ Both 0 and 1 target are different in the way we should gradient descend them.
58
+ y==0 point is something easy: it should be as low as possible I thing we can use the usual log loss function on it.
59
+ y==1 is different: there is different case possible:
60
+ - the point could be "unpredictable" in that case the gradient descend should be tuned to low, we expect the predictive function to have a low score.
61
+ - the point could be well predicted in that case we hope the value is prety hight and we would like to the the gradient descend more heavely.
62
+ This could be applied by using a sigmoid centered on the threshold
63
+
64
+
65
+ ### Paper writing
66
+
67
+ Les niveaux supérieurs font appel à divers types de chaînes :
68
+
69
+ 11.6 Dynamic + Dynamic Forcing Chains (145-192 nodes) Cell Forcing Chains
70
+ 11.7 Dynamic + Dynamic Forcing Chains (193-288 nodes) Double Forcing Chains
71
+ Ces Dynamic Forcing Chains sont une forme d’essais et erreurs.
72
+
73
+ ### Trial and error solving technique
74
+ We applied trial and error solving technique to reach 100% accuracy over sudoku. The resoning is simple we find the best digit/position to test and produce 2 children grid one with the number the other without. the we process each grid until one of them break sudoku's rules .
75
+
76
+ The V1 of this algorithm should only stopped at 1 trail and error test (no binary tree search) it should be simpler and feasible and if not : we will se an improve and try the next step.
77
+
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from sudoku.train import SudokuTrialErrorLightning
3
+ from sudoku.helper import display_as_dataframe, get_grid_number_soluce
4
+ import numpy as np
5
+ import re
6
+ import torch
7
+
8
+ model = SudokuTrialErrorLightning.load_from_checkpoint(checkpoint_path="model_9_m3_5epoch.ckpt")
9
+
10
+ empty_grid = '''
11
+ 0 0 0 | 0 0 0 | 0 0 0
12
+ 0 0 0 | 0 0 0 | 0 0 0
13
+ 0 0 0 | 0 0 0 | 0 0 0
14
+ ---------------------
15
+ 0 0 0 | 0 0 0 | 0 0 0
16
+ 0 0 0 | 0 0 0 | 0 0 0
17
+ 0 0 0 | 0 0 0 | 0 0 0
18
+ ---------------------
19
+ 0 0 0 | 0 0 0 | 0 0 0
20
+ 0 0 0 | 0 0 0 | 0 0 0
21
+ 0 0 0 | 0 0 0 | 0 0 0
22
+ '''
23
+ grid_evil_0 = '''
24
+ [0,0,0, 1,6,0, 2,0,0],
25
+ [0,3,0, 0,0,5, 0,0,4],
26
+ [0,0,7, 0,0,0, 0,9,0],
27
+
28
+ [0,0,1, 0,0,0, 0,8,2],
29
+ [3,0,0, 7,0,4, 0,0,5],
30
+ [4,2,0, 0,0,0, 6,0,0],
31
+
32
+ [0,8,0, 0,0,0, 5,0,0],
33
+ [7,0,0, 8,0,0, 0,2,0],
34
+ [0,0,6, 0,5,1, 0,0,0],
35
+ '''
36
+ grid_evil_1 = '''
37
+ [0,0,0, 2,0,0, 0,0,0],
38
+ [0,0,9, 7,0,0, 3,8,0],
39
+ [3,6,0, 0,0,0, 0,7,0],
40
+
41
+ [1,0,6, 8,0,0, 0,0,0],
42
+ [0,0,0, 3,5,9, 0,0,0],
43
+ [0,0,0, 0,0,2, 8,0,5],
44
+
45
+ [0,4,0, 0,0,0, 0,9,1],
46
+ [0,2,5, 0,0,8, 7,0,0],
47
+ [0,0,0, 0,0,6, 0,0,0],
48
+ ]
49
+ '''
50
+ grid_evil_2 = '''
51
+ [0,2,0, 1,0,0, 0,7,0],
52
+ [0,0,0, 3,0,2, 0,0,0],
53
+ [0,0,1, 0,8,0, 0,4,0],
54
+
55
+ [9,0,0, 0,1,0, 0,0,7],
56
+ [0,0,8, 0,6,0, 0,5,0],
57
+ [0,0,0, 0,0,0, 0,0,0],
58
+
59
+ [0,8,0, 0,3,6, 0,0,9],
60
+ [0,0,5, 0,7,0, 0,0,6],
61
+ [0,0,2, 0,0,0, 0,0,3],
62
+ ]
63
+ '''
64
+ grids = {
65
+ 'empty': empty_grid,
66
+ 'evil_1': grid_evil_0,
67
+ 'evil_2': grid_evil_1,
68
+ 'evil_3': grid_evil_2
69
+ }
70
+
71
+ def str_to_row_col_grid(su_str):
72
+ su_str = re.sub(r'[^\d.]', '', su_str).replace('.','0')
73
+
74
+ assert len(su_str)==81
75
+ return [[int(su_str[j*9+i]) for i in range(9)] for j in range(9)]
76
+
77
+ def build_tensor_out_of_grid(grid):
78
+ np_grid = np.asarray(grid)
79
+ big_grid = np.zeros((1,2,9,9,9), dtype=np.float32)
80
+ for i in range(9):
81
+ big_grid[0,1,:,:,i] = (np_grid==i+1)
82
+ big_grid[0,0,:,:,i] = ((np_grid!=i+1) & (np_grid!=0))
83
+ return torch.from_numpy(big_grid).view(1,2,729)
84
+
85
+ def build_grid_and_display(str_grid):
86
+ test_x = build_tensor_out_of_grid(str_grid)
87
+ return test_x, display_as_dataframe(test_x)
88
+ st.markdown('# Deep learning sudoku Solver')
89
+ st.markdown('### Author: Sébastien Guissart')
90
+
91
+ option = st.selectbox(
92
+ "select template grids",
93
+ grids,
94
+ )
95
+ x = st.text_area('grid', value=grids[option], height=300)
96
+
97
+ tensor_grid, df_styler = build_grid_and_display(str_to_row_col_grid(x))
98
+ # st.table(df_styler)
99
+ # st.dataframe(df_styler)
100
+ html = df_styler.to_html(escape=False, index=False)
101
+ # # st.text(html)
102
+ st.html(html)
103
+ st.markdown('''
104
+ ## Checking the number of solutions
105
+ (with backtracking algorithm)
106
+ must be equal 1
107
+ ''')
108
+ n_sol = get_grid_number_soluce(str_to_row_col_grid(x))
109
+
110
+ st.markdown(f"number of solution: {'>2' if n_sol==2 else n_sol}")
111
+ if n_sol==1:
112
+ new_X = model.predict(tensor_grid)
113
+ X_sum= new_X.sum()
114
+ st.html(display_as_dataframe(new_X).to_html(escape=False, index=False))
115
+ i=1
116
+ while new_X.sum()<729:
117
+ i+=1
118
+ st.markdown(f'iteration {i}')
119
+ new_X = model.predict(new_X)
120
+ st.html(display_as_dataframe(new_X).to_html(escape=False, index=False))
121
+ new_X_sum = new_X.sum()
122
+ assert new_X_sum> X_sum
123
+ X_sum = new_X_sum
124
+ st.markdown('## Grid solved!')
125
+
126
+
experiments/boost_training_7_base_2_loading.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
experiments/boost_training_7_m3_base.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
experiments/boost_training_7_m3_base_2.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
experiments/boost_training_7_m3_base_2_output-Copy2.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
experiments/boost_training_7_m3_base_2_output.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
experiments/boost_training_7_m3_base_2_retrain.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
experiments/boost_training_7_m3_harder_and_harder.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
experiments/boost_training_8_m3_train_only_once.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
experiments/boost_training_8_m3_train_only_once_analysis.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
experiments/boost_training_8_m3_train_only_once_clean.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
experiments/boost_training_9_m3_analysis.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
experiments/boost_training_9_m3_test_3m_sample.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
experiments/boost_training_9_m3_train_based_on_8_train_trial_error.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
justfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ init:
2
+ rm -rf venv
3
+ python3 -m venv venv
4
+ venv/bin/pip install -r requirements.txt
5
+
6
+ test:
7
+ PYTHONPATH=. venv/bin/pytest tests
8
+
9
+ tensorboard:
10
+ venv/bin/tensorboard --logdir lightning_logs
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ pytorch_lightning
2
+ lightning_lite
3
+ tensorboard
4
+ pandas
5
+ seaborn
6
+ torch==1.12.1
7
+ jupyter
8
+ scikit-learn
9
+ pytest
10
+ black
sudoku/__init__.py ADDED
File without changes
sudoku/buffer.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+
4
+ # TODO intensively test it!
5
+ # we got an error pop index out of range during
6
+
7
+
8
+ class BufferArray:
9
+ def __init__(self, array_size, batch_size):
10
+ self.array_size = array_size
11
+ self.batch_size = batch_size
12
+ self.batch_n_vectors = None
13
+
14
+ def init_buffers(self):
15
+ self.buffers = [
16
+ [[] for _ in range(self.batch_n_vectors)] for _ in range(self.array_size)
17
+ ]
18
+
19
+ def get_batch(self):
20
+ "return the idx of the first buffer reaching batch_size and a batch"
21
+ assert hasattr(self, "buffers")
22
+ for idx, buffer in enumerate(self.buffers):
23
+ if len(buffer[0]) >= self.batch_size:
24
+ vectors = [[] for _ in range(self.batch_n_vectors)]
25
+ for _ in range(self.batch_size):
26
+ pop_idx = random.randrange(len(buffer[0]))
27
+ for v, b in zip(vectors, buffer):
28
+ v.append(b.pop(pop_idx))
29
+ return idx, tuple([torch.stack(v, dim=0) for v in vectors])
30
+ return 0, None
31
+
32
+ def append(self, idx, batch: tuple):
33
+ "append batch to the desired buffer. we suppose batch as a tuple such as (x,y)"
34
+
35
+ if idx >= self.array_size:
36
+ return
37
+ if self.batch_n_vectors is None:
38
+ self.batch_n_vectors = len(batch)
39
+ self.init_buffers()
40
+ else:
41
+ assert len(batch) == self.batch_n_vectors
42
+ for i, element_vectors in enumerate(batch):
43
+ self.buffers[idx][i] = self.buffers[idx][i] + [
44
+ vector for vector in element_vectors
45
+ ]
46
+
47
+
48
+ # def append(self, X, Y) -> None:
49
+ # """Add experience to the buffer.
50
+
51
+ # Args:
52
+ # experience: tuple (state, action, reward, done, new_state)
53
+ # """
54
+
55
+ # X[Y == 0] = 0
56
+ # mask = ~(X == Y).view(-1, 2 * 729).all(dim=1)
57
+
58
+ # for x, y in zip(X[mask], Y[mask]):
59
+ # self.buffer.append((x, y))
60
+
61
+ class Buffer:
62
+ def __init__(self, batch_size):
63
+ self.batch_size = batch_size
64
+ self.batch_n_vectors = None
65
+
66
+ def init_buffer(self):
67
+ self.buffer = [[] for _ in range(self.batch_n_vectors)]
68
+
69
+ def get_batch(self):
70
+ "return the idx of the first buffer reaching batch_size and a batch"
71
+ if not hasattr(self, "buffer"):
72
+ return None
73
+ if len(self.buffer[0]) >= self.batch_size:
74
+ vectors = [[] for _ in range(self.batch_n_vectors)]
75
+ for _ in range(self.batch_size):
76
+ pop_idx = random.randrange(len(self.buffer[0]))
77
+ for v, b in zip(vectors, self.buffer):
78
+ v.append(b.pop(pop_idx))
79
+ return tuple([torch.stack(v, dim=0) for v in vectors])
80
+ return None
81
+
82
+ def append(self, batch: tuple):
83
+ "append batch to the desired buffer. we suppose batch as a tuple such as (x,y)"
84
+
85
+ if self.batch_n_vectors is None:
86
+ self.batch_n_vectors = len(batch)
87
+ self.init_buffer()
88
+ else:
89
+ assert len(batch) == self.batch_n_vectors
90
+ for i, element_vectors in enumerate(batch):
91
+ self.buffer[i] = self.buffer[i] + [
92
+ vector for vector in element_vectors
93
+ ]
sudoku/helper.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch.nn.functional as F
4
+
5
+
6
+ x_grid = np.array([x for x in range(1, 10)] * 9 * 9).reshape(9, 9, 3, 3)
7
+ df_grid = pd.DataFrame(x_grid.swapaxes(1, 2).reshape(27, 27))
8
+
9
+ cell = { # for row hover use <tr> instead of <td>
10
+ "selector": "td",
11
+ "props": [("color", "black"), ("text-align", "center")],
12
+ }
13
+ line_height_1 = {"selector": "tr", "props": [("line-height", "2em")]}
14
+ line_height_2 = {"selector": "td", "props": "line-height: inherit; padding: 0;"}
15
+ border_global = {"selector": "", "props": [("border", "2px solid")]}
16
+ border_left_property_1 = [{"selector": "td", "props": "border-left: 1px solid black"}]
17
+ border_left_property_2 = [{"selector": "td", "props": "border-left: 2px solid black"}]
18
+
19
+ border_1 = {
20
+ 3: border_left_property_1,
21
+ 6: border_left_property_1,
22
+ 9: border_left_property_2,
23
+ 12: border_left_property_1,
24
+ 15: border_left_property_1,
25
+ 18: border_left_property_2,
26
+ 21: border_left_property_1,
27
+ 24: border_left_property_1,
28
+ }
29
+ border_top_property_1 = [{"selector": "td", "props": "border-top: 1px solid black"}]
30
+ border_top_property_2 = [{"selector": "td", "props": "border-top: 2px solid black"}]
31
+ border_2 = {
32
+ 3: border_top_property_1,
33
+ 6: border_top_property_1,
34
+ 9: border_top_property_2,
35
+ 12: border_top_property_1,
36
+ 15: border_top_property_1,
37
+ 18: border_top_property_2,
38
+ 21: border_top_property_1,
39
+ 24: border_top_property_1,
40
+ }
41
+
42
+
43
+ def display_as_dataframe(x_input, output=None, display="neg"):
44
+ x_mask_pos = (
45
+ x_input[0, 1, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27)
46
+ )
47
+ x_mask_is_not_pos = x_input[0, 1, :].numpy().reshape(9, 9, 9).max(axis=2) == 0
48
+ x_mask_neg_pos = x_input[0, 0, :].numpy().reshape(9, 9, 9).copy()
49
+ x_mask_neg_pos[x_mask_is_not_pos] = np.array([0 for _ in range(9)])
50
+ x_mask_full_neg = (
51
+ x_input[0, 0, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27)
52
+ )
53
+ x_mask_neg_pos = x_mask_neg_pos.reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27)
54
+
55
+ cell_color = pd.DataFrame("black", index=df_grid.index, columns=df_grid.columns)
56
+ cell_color[x_mask_pos == 1] = "blue"
57
+ cell_color[x_mask_full_neg == 1] = "red"
58
+ cell_color[x_mask_neg_pos == 1] = "white"
59
+
60
+ styler = (
61
+ df_grid.style.hide(axis=1)
62
+ .hide(axis=0)
63
+ .set_properties(subset=[x for x in range(27)], **{"width": "2em"})
64
+ .set_table_styles(
65
+ [ # create internal CSS classes
66
+ {"selector": ".red", "props": "color: red; font-weight: bold"},
67
+ {"selector": ".blue", "props": "color: blue; font-weight: bold"},
68
+ {"selector": ".white", "props": "color: white; font-weight: bold"},
69
+ ],
70
+ overwrite=False,
71
+ )
72
+ .set_td_classes(cell_color)
73
+ .set_table_styles(
74
+ [cell, border_global, line_height_1, line_height_2], overwrite=False
75
+ )
76
+ .set_table_styles(border_1, overwrite=False)
77
+ .set_table_styles(border_2, overwrite=False, axis=1)
78
+ )
79
+ if (output is not None) and (display == "neg"):
80
+ x_output = (
81
+ output[0, 0, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27)
82
+ )
83
+ return styler.background_gradient(
84
+ axis=None, vmin=-3, vmax=3, cmap="bwr", gmap=x_output
85
+ )
86
+ if (output is not None) and (display == "pos"):
87
+ x_output = (
88
+ output[0, 1, :].numpy().reshape(9, 9, 3, 3).swapaxes(1, 2).reshape(27, 27)
89
+ )
90
+ return styler.background_gradient(
91
+ axis=None, vmin=-3, vmax=3, cmap="bwr", gmap=x_output
92
+ )
93
+
94
+ return styler
95
+
96
+
97
+ def compute_loss(x, y, output, new_x):
98
+ mask_0_error = (new_x == 1) & (y == 0)
99
+ mask_error = mask_0_error.view(-1, 2 * 729).any(dim=1)
100
+ mask_no_improve = new_x.sum(dim=(1, 2)) <= x.sum(dim=(1, 2))
101
+ mask_no_improve[mask_error] = False
102
+ mask_1_no_improve = y == 1
103
+ mask_1_no_improve[~mask_no_improve] = False
104
+ loss = F.binary_cross_entropy_with_logits(output, y, reduce=False)
105
+ loss_error = loss[mask_0_error].mean()
106
+ loss_no_improve = loss[mask_1_no_improve].mean()
107
+ # loss_error = F.binary_cross_entropy_with_logits(output[mask_0_error], y[mask_0_error])
108
+ # loss_no_improve = F.binary_cross_entropy_with_logits(output[mask_1_no_improve], y[mask_1_no_improve])
109
+
110
+ return loss_error, loss_no_improve, mask_error.sum(), mask_no_improve.sum()
111
+
112
+
113
+ # returns 0, 1 or more than 1 depending on whether 0, 1 or more than 1 solutions are found
114
+ def solve(i, j, cells, count): # initially called with count = 0
115
+ if i == 9:
116
+ i = 0
117
+ j += 1
118
+ if j == 9:
119
+ return 1 + count
120
+ if cells[i][j] != 0: # skip filled cells
121
+ return solve(i + 1, j, cells, count)
122
+
123
+ for val in range(1, 10):
124
+ if count < 2 and legal(i, j, val, cells):
125
+ cells[i][j] = val
126
+ count = solve(i + 1, j, cells, count)
127
+
128
+ cells[i][j] = 0 # reset on backtrack
129
+ return count
130
+
131
+
132
+ def legal(row, col, num, grid):
133
+
134
+ # Check if we find the same num
135
+ # in the similar row , we
136
+ # return false
137
+ for x in range(9):
138
+ if grid[row][x] == num:
139
+ return False
140
+
141
+ # Check if we find the same num in
142
+ # the similar column , we
143
+ # return false
144
+ for x in range(9):
145
+ if grid[x][col] == num:
146
+ return False
147
+
148
+ # Check if we find the same num in
149
+ # the particular 3*3 matrix,
150
+ # we return false
151
+ startRow = row - row % 3
152
+ startCol = col - col % 3
153
+ for i in range(3):
154
+ for j in range(3):
155
+ if grid[i + startRow][j + startCol] == num:
156
+ return False
157
+ return True
158
+
159
+ def get_grid_number_soluce(grid):
160
+ return solve(0,0,grid,0)
sudoku/loader.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ from sklearn.model_selection import train_test_split
5
+ from torch.utils.data.dataset import IterableDataset
6
+ from collections import deque
7
+
8
+ from numpy.random import default_rng
9
+
10
+ DATA = np.load(
11
+ # "/home/seb/Perso/git/sudoku/sudoku_clean/data/sudoku_reshaped_million.npz"
12
+ "sudoku_reshaped_3_million.npz"
13
+ )
14
+
15
+ rng = np.random.default_rng()
16
+
17
+
18
+ def get_datasets(
19
+ add_proba_fill=False, train_size=1280 // 2, test_size=1280 // 2, max_holes=None
20
+ ):
21
+ quizzes = DATA["quizzes"][: train_size + test_size]
22
+ solutions = DATA["solutions"][: train_size + test_size]
23
+ X = quizzes
24
+ if max_holes:
25
+ while True:
26
+ x_holes = X[:, 1].sum(-1) == 0
27
+ x_nb_holes = x_holes.sum((1, 2))
28
+ mask_x_max_holes = x_nb_holes > max_holes
29
+ if not any(mask_x_max_holes):
30
+ break
31
+ for idx_x in np.nonzero(mask_x_max_holes)[0]:
32
+ sub_x_holes = x_holes[idx_x]
33
+ idx_fill = rng.choice(np.transpose(np.nonzero(sub_x_holes)))
34
+ X[idx_x, :, idx_fill[0], idx_fill[1], :] = solutions[
35
+ idx_x, :, idx_fill[0], idx_fill[1], :
36
+ ]
37
+ X = X.reshape(X.shape[0], 2, 9 * 9 * 9)
38
+ solutions = solutions.reshape(solutions.shape[0], 2, 9 * 9 * 9)
39
+
40
+ X_train, X_test, solutions_train, solutions_test = train_test_split(
41
+ X, solutions, test_size=test_size, random_state=42
42
+ )
43
+ if add_proba_fill:
44
+ X_train_bis = X_train.copy()
45
+ mask = solutions_train == 1
46
+ X_train_bis[mask] = np.random.randint(0, 2, size=mask.sum())
47
+ X_train = np.concatenate([X_train, X_train_bis])
48
+ solutions_train = np.concatenate([solutions_train, solutions_train])
49
+
50
+ train = torch.utils.data.TensorDataset(
51
+ torch.Tensor(X_train), torch.Tensor(solutions_train)
52
+ )
53
+ test = torch.utils.data.TensorDataset(
54
+ torch.Tensor(X_test), torch.Tensor(solutions_test)
55
+ )
56
+ return train, test
57
+
58
+
59
+ train_dataset, test_dataset = get_datasets()
60
+
61
+
62
+ def data_loader(batch_size=32, add_proba_fill=False):
63
+ train, test = get_datasets(add_proba_fill=add_proba_fill)
64
+
65
+ train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size)
66
+
67
+ test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size)
68
+
69
+ return train_loader, test_loader
70
+
71
+
72
+ class DataIterBuffer(IterableDataset):
73
+ def __init__(self, raw_dataset=[], buffer_optim=50, prop_new=0.1, seed=1):
74
+ self.raw_dataset = raw_dataset
75
+ # self.raw_dataset = iter(raw_dataset)
76
+ self.buffer = deque()
77
+ self.buffer_optim = buffer_optim
78
+ self.prop_new = prop_new
79
+ self.rng = default_rng(seed=seed)
80
+ self.idx_dataset = 0
81
+
82
+ def __iter__(self):
83
+ # while True:
84
+ # if (np.random.random() < self.prop_new) and (
85
+ # len(self.buffer) <= self.buffer_optim
86
+ # ):
87
+ # try:
88
+ # yield next(self.raw_dataset)
89
+ # except StopIteration:
90
+ # if len(self.buffer) != 0:
91
+ # yield self.buffer.popleft()
92
+ # else:
93
+ # break
94
+ # else:
95
+ # if len(self.buffer) != 0:
96
+ # yield self.buffer.popleft()
97
+ # else:
98
+ # try:
99
+ # yield next(self.raw_dataset)
100
+ # except StopIteration:
101
+ # break
102
+ while True:
103
+ if (np.random.random() < self.prop_new) and (
104
+ len(self.buffer) <= self.buffer_optim
105
+ ):
106
+ if self.idx_dataset >= len(self.raw_dataset):
107
+ if len(self.buffer) != 0:
108
+ yield self.buffer.popleft()
109
+ else:
110
+ break
111
+ else:
112
+ yield self.raw_dataset[self.idx_dataset]
113
+ self.idx_dataset += 1
114
+ else:
115
+ if len(self.buffer) != 0:
116
+ yield self.buffer.popleft()
117
+ else:
118
+ if self.idx_dataset >= len(self.raw_dataset):
119
+ break
120
+ else:
121
+ yield self.raw_dataset[self.idx_dataset]
122
+ self.idx_dataset += 1
123
+
124
+ def append(self, X, Y) -> None:
125
+ """Add experience to the buffer.
126
+
127
+ Args:
128
+ experience: tuple (state, action, reward, done, new_state)
129
+ """
130
+
131
+ X[Y == 0] = 0
132
+ mask = ~(X == Y).view(-1, 2 * 729).all(dim=1)
133
+
134
+ for x, y in zip(X[mask], Y[mask]):
135
+ self.buffer.append((x, y))
136
+
137
+ def __len__(self):
138
+ return len(self.buffer) + len(self.raw_dataset)
139
+
140
+
141
+ # class DataIterDeepBuffer(IterableDataset):
142
+ # def __init__(self, raw_dataset=[], buffer_target_size=32, prop_new=0.1, seed=1, prof=6):
143
+ # self.raw_dataset = iter(raw_dataset)
144
+ # # self.buffer = deque()
145
+ # self.buffer_target_size = buffer_target_size
146
+ # self.prop_new = prop_new
147
+ # self.rng = default_rng(seed=seed)
148
+ # self.prof=prof
149
+ # self.buffers=[deque() for _ in range(prof)]
150
+
151
+ # def __iter__(self):
152
+ # while True:
153
+ # buffer_sizes = np.array([len(buffer) for buffer in self.buffers])
154
+ # if any(buffer_sizes>=self.buffer_target_size):
155
+ # #
156
+
157
+ # if (np.random.random() < self.prop_new) and (
158
+ # len(self.buffer) <= self.buffer_optim
159
+ # ):
160
+ # try:
161
+ # yield next(self.raw_dataset)
162
+ # except StopIteration:
163
+ # if len(self.buffer) != 0:
164
+ # yield self.buffer.popleft()
165
+ # else:
166
+ # break
167
+ # else:
168
+ # if len(self.buffer) != 0:
169
+ # yield self.buffer.popleft()
170
+ # else:
171
+ # try:
172
+ # yield next(self.raw_dataset)
173
+ # except StopIteration:
174
+ # break
175
+
176
+ # def append(self, X, Y) -> None:
177
+ # """Add experience to the buffer.
178
+
179
+ # Args:
180
+ # experience: tuple (state, action, reward, done, new_state)
181
+ # """
182
+
183
+ # X[Y == 0] = 0
184
+ # mask = ~(X == Y).view(-1, 2 * 729).all(dim=1)
185
+
186
+ # for x, y in zip(X[mask], Y[mask]):
187
+ # self.buffer.append((x, y))
188
+
189
+ # def __len__(self):
190
+ # return len(self.buffer) + len(self.raw_dataset)
sudoku/loader_buffer.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from torch.utils.data.dataset import IterableDataset
4
+ from torch.utils.data import DataLoader
5
+ import numpy as np
6
+
7
+ from sudoku.loader import DataIterBuffer, train_dataset, test_dataset, data_loader, get_datasets
8
+
9
+
10
+
11
+ class CustomDataLoader(DataLoader):
12
+ def __init__(self, data_iters, batch_size):
13
+ self.data_iters=data_iters
14
+ self.batch_size=batch_size
15
+ self.data_loaders = [iter(DataLoader(data_iter, batch_size=batch_size)) for data_iter in data_iters]
16
+ def __iter__(self):
17
+ while True:
18
+ buffer_sizes = np.array([len(buffer) for buffer in self.data_iters])
19
+ if any(buffer_sizes>=self.batch_size):
20
+ idx_yield = len(buffer_sizes)-1-np.argmax(buffer_sizes[::-1]>=self.batch_size)
21
+
22
+ else :
23
+ idx_yield = np.argmax(buffer_sizes)
24
+ yield [idx_yield]+ next(self.data_loaders[idx_yield])
25
+
26
+ @property
27
+ def num_workers(self):
28
+ return 0
sudoku/models.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from sudoku.symetries import mat_sym
4
+
5
+
6
+ class SudokuNet(nn.Module):
7
+ def __init__(self, n_output=2, coef_hidden=4):
8
+ super(SudokuNet, self).__init__()
9
+ self.hidden_neural_number = mat_sym.shape[1]
10
+
11
+ self.conv111 = nn.Conv1d(
12
+ self.hidden_neural_number * 2,
13
+ self.hidden_neural_number * 2 * coef_hidden,
14
+ 1,
15
+ groups=self.hidden_neural_number * 2,
16
+ )
17
+ self.conv111_last = nn.Conv1d(
18
+ self.hidden_neural_number * 2 * coef_hidden, n_output, 1
19
+ )
20
+
21
+ sym_tensor = torch.from_numpy(mat_sym).type(torch.FloatTensor)
22
+ self.sym_tensor = nn.Parameter(sym_tensor, requires_grad=False)
23
+
24
+ def forward(self, x):
25
+ x = torch.tensordot(x, self.sym_tensor, dims=([2], [2]))
26
+ x = x.view(-1, 2, 9 * 9 * 9, self.hidden_neural_number)
27
+ x = x.permute(0, 1, 3, 2)
28
+ x = x.contiguous().view(-1, self.hidden_neural_number * 2, 9 * 9 * 9)
29
+ x = torch.relu(self.conv111(x))
30
+ x = self.conv111_last(x)
31
+ return x
32
+
33
+
34
+ class SymPreprocess(nn.Module):
35
+ def __init__(self):
36
+ super().__init__()
37
+ self.hidden_neural_number = mat_sym.shape[1]
38
+ sym_tensor = torch.from_numpy(mat_sym).type(torch.FloatTensor)
39
+ self.sym_tensor = nn.Parameter(sym_tensor, requires_grad=False)
40
+
41
+ def forward(self, x):
42
+ n_channel = x.shape[1]
43
+ x = torch.tensordot(x, self.sym_tensor, dims=([2], [2]))
44
+ x = x.view(-1, n_channel, 9 * 9 * 9, self.hidden_neural_number)
45
+ x = x.permute(0, 1, 3, 2)
46
+ x = x.contiguous().view(-1, self.hidden_neural_number * n_channel, 9 * 9 * 9)
47
+ return x
48
+
49
+
50
+ class SmallNet(nn.Module):
51
+ def __init__(self, n_output=2, coef_hidden=4, n_input_channel=2):
52
+ super(SmallNet, self).__init__()
53
+ self.hidden_neural_number = mat_sym.shape[1]
54
+
55
+ self.conv111 = nn.Conv1d(
56
+ self.hidden_neural_number * n_input_channel,
57
+ self.hidden_neural_number * n_input_channel * coef_hidden,
58
+ 1,
59
+ groups=self.hidden_neural_number * n_input_channel,
60
+ )
61
+ self.conv111_last = nn.Conv1d(
62
+ self.hidden_neural_number * n_input_channel * coef_hidden, n_output, 1
63
+ )
64
+
65
+ def forward(self, x):
66
+ x = torch.relu(self.conv111(x))
67
+ x = self.conv111_last(x)
68
+ return x
69
+
70
+
71
+ class SmallNetBis(nn.Module):
72
+ def __init__(self, n_output=2, coef_hidden=2, n_input_channel=2):
73
+ super(SmallNetBis, self).__init__()
74
+ self.hidden_neural_number = mat_sym.shape[1] * 2
75
+
76
+ self.conv111 = nn.Conv1d(
77
+ self.hidden_neural_number * n_input_channel,
78
+ self.hidden_neural_number * n_input_channel * coef_hidden,
79
+ 1,
80
+ groups=self.hidden_neural_number * n_input_channel,
81
+ )
82
+ self.conv111_last = nn.Conv1d(
83
+ self.hidden_neural_number * n_input_channel * coef_hidden, n_output, 1
84
+ )
85
+
86
+ def forward(self, x):
87
+ x = torch.cat([x, 1 - x], dim=1)
88
+ x = torch.relu(self.conv111(x))
89
+ x = self.conv111_last(x)
90
+ return x
91
+
92
+
93
+ class SplittedSmallNet(nn.Module):
94
+ def __init__(self, coef_hidden=4, n_input_channel=2):
95
+ super().__init__()
96
+ self.hidden_neural_number = mat_sym.shape[1]
97
+
98
+ self.conv111_0 = nn.Conv1d(
99
+ self.hidden_neural_number * n_input_channel,
100
+ self.hidden_neural_number * n_input_channel * coef_hidden,
101
+ 1,
102
+ groups=self.hidden_neural_number * n_input_channel,
103
+ )
104
+ self.conv111_1 = nn.Conv1d(
105
+ self.hidden_neural_number * n_input_channel,
106
+ self.hidden_neural_number * n_input_channel * coef_hidden,
107
+ 1,
108
+ groups=self.hidden_neural_number * n_input_channel,
109
+ )
110
+
111
+ self.conv111_last_0 = nn.Conv1d(
112
+ self.hidden_neural_number * n_input_channel * coef_hidden, 1, 1
113
+ )
114
+ self.conv111_last_1 = nn.Conv1d(
115
+ self.hidden_neural_number * n_input_channel * coef_hidden, 1, 1
116
+ )
117
+
118
+ def forward(self, x):
119
+ x_0 = torch.relu(self.conv111_0(x))
120
+ x_0 = self.conv111_last_0(x_0)
121
+
122
+ x_1 = torch.relu(self.conv111_1(x))
123
+ x_1 = self.conv111_last_1(x_1)
124
+ return torch.cat([x_0, x_1], dim=1)
sudoku/symetries.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def generate_symetrie_reduction():
5
+ # chiffre, ligne, colonne, chiffre ligne colonne
6
+ a1_array = np.zeros((9, 9, 9, 9, 9, 9), dtype=float)
7
+ a2_array = np.zeros((9, 9, 9, 9, 9, 9), dtype=float)
8
+ a3_array = np.zeros((9, 9, 9, 9, 9, 9), dtype=float)
9
+ a4_array = np.zeros((9, 9, 9, 9, 9, 9), dtype=float)
10
+ a5_array = np.zeros((9, 9, 9, 9, 9, 9), dtype=float)
11
+ a6_array = np.zeros((9, 9, 9, 9, 9, 9), dtype=float)
12
+ a7_array = np.zeros((9, 9, 9, 9, 9, 9), dtype=float)
13
+ a8_array = np.zeros((9, 9, 9, 9, 9, 9), dtype=float)
14
+
15
+ b1_array = np.zeros((9, 9, 9, 9, 9, 9), dtype=float)
16
+ b2_array = np.zeros((9, 9, 9, 9, 9, 9), dtype=float)
17
+ b3_array = np.zeros((9, 9, 9, 9, 9, 9), dtype=float)
18
+ b4_array = np.zeros((9, 9, 9, 9, 9, 9), dtype=float)
19
+ b5_array = np.zeros((9, 9, 9, 9, 9, 9), dtype=float)
20
+ b6_array = np.zeros((9, 9, 9, 9, 9, 9), dtype=float)
21
+ b7_array = np.zeros((9, 9, 9, 9, 9, 9), dtype=float)
22
+ b8_array = np.zeros((9, 9, 9, 9, 9, 9), dtype=float)
23
+
24
+ for i in range(9):
25
+ for j in range(9):
26
+ for n in range(9):
27
+ a1_array[i, j, n, i, j, n] = 1.0
28
+
29
+ a2_array[autre_chiffre(i), j, n, i, j, n] = 1.0 / 8.0
30
+
31
+ a3_array[i, autre_chiffre(j), n, i, j, n] = 1.0 / 8.0
32
+
33
+ a4_array[np.ix_([i], meme_bloc(j), [n], [i], [j], [n])] = 1.0 / 8.0
34
+ a4_array[np.ix_(meme_bloc(i), [j], [n], [i], [j], [n])] = 1.0 / 8.0
35
+ a4_array[np.ix_(meme_bloc(i), meme_bloc(j), [n], [i], [j], [n])] = (
36
+ 1.0 / 8.0
37
+ )
38
+
39
+ a5_array[np.ix_(meme_bloc(i), autre_bloc_b(j), [n], [i], [j], [n])] = (
40
+ 1.0 / 6.0
41
+ )
42
+
43
+ a6_array[np.ix_(meme_bloc(i), autre_bloc_h(j), [n], [i], [j], [n])] = (
44
+ 1.0 / 6.0
45
+ )
46
+
47
+ a7_array[np.ix_(autre_bloc_b(i), meme_bloc(j), [n], [i], [j], [n])] = (
48
+ 1.0 / 6.0
49
+ )
50
+
51
+ a8_array[np.ix_(autre_bloc_h(i), meme_bloc(j), [n], [i], [j], [n])] = (
52
+ 1.0 / 6.0
53
+ )
54
+
55
+ b1_array[i, j, n, i, j, autre_chiffre(n)] = 1.0 / 8.0
56
+
57
+ b2_array[
58
+ np.ix_(autre_chiffre(i), [j], autre_chiffre(n), [i], [j], [n])
59
+ ] = (1.0 / 64.0)
60
+
61
+ b3_array[
62
+ np.ix_([i], autre_chiffre(j), autre_chiffre(n), [i], [j], [n])
63
+ ] = (1.0 / 64.0)
64
+
65
+ b4_array[np.ix_([i], meme_bloc(j), autre_chiffre(n), [i], [j], [n])] = (
66
+ 1.0 / 64.0
67
+ )
68
+ b4_array[np.ix_(meme_bloc(i), [j], autre_chiffre(n), [i], [j], [n])] = (
69
+ 1.0 / 64.0
70
+ )
71
+ b4_array[
72
+ np.ix_(meme_bloc(i), meme_bloc(j), autre_chiffre(n), [i], [j], [n])
73
+ ] = (1.0 / 64.0)
74
+
75
+ b5_array[
76
+ np.ix_(
77
+ meme_bloc(i), autre_bloc_b(j), autre_chiffre(n), [i], [j], [n]
78
+ )
79
+ ] = (1.0 / 48.0)
80
+
81
+ b6_array[
82
+ np.ix_(
83
+ meme_bloc(i), autre_bloc_h(j), autre_chiffre(n), [i], [j], [n]
84
+ )
85
+ ] = (1.0 / 48.0)
86
+
87
+ b7_array[
88
+ np.ix_(
89
+ autre_bloc_b(i), meme_bloc(j), autre_chiffre(n), [i], [j], [n]
90
+ )
91
+ ] = (1.0 / 48.0)
92
+
93
+ b8_array[
94
+ np.ix_(
95
+ autre_bloc_h(i), meme_bloc(j), autre_chiffre(n), [i], [j], [n]
96
+ )
97
+ ] = (1.0 / 48.0)
98
+
99
+ list_arr = [
100
+ a1_array,
101
+ a2_array,
102
+ a3_array,
103
+ a4_array,
104
+ a5_array,
105
+ a6_array,
106
+ a7_array,
107
+ a8_array,
108
+ b1_array,
109
+ b2_array,
110
+ b3_array,
111
+ b4_array,
112
+ b5_array,
113
+ b6_array,
114
+ b7_array,
115
+ b8_array,
116
+ ]
117
+ full_sym = np.zeros((9, 9, 9, 9, 9, 9, len(list_arr)))
118
+ for i, arr in enumerate(list_arr):
119
+ full_sym[:, :, :, :, :, :, i] = arr
120
+
121
+ full_sym = full_sym.reshape(9 * 9 * 9, 9 * 9 * 9, len(list_arr))
122
+ return full_sym
123
+
124
+
125
+ def meme_bloc(i):
126
+ bloc_start = 3 * (i // 3)
127
+ bloc_end = bloc_start + 2
128
+ if i == bloc_start:
129
+ return [bloc_start + 1, bloc_end]
130
+ elif i == bloc_end:
131
+ return [bloc_start, bloc_start + 1]
132
+ else:
133
+ return [bloc_start, bloc_end]
134
+
135
+
136
+ def autre_bloc(i):
137
+ bloc_start = 3 * (i // 3)
138
+ all_in_bloc = [bloc_start, bloc_start + 1, bloc_start + 2]
139
+ list_to_return = []
140
+ for n in range(9):
141
+ if not (n in all_in_bloc):
142
+ list_to_return.append(n)
143
+ return list_to_return
144
+
145
+
146
+ def autre_bloc_b(i):
147
+ bloc_start = 3 * (i // 3)
148
+ all_in_bloc = [bloc_start, bloc_start + 1, bloc_start + 2]
149
+ list_to_return = []
150
+ for n in range(9):
151
+ if not (n in all_in_bloc):
152
+ list_to_return.append(n)
153
+ return list_to_return[:3]
154
+
155
+
156
+ def autre_bloc_h(i):
157
+ bloc_start = 3 * (i // 3)
158
+ all_in_bloc = [bloc_start, bloc_start + 1, bloc_start + 2]
159
+ list_to_return = []
160
+ for n in range(9):
161
+ if not (n in all_in_bloc):
162
+ list_to_return.append(n)
163
+ return list_to_return[3:]
164
+
165
+
166
+ def autre_chiffre(i):
167
+ return [n for n in range(9) if n != i]
168
+
169
+
170
+ mat_sym = generate_symetrie_reduction()
171
+ mat_sym = np.moveaxis(mat_sym, (0, 1, 2), (2, 0, 1))
sudoku/train.py ADDED
@@ -0,0 +1,885 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import copysign
2
+ import torch
3
+ from torch import nn
4
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
5
+ import pytorch_lightning as pl
6
+ from sudoku.models import SmallNetBis, SymPreprocess
7
+ import torch.nn.functional as F
8
+
9
+ from sudoku.buffer import BufferArray, Buffer
10
+ from sudoku.trial_grid import TrialGrid
11
+
12
+ from copy import deepcopy
13
+
14
+
15
+ class SudokuLightning(pl.LightningModule):
16
+ def __init__(
17
+ self,
18
+ lr=0.1,
19
+ margin=0.1, # th marge?
20
+ coef_0=10,
21
+ nets_number=6,
22
+ nets_training_number=1,
23
+ batch_size=32,
24
+ ):
25
+ super().__init__()
26
+ self.nets_number = nets_number
27
+ self.batch_size = batch_size
28
+ self.nets_training_number = nets_training_number
29
+ # self.nets=[SmallNetBis() for _ in range(self.nets_number)]
30
+ self.nets = nn.ModuleList([SmallNetBis() for _ in range(self.nets_number)])
31
+ self.buffer = BufferArray(self.nets_number, self.batch_size)
32
+ self.sym_preprocess = SymPreprocess()
33
+ pos_weight = torch.ones((2, 9 * 9 * 9))
34
+ pos_weight[0, :] = 1.0 / 8.0
35
+ pos_weight[1, :] = 1.0
36
+ pos_weight /= coef_0
37
+ weight = torch.ones((2, 9 * 9 * 9))
38
+ weight[0, :] = 8.0
39
+ weight[1, :] = 1.0
40
+ weight *= coef_0
41
+
42
+ self.bcewll = nn.BCEWithLogitsLoss(
43
+ pos_weight=pos_weight, weight=weight, reduce=False
44
+ )
45
+ self.lr = lr
46
+ # self.auroc = AUROC(task='binary')
47
+
48
+ self.margin = margin
49
+ self.th_epsilon = margin * 0.01
50
+ self.threshold_pres = torch.tensor([-10.0 for _ in range(nets_number)])
51
+ self.threshold_abs = torch.tensor([-10.0 for _ in range(nets_number)])
52
+
53
+ self.automatic_optimization = False
54
+ self.reset_threshold_on_validation = True
55
+
56
+ def configure_optimizers(self):
57
+ # no need config scheduler -> manual optimisation
58
+ optimizers = []
59
+ for net in self.nets:
60
+ opti = torch.optim.Adam(net.parameters(), lr=self.lr)
61
+ optimizers.append(
62
+ {
63
+ "optimizer": opti,
64
+ "lr_scheduler": ReduceLROnPlateau(opti, "min"),
65
+ }
66
+ )
67
+ return optimizers
68
+
69
+ # def configure_optimizers(self):
70
+ # optimizer1 = Adam(...)
71
+ # optimizer2 = SGD(...)
72
+ # scheduler1 = ReduceLROnPlateau(optimizer1, ...)
73
+ # scheduler2 = LambdaLR(optimizer2, ...)
74
+ # return (
75
+ # {
76
+ # "optimizer": optimizer1,
77
+ # "lr_scheduler": {
78
+ # "scheduler": scheduler1,
79
+ # "monitor": "metric_to_track",
80
+ # },
81
+ # },
82
+ # {"optimizer": optimizer2, "lr_scheduler": scheduler2},
83
+ # )
84
+
85
+ # lr_scheduler_config = {
86
+ # # REQUIRED: The scheduler instance
87
+ # "scheduler": lr_scheduler,
88
+ # # The unit of the scheduler's step size, could also be 'step'.
89
+ # # 'epoch' updates the scheduler on epoch end whereas 'step'
90
+ # # updates it after a optimizer update.
91
+ # "interval": "epoch",
92
+ # # How many epochs/steps should pass between calls to
93
+ # # `scheduler.step()`. 1 corresponds to updating the learning
94
+ # # rate after every epoch/step.
95
+ # "frequency": 1,
96
+ # # Metric to to monitor for schedulers like `ReduceLROnPlateau`
97
+ # "monitor": "val_loss",
98
+ # # If set to `True`, will enforce that the value specified 'monitor'
99
+ # # is available when the scheduler is updated, thus stopping
100
+ # # training if not found. If set to `False`, it will only produce a warning
101
+ # "strict": True,
102
+ # # If using the `LearningRateMonitor` callback to monitor the
103
+ # # learning rate progress, this keyword can be used to specify
104
+ # # a custom logged name
105
+ # "name": None,
106
+ # }
107
+ # lr_scheduler_config = {'scheduler: lr_sch, interval: epoch, frequency: 1, monitor: 'val_loss'}
108
+
109
+ def forward_layer(self, x, idx=0):
110
+ x = self.sym_preprocess.forward(x)
111
+ return self.nets[idx](x)
112
+
113
+ def forward(self, x):
114
+ for idx in range(self.nets_number):
115
+ output = self.forward_layer(x, idx)
116
+ new_X = self.compute_new_X(output, x, idx, None, train=False)
117
+ improved_mask = ((new_X == 1) & (x == 0)).any(dim=1).any(dim=1)
118
+ if improved_mask.sum() > 0:
119
+ return idx, new_X
120
+ return idx, new_X
121
+
122
+ def predict_from_net(self, x, net, th_abs, th_pres):
123
+ x = self.sym_preprocess.forward(x)
124
+ x = net(x)
125
+ new_x = torch.empty(x.shape, device=x.device)
126
+ new_x[:, 0] = (x[:, 0] > th_abs).float()
127
+ new_x[:, 1] = (x[:, 1] > th_pres).float()
128
+ return new_x
129
+
130
+ @staticmethod
131
+ def mask_uncomplete(x, y):
132
+ mask_uncomplete = x.reshape(-1, 2, 9, 9, 9).sum(-1) < torch.tensor((8, 1)).to(
133
+ x
134
+ ).reshape(1, 2, 1, 1)
135
+ mask_uncomplete = mask_uncomplete.reshape(-1, 2, 9, 9, 1)
136
+ mask = ((x == 0).reshape(-1, 2, 9, 9, 9) * mask_uncomplete).reshape(
137
+ -1, 2, 9**3
138
+ )
139
+ mask = mask.float()
140
+ return mask
141
+
142
+ def computing_loss(self, x, y, output):
143
+ loss = self.bcewll(output, y)
144
+ mask = self.mask_uncomplete(x, y)
145
+ loss = (loss * mask).sum()
146
+
147
+ return loss
148
+
149
+ def training_step(self, batch, batch_idx):
150
+ self.log(
151
+ "train_grid_count",
152
+ batch[0].shape[0],
153
+ reduce_fx=torch.sum,
154
+ on_epoch=True,
155
+ on_step=False,
156
+ )
157
+
158
+ self.layer_training_step(0, batch)
159
+ while True:
160
+ idx, batch = self.buffer.get_batch()
161
+ if batch is None:
162
+ break
163
+ # check if the train should be done by comparing lr from sch = self.lr_schedulers()
164
+ # if self.lr != sch[idx].get_last_lr():
165
+ self.layer_training_step(idx, batch)
166
+
167
+ def validation_step(self, batch, batch_idx):
168
+ self.layer_training_step(0, batch, train=False)
169
+ while True:
170
+ idx, batch = self.buffer.get_batch()
171
+ if batch is None:
172
+ break
173
+ # check if the train should be done by comparing lr from sch = self.lr_schedulers()
174
+ # if self.lr != sch[idx].get_last_lr():
175
+ self.layer_training_step(idx, batch, train=False)
176
+
177
+ def layer_training_step(
178
+ self, idx, batch, train=True
179
+ ): # to rename to layer_training_step
180
+ x, y = batch
181
+
182
+ prefix = "train" if train else "val"
183
+ self.log(
184
+ f"{prefix}_grid_count_{idx}",
185
+ batch[0].shape[0],
186
+ reduce_fx=torch.sum,
187
+ on_epoch=True,
188
+ on_step=False,
189
+ )
190
+
191
+ output = self.forward_layer(x, idx)
192
+ loss = self.computing_loss(x, y, output)
193
+ if train:
194
+
195
+ opt = self.optimizers()#[idx]
196
+ if isinstance(opt, list):
197
+ opt=opt[idx]
198
+ opt.zero_grad()
199
+ self.manual_backward(loss)
200
+ opt.step()
201
+
202
+ loss_0 = F.binary_cross_entropy_with_logits(output[:, [0], :], y[:, [0], :])
203
+ loss_1 = F.binary_cross_entropy_with_logits(output[:, [1], :], y[:, [1], :])
204
+ self.log_dict(
205
+ {f"{prefix}_loss_pos": loss_1, f"{prefix}_loss_neg": loss_0}, on_epoch=True
206
+ )
207
+
208
+ # accuracy_1 = torch.mean(torch.eq(transform_to_number_1(output), transform_to_number_1(x)).type(torch.float))
209
+ # accuracy_0 = torch.mean(torch.eq(transform_to_number_0(output), transform_to_number_0(x)).type(torch.float))
210
+ # self.log_dict({'accuracy_1': accuracy_1, 'accuracy_0': accuracy_0}, on_epoch=True)
211
+ self.log(f"{prefix}_loss_{idx}", loss)
212
+ # add a count log on (X and x == y)
213
+
214
+ new_X = self.compute_new_X(output, x, idx, y, train=train)
215
+ solved_mask = (new_X == y).all(dim=1).all(dim=1)
216
+ new_X = new_X[~solved_mask]
217
+ x = x[~solved_mask]
218
+ y = y[~solved_mask]
219
+ self.log(
220
+ f"{prefix}_resolved_grid_count",
221
+ solved_mask.sum(),
222
+ on_epoch=True,
223
+ on_step=False,
224
+ reduce_fx=torch.sum,
225
+ )
226
+ mask_no_improve = new_X.sum(dim=(1, 2)) <= x.sum(dim=(1, 2))
227
+ self.log(
228
+ f"{prefix}_improved_grid_count_{idx}",
229
+ (~mask_no_improve).sum(),
230
+ on_epoch=True,
231
+ on_step=False,
232
+ reduce_fx=torch.sum,
233
+ )
234
+ # store_new_x
235
+ # TODO keep the log in this method
236
+ # loss per epoch per model boost layer
237
+ # number of error per epoch model boost layer
238
+ # number of resolved puzzles per epochs
239
+ # threshold per epochs per model layer
240
+ # number of sudoku grid
241
+
242
+ # number of filled digits per model boost layer per epoch for both pis ans abs
243
+ # add parameter reduce_fx=torch.sum() to numbers
244
+ # th -> on_epoch=False
245
+
246
+ self.store_new_x(idx, new_X, x, y)
247
+
248
+ def store_new_x(self, idx, new_X, x, y):
249
+ mask_improve = new_X.sum(dim=(1, 2)) > x.sum(dim=(1, 2))
250
+ self.buffer.append(
251
+ idx + 1, (new_X[~mask_improve].clone(), y[~mask_improve].clone())
252
+ )
253
+ self.buffer.append(0, (new_X[mask_improve].clone(), y[mask_improve].clone()))
254
+ # TODO if improve on no improvments -> add one digit from y to new_X and ad it to idx=0
255
+
256
+ def compute_new_X(self, output, x, idx, y=None, train=True, mask_adapt_th=None):
257
+ # y could be None
258
+ prefix = "train" if train else "val"
259
+ new_X = torch.empty(output.shape, device=output.device)
260
+ # we could try to make evolv threshold here
261
+ if y is not None:
262
+ # max_th_abs = (
263
+ # output[:, 0][(x[:, 0] == 0) & (y[:, 0] == 0)].max().item()
264
+ # + self.th_epsilon
265
+ # )
266
+ max_th_abs = output[:, 0][(y[:, 0] == 0)].max().item() + self.th_epsilon
267
+ max_th_pres = (
268
+ output[:, 1][(x[:, 1] == 0) & (y[:, 1] == 0)].max().item()
269
+ + self.th_epsilon
270
+ )
271
+ if mask_adapt_th is None or (mask_adapt_th.sum()>0):
272
+ if mask_adapt_th is not None and (mask_adapt_th.sum()>0):
273
+ max_th_abs = output[mask_adapt_th, 0][(y[mask_adapt_th, 0] == 0)].max().item() + self.th_epsilon
274
+ max_th_pres = (
275
+ output[mask_adapt_th, 1][(x[mask_adapt_th, 1] == 0) & (y[mask_adapt_th, 1] == 0)].max().item()
276
+ + self.th_epsilon
277
+ )
278
+ self.threshold_abs[idx] = max(max_th_abs, self.threshold_abs[idx])
279
+ self.threshold_pres[idx] = max(max_th_pres, self.threshold_pres[idx])
280
+ self.log_dict(
281
+ {
282
+ f"{prefix}_th_abs_{idx}": self.threshold_abs[idx],
283
+ f"{prefix}_th_pres_{idx}": self.threshold_pres[idx],
284
+ },
285
+ on_step=True,
286
+ )
287
+ if not train:
288
+ self.threshold_abs_compute[idx] = max(
289
+ max_th_abs + self.margin, self.threshold_abs_compute[idx]
290
+ )
291
+ self.threshold_pres_compute[idx] = max(
292
+ max_th_pres + self.margin, self.threshold_pres_compute[idx]
293
+ )
294
+
295
+ if self.training:
296
+ new_X[:, 0] = (output[:, 0].detach() > self.threshold_abs[idx]).float()
297
+ new_X[:, 1] = (output[:, 1].detach() > self.threshold_pres[idx]).float()
298
+ else:
299
+ new_X[:, 0] = (output[:, 0].detach() > self.threshold_abs[idx]).float()
300
+ new_X[:, 1] = (output[:, 1].detach() > self.threshold_pres[idx]).float()
301
+ new_X[x.detach() == 1] = 1
302
+ if y is not None:
303
+ self.log(
304
+ f"{prefix}_count_error_grid_{idx}",
305
+ ((new_X == 1) & (y == 0)).any(dim=1).any(dim=1).sum(),
306
+ on_epoch=True,
307
+ on_step=False,
308
+ reduce_fx=torch.sum,
309
+ )
310
+ if mask_adapt_th is None:
311
+ new_X[y.detach() == 0] = 0 # do not remove the error!!!!!!
312
+ else:
313
+ y_bis = y.detach().clone()
314
+ y_bis[~mask_adapt_th]=1
315
+ new_X[y_bis==0] = 0
316
+ return new_X
317
+
318
+ # TODO add idx stuff (one lr scheduler per net)
319
+ # def on_train_epoch_end(self):
320
+ # sch = self.lr_schedulers()
321
+
322
+ # # If the selected scheduler is a ReduceLROnPlateau scheduler.
323
+ # if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau):
324
+ # sch.step(self.trainer.callback_metrics["loss"])
325
+
326
+ def on_validation_epoch_start(self) -> None:
327
+ if self.reset_threshold_on_validation:
328
+ self.threshold_abs_compute = torch.tensor(
329
+ [-10.0 for _ in range(self.nets_number)]
330
+ )
331
+ self.threshold_pres_compute = torch.tensor(
332
+ [-10.0 for _ in range(self.nets_number)]
333
+ )
334
+ else:
335
+ self.threshold_abs_compute = self.threshold_abs
336
+ self.threshold_pres_compute = self.threshold_pres
337
+
338
+ self.buffer = BufferArray(self.nets_number, self.batch_size)
339
+
340
+ def on_train_epoch_start(self) -> None:
341
+ self.buffer = BufferArray(self.nets_number, self.batch_size)
342
+ return super().on_train_epoch_start()
343
+
344
+ def on_validation_epoch_end(self):
345
+ # tensorboard = self.logger.experiment
346
+ self.threshold_abs = self.threshold_abs_compute
347
+ self.threshold_pres = self.threshold_pres_compute
348
+
349
+ schs = self.lr_schedulers()
350
+ if not isinstance(schs, list):
351
+ schs=[schs]
352
+ for idx, sch in enumerate(schs):
353
+ # sch.step(self.validation.callback_metrics["val_loss_{idx}"])
354
+ try:
355
+ sch.step(self.trainer.callback_metrics[f"val_loss_{idx}"])
356
+ except:
357
+ # print(f"val_loss_{idx} not found")
358
+ pass
359
+ # sch.step(self.trainer.callback_metrics["val_loss_"])
360
+
361
+ def on_save_checkpoint(self, checkpoint) -> None:
362
+ "Objects to include in checkpoint file"
363
+ checkpoint["ths_abs"] = self.threshold_abs
364
+ checkpoint["ths_pres"] = self.threshold_pres
365
+
366
+ def on_load_checkpoint(self, checkpoint) -> None:
367
+ "Objects to retrieve from checkpoint file"
368
+ self.threshold_abs = checkpoint["ths_abs"]
369
+ self.threshold_pres = checkpoint["ths_pres"]
370
+ self.nets = nn.ModuleList([SmallNetBis() for _ in self.threshold_abs])
371
+
372
+ def validate_grids(self, x) -> "torch.tensor":
373
+ return ~(
374
+ (self.sym_preprocess(x)[:, 17].max(dim=1).values > (1 / 8))
375
+ | (self.sym_preprocess(x)[:, 18].max(dim=1).values > (1 / 8))
376
+ | (self.sym_preprocess(x)[:, 19].max(dim=1).values > (1 / 8))
377
+ )
378
+
379
+ # steps to trial error
380
+ # - get stops
381
+ # - choose a number -> store it
382
+ # - process to get either a new stop either a a validation grid fail
383
+ # if validation grid fail back propagate
384
+ # else choose a number
385
+
386
+ # add counter to each grid,
387
+ # add id to each grid id=batch_id + position
388
+ # add validation
389
+
390
+ # if non improvment stop ->
391
+ # - check if id already exist, if true add non improve counter
392
+ # if non improve counter = 2 -> add grid to trial_error_model_buffer with 1000 step target.
393
+ # - store the grid to trial_error_model deep search dict
394
+ # - create two grids with counter to 0, same id
395
+ # add them in the buffer
396
+
397
+ # - when validation fail ->
398
+ # - check if id already exist
399
+ # if true: add grid to trial_error_model with the counter
400
+ # if false: raise error
401
+
402
+ # TODO adapt training to something softer
403
+ #
404
+ class SudokuTrialErrorLightning(SudokuLightning):
405
+ def __init__(self, **kwargs):
406
+ super().__init__(**kwargs)
407
+ self.deep_backtrack_regressor = SmallNetBis(n_output=1)
408
+ self.trial_error_buffer = Buffer(self.batch_size)
409
+ self.trial_grids = [None]
410
+ # schema:
411
+ # [
412
+ # idx:
413
+ # "tried_pos": [
414
+ # ]
415
+ # "pos": pos
416
+ # "no_improve_counter": 0
417
+ # ]
418
+ #
419
+ # self.tracking_grid = []
420
+
421
+ def copy_from_model(self, model):
422
+ self.nets = model.nets
423
+ self.threshold_pres = model.threshold_pres
424
+ self.threshold_abs = model.threshold_abs
425
+
426
+ def reg(self, x):
427
+ x_reg = self.sym_preprocess.forward(x)
428
+ x_reg = self.deep_backtrack_regressor(x_reg)
429
+ return torch.softmax(x_reg, dim=1)
430
+
431
+ def configure_optimizers(self):
432
+ # no need config scheduler -> manual optimisation
433
+ # optimizers = [torch.optim.Adam(net.parameters(), lr=self.lr) for net in self.nets]
434
+ optimizers = []
435
+ for net in self.nets:
436
+ opti = torch.optim.Adam(net.parameters(), lr=self.lr)
437
+ optimizers.append(
438
+ {
439
+ "optimizer": opti,
440
+ "lr_scheduler": ReduceLROnPlateau(opti, "min"),
441
+ }
442
+ )
443
+ optimizers.append(
444
+ {
445
+ 'optimizer': torch.optim.Adam(self.deep_backtrack_regressor.parameters(), lr=self.lr),
446
+ "lr_scheduler": ReduceLROnPlateau(opti, "min"),
447
+ }
448
+ )
449
+ return optimizers
450
+
451
+ def training_step(self, batch, batch_idx):
452
+ self.log(
453
+ "train_grid_count",
454
+ batch[0].shape[0],
455
+ reduce_fx=torch.sum,
456
+ on_epoch=True,
457
+ on_step=False,
458
+ )
459
+ x, y = batch
460
+ x_idx = torch.zeros(self.batch_size) # if we are not on trial error x_idx=0
461
+ counters = torch.zeros(self.batch_size)
462
+
463
+ self.layer_training_step(0, (x, y, x_idx, counters))
464
+ idx_while=0
465
+ while True:
466
+ idx_while+=1
467
+ if idx_while ==10000:
468
+ print('a while')
469
+ idx, batch = self.buffer.get_batch()
470
+ if batch is None:
471
+ break
472
+ # check if the train should be done by comparing lr from sch = self.lr_schedulers()
473
+ # if self.lr != sch[idx].get_last_lr():
474
+ self.layer_training_step(idx, batch)
475
+
476
+ while True:
477
+ trial_error_batch = self.trial_error_buffer.get_batch()
478
+ if trial_error_batch is None:
479
+ break
480
+ self.trial_error_training_step(trial_error_batch)
481
+
482
+ def validation_step(self, batch, batch_idx):
483
+ x, y = batch
484
+ x_idx = torch.zeros(x.shape[0], dtype=torch.long) # if we are not on trial error x_idx=0
485
+ counters = torch.zeros(x.shape[0])
486
+
487
+ self.layer_training_step(0, (x, y, x_idx, counters), train=False)
488
+ while True:
489
+ idx, batch = self.buffer.get_batch()
490
+ if batch is None:
491
+ break
492
+ # check if the train should be done by comparing lr from sch = self.lr_schedulers()
493
+ # if self.lr != sch[idx].get_last_lr():
494
+ self.layer_training_step(idx, batch, train=False)
495
+
496
+ while True:
497
+ trial_error_batch = self.trial_error_buffer.get_batch()
498
+ if trial_error_batch is None:
499
+ break
500
+ self.trial_error_training_step(trial_error_batch, train=False)
501
+
502
+
503
+ def layer_training_step(
504
+ self, idx, batch, train=True
505
+ ): # to rename to layer_training_step
506
+ x, y, x_idx, counters = batch
507
+
508
+ prefix = "train" if train else "val"
509
+ self.log(
510
+ f"{prefix}_grid_count_{idx}",
511
+ batch[0].shape[0],
512
+ reduce_fx=torch.sum,
513
+ on_epoch=True,
514
+ on_step=False,
515
+ )
516
+ output = self.forward_layer(x, idx)
517
+ loss = self.computing_loss(x[x_idx==0], y[x_idx==0], output[x_idx==0])
518
+ if train:
519
+ pass
520
+ opt = self.optimizers()[idx]
521
+ opt.zero_grad()
522
+ self.manual_backward(loss)
523
+ opt.step()
524
+
525
+ loss_0 = F.binary_cross_entropy_with_logits(output[:, [0], :], y[:, [0], :])
526
+ loss_1 = F.binary_cross_entropy_with_logits(output[:, [1], :], y[:, [1], :])
527
+ self.log_dict(
528
+ {f"{prefix}_loss_pos": loss_1, f"{prefix}_loss_neg": loss_0}, on_epoch=True
529
+ )
530
+ self.log(f"{prefix}_loss_{idx}", loss)
531
+
532
+ mask_bad_x = ((x==1)&(y==0)).any(dim=1).any(dim=1)
533
+ new_X = self.compute_new_X(output, x, idx, y, train=train, mask_adapt_th=(~mask_bad_x))
534
+ solved_mask = (new_X == y).all(dim=1).all(dim=1)
535
+
536
+ new_X = new_X[~solved_mask]
537
+ x = x[~solved_mask]
538
+ y = y[~solved_mask]
539
+ x_idx = x_idx[~solved_mask]
540
+ counters = counters[~solved_mask]
541
+
542
+ self.log(
543
+ f"{prefix}_resolved_grid_count",
544
+ solved_mask.sum(),
545
+ on_epoch=True,
546
+ on_step=False,
547
+ reduce_fx=torch.sum,
548
+ )
549
+ mask_no_improve = new_X.sum(dim=(1, 2)) <= x.sum(dim=(1, 2))
550
+ self.log(
551
+ f"{prefix}_improved_grid_count_{idx}",
552
+ (~mask_no_improve).sum(),
553
+ on_epoch=True,
554
+ on_step=False,
555
+ reduce_fx=torch.sum,
556
+ )
557
+ # self.store_new_x(idx, new_X, x, y) # TODO create another function (need to increment counter and validate)
558
+ self.process_validation(idx, new_X, x, y, x_idx, counters)
559
+
560
+ def process_validation(self, idx, new_X, x, y, x_idx, counters):
561
+ new_X = self.redresse_new_X(new_X,y,x)
562
+ mask_validated = self.validate_grids(new_X)
563
+ # mask_not_validated = (~self.validate_grids(new_X)) & ((x==0)&(y==1)).any(dim=(1,2))
564
+ mask_improve = (new_X.sum(dim=(1, 2)) > x.sum(dim=(1, 2))) & mask_validated
565
+ mask_not_improved = (new_X.sum(dim=(1, 2)) == x.sum(dim=(1, 2))) & mask_validated
566
+
567
+
568
+ for i, (failed_idx, failed_counter, s_new_X, s_y) in enumerate(zip(
569
+ x_idx[~mask_validated],
570
+ counters[~mask_validated],
571
+ new_X[~mask_validated],
572
+ y[~mask_validated],
573
+ )):
574
+ # when we find failed:
575
+ # - we store good grid to continue the process # /!\ it is not necessary, the second half will continue to process.
576
+ # - we store the initial grid with the score (to traine the regressor)
577
+ if failed_idx == 0:
578
+ self.failed_batch = (x[~mask_validated][i], s_y)
579
+ raise ValueError("validation error on no trial-error grid")
580
+ if not ((x[~mask_validated][i]==0)&(s_y==1)).any():
581
+ raise ValueError()
582
+ is_pos = copysign(1, failed_idx)==1
583
+ trial_grid: TrialGrid = self.trial_grids[int(abs(failed_idx))]
584
+ if is_pos:
585
+ trial_grid.pos_result = 'fail'
586
+ else:
587
+ trial_grid.neg_result = 'fail'
588
+
589
+ self.process_search_store_grid(int(abs(failed_idx)), trial_grid, s_y)
590
+
591
+ if idx == self.nets_number - 1:
592
+ for no_improved_idx, s_new_X, s_y in zip(
593
+ x_idx[mask_not_improved], new_X[mask_not_improved], y[mask_not_improved]
594
+ ):
595
+ if no_improved_idx == 0:
596
+ self.search_trial_buffer_trials(s_new_X, s_y)
597
+ continue
598
+
599
+ is_pos = copysign(1, no_improved_idx)==1
600
+ trial_grid: TrialGrid = self.trial_grids[int(abs(no_improved_idx.item()))]
601
+ if is_pos:
602
+ trial_grid.pos_result = 'no_improved'
603
+ else:
604
+ trial_grid.neg_result = 'no_improved'
605
+ assert s_new_X.sum()> trial_grid.initial_grid.sum()
606
+
607
+ self.process_search_store_grid(int(abs(no_improved_idx)), trial_grid, s_y)
608
+
609
+ self.buffer.append(
610
+ idx + 1,
611
+ (new_X[mask_not_improved].clone(), y[mask_not_improved].clone(), x_idx[mask_not_improved].clone(), counters[mask_not_improved].clone()),
612
+ )
613
+ # assert mask_improve.sum()>0
614
+ if ((new_X[mask_improve & (x_idx.to(self.device)==0)]==1) & (y[mask_improve & (x_idx.to(self.device)==0)]==0)).any():
615
+ self.failed_batch=(x[mask_improve & (x_idx.to(self.device)==0)],y[mask_improve & (x_idx.to(self.device)==0)] )
616
+ raise ValueError()
617
+ self.buffer.append(
618
+ 0,
619
+ (new_X[mask_improve].clone(), y[mask_improve].clone(), x_idx[mask_improve].clone(), counters[mask_improve].clone() + 1),
620
+ )
621
+
622
+ def process_search_store_grid(self, idx, trial_grid: TrialGrid, s_y):
623
+ """_summary_
624
+ if score is 1:
625
+ great
626
+ if fail -> the second one should continue (it has his id if it stopped)
627
+ so do nothing
628
+ if no_improved ->
629
+ trial_error and reset trial_error_grid
630
+ if score is not None -> store the new grid in the trial_error_buffer
631
+
632
+ if score is -1 => also search_trial and store.
633
+ if store is 1 =>
634
+ if both result are here:
635
+ get the no_improved -> search_trial and store on a new grid
636
+ if one complete grid -> set grid place to None
637
+ else: wait
638
+
639
+
640
+ if score is None:
641
+ if grid_idx==-1 or oposite_grid failed:
642
+ we create a new_idx, and store stuff.
643
+ else:
644
+ we store the grid (in case the second grid fail)
645
+ we increment the non_improvement counter
646
+
647
+ if non_improvement counter = 2:
648
+ we add the initial grid to the search training buffer
649
+ we process the search training engine to find another grid postion
650
+ else:
651
+ we add the initial grid to the search training buffer
652
+ if a non improved grid is store we create a new_idx and store stuf.
653
+
654
+ Args:
655
+ grid_idx (_type_): _description_
656
+ score (_type_): _description_
657
+ s_new_X (_type_): _description_
658
+ s_y (_type_): _description_
659
+ """
660
+ score = trial_grid.score()
661
+ if score is None:
662
+ self.trial_grids[idx]=trial_grid
663
+ return
664
+ # add grid to buffer (initial_grid, score)
665
+ self.trial_error_buffer.append((
666
+ trial_grid.initial_grid.view(-1,2,729),
667
+ torch.tensor([score,],dtype=torch.float).to(self.device),
668
+ torch.tensor([trial_grid.row_col_digit_position,], dtype=torch.long).to(self.device),
669
+ ))
670
+
671
+ # find the no_improve_grid ~and search_trial~ and add it to buffer
672
+ if trial_grid.neg_result == 'no_improved':
673
+ if trial_grid.pos_result == 'no_improved':
674
+ trial_grid.tried_grid.append(trial_grid.row_col_digit_position)
675
+ trial_grid.neg_result= None
676
+ trial_grid.pos_result= None
677
+ self.trial_grids[idx] = trial_grid
678
+ self.search_trial_buffer_trials(None, s_y, idx)
679
+ # new trial with same idx
680
+ return
681
+ # add to buffer neg grid
682
+ # we get back the initial grid
683
+ # set the correct row col digit
684
+ # add it the buffer
685
+ # set trial_grids to None
686
+ grid_neg = deepcopy(trial_grid.initial_grid)
687
+ grid_neg[0,trial_grid.row_col_digit_position] = 1
688
+ if ((grid_neg==1) & (s_y==0)).any():
689
+ raise ValueError()
690
+ self.buffer.append(
691
+ 0,
692
+ (
693
+ grid_neg.view(-1,2,729),
694
+ s_y.clone().view(-1,2,729),
695
+ torch.tensor([0]),
696
+ torch.tensor([0]),
697
+ )
698
+ )
699
+ self.trial_grids[idx] = None
700
+ return
701
+ if trial_grid.pos_result == 'no_improved':
702
+ grid_pos = deepcopy(trial_grid.initial_grid)
703
+ grid_pos[1,trial_grid.row_col_digit_position] = 1
704
+ if ((grid_pos==1) & (s_y==0)).any():
705
+ raise ValueError()
706
+ self.buffer.append(
707
+ 0,
708
+ (
709
+ grid_pos.view(-1,2,729),
710
+ s_y.clone().view(-1,2,729),
711
+ torch.tensor([0]),
712
+ torch.tensor([0]),
713
+ )
714
+ )
715
+ self.trial_grids[idx] = None
716
+ # add to buffer pos grid
717
+ return
718
+
719
+
720
+ # if complete: replace grid by none.
721
+ if "complete" in [trial_grid.neg_result, trial_grid.pos_result]:
722
+ self.trial_grids[idx]=None
723
+
724
+ # def store_new_trial_error_grid(self, new_X, y):
725
+ # """build a new idx add the grid in the tracking stuff
726
+ # and add grid in the buffer
727
+
728
+ # Args:
729
+ # new_X (_type_): _description_
730
+ # y (_type_): _description_
731
+ # """
732
+ # ...
733
+
734
+ # def store_training_trail_search_batch(self, grid, score):
735
+ # """store grid to train trial_search nn model
736
+
737
+ # Args:
738
+ # grid (_type_): _description_
739
+ # score (_type_): _description_
740
+ # """
741
+ # ...
742
+
743
+ def search_trial(self, s_new_X, tried_pos):
744
+ """use the trail_search nn model to probe a new
745
+
746
+ Args:
747
+ s_new_X (_type_): _description_
748
+ tried_pos (_type_): _description_
749
+ """
750
+ mask_possibility = s_new_X.sum(dim=0)==0
751
+ for pos in tried_pos:
752
+ mask_possibility[pos]=False
753
+ if mask_possibility.sum()==0:
754
+ print('mask_possible=0')
755
+ raise ValueError()
756
+
757
+ with torch.no_grad():
758
+ x_reg = self.sym_preprocess.forward(s_new_X.view(1,2,-1))
759
+ output = self.deep_backtrack_regressor(x_reg)
760
+ # shape (1, 729)
761
+ # can be regression -> i want the smallest
762
+ # can be logistic regression -> i want the smallest
763
+ # if i do softmax -> i can add 1 to each tried pos
764
+ output = torch.softmax(output[0][0],dim=0)
765
+ # for pos in tried_pos:
766
+ # output[pos]=1
767
+ output[~mask_possibility]+=1
768
+ return torch.argmin(output, dim=0).item()
769
+
770
+ def search_trial_buffer_trials(self, s_new_X, s_y, idx_trial_grids=None):
771
+
772
+ if idx_trial_grids is None:
773
+ row_col_digit_trial = self.search_trial(s_new_X, [])
774
+ trial_grid = TrialGrid(s_new_X, row_col_digit_trial)
775
+ self.trial_grids.append(TrialGrid(s_new_X, row_col_digit_trial))
776
+ idx_trial_grids = len(self.trial_grids)-1
777
+ else:
778
+ trial_grid = self.trial_grids[idx_trial_grids]
779
+ s_new_X = trial_grid.initial_grid
780
+ row_col_digit_trial = self.search_trial(s_new_X, trial_grid.tried_grid)
781
+ trial_grid.row_col_digit_position = row_col_digit_trial
782
+ self.trial_grids[idx_trial_grids] = trial_grid
783
+
784
+ # and we add both into buffer.
785
+ grid_pos = deepcopy(s_new_X)
786
+ grid_neg = deepcopy(s_new_X)
787
+ grid_pos[1,row_col_digit_trial] = 1
788
+ grid_neg[0,row_col_digit_trial] = 1
789
+ self.buffer.append(
790
+ 0,
791
+ (
792
+ torch.stack([grid_pos,grid_neg], dim=0),
793
+ torch.stack([s_y.clone(),s_y.clone()], dim=0),
794
+ torch.tensor([idx_trial_grids, -idx_trial_grids]),
795
+ torch.tensor([0, 0]),
796
+ )
797
+ )
798
+
799
+ def trial_error_training_step(self, batch, train=True):
800
+ x, y, row_col_digit = batch
801
+ prefix = "train" if train else "val"
802
+ self.log(
803
+ f"{prefix}_grid_count_trial_error_training",
804
+ batch[0].shape[0],
805
+ reduce_fx=torch.sum,
806
+ on_epoch=True,
807
+ on_step=False,
808
+ )
809
+
810
+
811
+ x_reg = self.sym_preprocess.forward(x)
812
+ output = self.deep_backtrack_regressor(x_reg)
813
+ loss = nn.functional.binary_cross_entropy_with_logits(output[[i for i in range(self.batch_size)], 0, row_col_digit], y, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)
814
+ # # depending the distribution of the target, the weight could be different
815
+ # loss = binary (output[:,0,row_col_digit], y)
816
+ # loss = self.computing_loss(x, y, output)
817
+ if train:
818
+ opt = self.optimizers()[-1]
819
+ opt.zero_grad()
820
+ self.manual_backward(loss)
821
+ opt.step()
822
+
823
+ self.log(f"{prefix}_loss_trial_error", loss)
824
+ self.log(f"{prefix}_loss_{self.nets_number}", loss)
825
+ self.log(f"{prefix}_y_pos_trial_error", y.sum())
826
+ self.log(f"{prefix}_y_neg_trial_eror", y.shape[0]-y.sum())
827
+
828
+ def predict(self, x):
829
+ """ return an improvement of x
830
+
831
+ """
832
+
833
+ idx, new_X = self.forward(x.view(-1,2,729))
834
+ if (new_X.sum()>x.sum()) or (new_X.sum()==729):
835
+ return new_X
836
+ else:
837
+ # call trial error until we find a solution
838
+ tried_position = []
839
+ while True:
840
+ pos = self.search_trial(x.view(2,729), tried_position)
841
+ tried_position.append(pos)
842
+ # creat pos neg tensor
843
+ grid_pos = deepcopy(x.view(2,729))
844
+ grid_neg = deepcopy(x.view(2,729))
845
+ grid_pos[1,pos] = 1
846
+ grid_neg[0,pos] = 1
847
+ X_tried = torch.stack([grid_neg, grid_pos], dim=0)
848
+ # process it
849
+ while True:
850
+ idx, new_X = self.forward(X_tried)
851
+ mask_validated = self.validate_grids(new_X)
852
+ if mask_validated.sum()<2:
853
+ x[0, mask_validated, pos] = 1 # TODO check if it work
854
+ return x
855
+ if X_tried.sum()==new_X.sum():
856
+ # if both stop to improve -> break it will tried an new pos
857
+ break
858
+ mask_complete = (X_tried.sum(dim=1)==729)# check if it works
859
+ if mask_complete.sum()>0:
860
+ x[0, mask_complete, pos] = 1
861
+ return x
862
+ X_tried = new_X
863
+ # if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
864
+
865
+ def on_validation_epoch_start(self) -> None:
866
+ # self.buffer = BufferArray(self.nets_number, self.batch_size)
867
+ self.trial_error_buffer = Buffer(self.batch_size)
868
+ self.trial_grids = [None]
869
+ return super().on_validation_epoch_start()
870
+
871
+ def on_train_epoch_start(self) -> None:
872
+ self.trial_error_buffer = Buffer(self.batch_size)
873
+ self.trial_grids = [None]
874
+ return super().on_train_epoch_start()
875
+
876
+ def redresse_new_X(self, new_X,y,x):
877
+ mask_bad_x = ((x==1)&(y==0)).any(dim=1).any(dim=1)
878
+ y_bis = y.clone()
879
+ y_bis[mask_bad_x]=1
880
+ new_X[y_bis==0]=0
881
+ return new_X
882
+
883
+
884
+ # ADD threshold adjustment during prediction
885
+ # or maybe validate? on it? bah oui!
sudoku/trial_grid.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ class TrialGrid():
4
+ def __init__(self, initial_grid, row_col_digit_position):
5
+ self.initial_grid=initial_grid
6
+ self.row_col_digit_position = row_col_digit_position
7
+ self.neg_result=None
8
+ self.pos_result=None
9
+ self.tried_grid = []
10
+
11
+ def score(self):
12
+ # currently return 0 if find a good path else 1 (I'm too lazy)
13
+ if "fail" in [self.neg_result, self.pos_result]:
14
+ return 0
15
+ if "complete" in [self.neg_result, self.pos_result]:
16
+ return 0
17
+ if None in [self.neg_result, self.pos_result]:
18
+ return None
19
+ return 1
tensorboard.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ ~/git/prose-dl/venv/bin/tensorboard --logdir lightning_logs/
tests/test_buffer.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytest
3
+ import numpy as np
4
+
5
+ from sudoku.buffer import BufferArray
6
+
7
+
8
+ # test fait to get batch
9
+ def test_get_batch_on_empty():
10
+ ba = BufferArray(array_size=3, batch_size=2)
11
+ with pytest.raises(AssertionError):
12
+ ba.get_batch()
13
+
14
+
15
+ def test_append_get_batch():
16
+ x = torch.tensor(
17
+ [
18
+ [[1, 2, 3], [4, 5, 6]],
19
+ [[11, 12, 13], [14, 15, 16]],
20
+ ]
21
+ )
22
+ y = torch.tensor(
23
+ [
24
+ 1,
25
+ 2,
26
+ ]
27
+ )
28
+ ba = BufferArray(array_size=3, batch_size=2)
29
+ ba.append(0, (x, y))
30
+ idx, batch = ba.get_batch()
31
+ assert batch[0].shape == x.shape
32
+ assert batch[1].shape == y.shape
33
+ assert idx == 0
34
+ assert len(ba.buffers[0][0]) == 0
35
+
36
+
37
+ @pytest.mark.parametrize("idx", [0, 1])
38
+ def test_append_get_batch_2(idx):
39
+ x = torch.tensor(
40
+ [
41
+ [[1, 2, 3], [4, 5, 6]],
42
+ [[11, 12, 13], [14, 15, 16]],
43
+ ]
44
+ )
45
+ y = torch.tensor(
46
+ [
47
+ 1,
48
+ 2,
49
+ ]
50
+ )
51
+ ba = BufferArray(array_size=3, batch_size=2)
52
+ ba.append(idx, (x, y))
53
+ ba.append(idx, (x, y))
54
+ batch_idx, batch = ba.get_batch()
55
+ assert batch[0].shape == x.shape
56
+ assert batch[1].shape == y.shape
57
+ assert idx == batch_idx
58
+ assert len(ba.buffers[idx][0]) == 2
59
+
60
+
61
+ # test buffer array ->
62
+
63
+ # test -> faire des scenarios
64
+ # append
65
+ # append
66
+ # append
67
+ # batch batch
68
+
69
+ # verifier que chaque batch est bien séparer
70
+ # que la limite de batch est bien respecté
71
+ # que le sampling se fait.
72
+
73
+ # que les assert pops
tests/test_helper.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sudoku.helper import compute_loss
2
+ import torch
3
+
4
+
5
+ def test_compute_loss():
6
+ x = torch.zeros((3, 2, 729))
7
+ y = torch.zeros((3, 2, 729))
8
+ output = torch.zeros((3, 2, 729))
9
+ y[:, 0, 0] = 1
10
+ output[0, 0, 0] = 0.1
11
+ output[1, 0, 0] = 0.1
12
+ output[2, 0, 1] = 0.1
13
+ new_x = (output > 0).type("torch.FloatTensor")
14
+ loss_error, loss_no_improve, n_error, n_no_improve = compute_loss(
15
+ x, y, output, new_x
16
+ )
tests/test_loader.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sudoku.loader import data_loader, DataIterBuffer, get_datasets, train_dataset
2
+ from torch.utils.data import DataLoader
3
+
4
+
5
+ def test_data_loader():
6
+ train_loader, test_loader = data_loader()
7
+
8
+ X, Y = next(iter(train_loader))
9
+ assert X.shape == Y.shape == (32, 2, 9 * 9 * 9)
10
+ assert (Y.sum(-1)[:, 0] == (9 * 9 * 8)).all()
11
+ assert (Y.sum(-1)[:, 1] == (9 * 9)).all()
12
+ assert (X.sum(-1)[:, 0] < (9 * 9 * 8)).all()
13
+ assert (X.sum(-1)[:, 1] < (9 * 9)).all()
14
+
15
+
16
+ def test_data_iter_buffer():
17
+ i = 0
18
+ data_iter = DataIterBuffer(train_dataset)
19
+ data_loader_buffer = DataLoader(data_iter, 32)
20
+ for X, Y in data_loader_buffer:
21
+ i += 1
22
+ if i % 3 == 0:
23
+ data_iter.append(X, Y)
24
+ assert i == 29
25
+
26
+
27
+ def test_max_holes():
28
+ train_dataset, test_dataset = get_datasets(train_size=2, test_size=2, max_holes=2)
29
+ X, Y = next(iter(train_dataset))
30
+ x_holes = X[1].reshape(9, 9, 9).sum(-1) == 0
31
+ print(X.shape, x_holes.shape)
32
+ assert x_holes.sum() == 2
tests/test_models.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from sudoku.models import SudokuNet
4
+
5
+
6
+ def test_same_output_under_rotation():
7
+ model = SudokuNet()
8
+ arr1 = torch.zeros((1, 2, 9, 9, 9))
9
+ arr1[0, 0, 1, 2, 3] = 1
10
+
11
+ output_1 = model.forward(arr1.view(1, 2, 9 * 9 * 9))
12
+ assert output_1.shape == (1, 2, 9 * 9 * 9), output_1
13
+ arr2 = torch.zeros((1, 2, 9, 9, 9))
14
+ arr2[0, 0, 2, 3, 4] = 1
15
+
16
+ output_2 = model.forward(arr2.view(1, 2, 9 * 9 * 9))
17
+ assert (
18
+ output_1.view(1, 2, 9, 9, 9)[0, 0, 1, 2, 3]
19
+ == output_2.view(1, 2, 9, 9, 9)[0, 0, 2, 3, 4]
20
+ )
21
+ assert (
22
+ output_1.view(1, 2, 9, 9, 9)[0, 0, 1, 2, 4]
23
+ == output_2.view(1, 2, 9, 9, 9)[0, 0, 2, 3, 6]
24
+ )
25
+ assert (
26
+ output_1.view(1, 2, 9, 9, 9)[0, 1, 1, 2, 4]
27
+ == output_2.view(1, 2, 9, 9, 9)[0, 1, 2, 3, 6]
28
+ )
29
+ assert (
30
+ output_1.view(1, 2, 9, 9, 9)[0, 1, 2, 2, 4]
31
+ == output_2.view(1, 2, 9, 9, 9)[0, 1, 1, 3, 6]
32
+ )
33
+ assert (
34
+ output_1.view(1, 2, 9, 9, 9)[0, 1, 2, 3, 4]
35
+ == output_2.view(1, 2, 9, 9, 9)[0, 1, 1, 2, 6]
36
+ )
37
+
38
+
39
+ # 0, 1, 2 | 3, 4, 5 | 6, 7, 8
40
+ # 0, 1, a | 3, 4, 5 | 6, 7, 8
41
+ # 0, 1, 2 | b, 4, 5 | 6, 7, 8
42
+ # ----------------------------
43
+ # 0, 1, 2 | 3, 4, 5 | 6, 7, 8
44
+ # 0, 1, 2 | 3, 4, 5 | 6, 7, 8
45
+ # 0, 1, 2 | 3, 4, 5 | 6, 7, 8
46
+ # ----------------------------
47
+ # 0, 1, 2 | 3, 4, 5 | 6, 7, 8
48
+ # 0, 1, 2 | 3, 4, 5 | 6, 7, 8
49
+ # 0, 1, 2 | 3, 4, 5 | 6, 7, 8
tests/test_symetries.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import numpy as np
3
+
4
+ from sudoku.symetries import mat_sym
5
+
6
+
7
+ def test_same_output_under_rotation():
8
+ assert mat_sym.shape == (9 * 9 * 9, 16, 9 * 9 * 9)
9
+
10
+ arr1 = np.zeros((9, 9, 9))
11
+ arr1[0, 1, 2] = 1
12
+ arr1_sym = np.dot(mat_sym, arr1.reshape(9 * 9 * 9, -1)).reshape(9, 9, 9, -1)
13
+ arr2 = np.zeros((9, 9, 9))
14
+ arr2[1, 2, 3] = 1
15
+ arr2_sym = np.dot(mat_sym, arr2.reshape(9 * 9 * 9, -1)).reshape(9, 9, 9, -1)
16
+ assert arr1_sym.shape == (9, 9, 9, 16)
17
+ assert (arr1_sym[0, 1, 2] == arr2_sym[1, 2, 3]).all()
18
+ assert (arr1_sym[5, 1, 2] == arr2_sym[5, 2, 3]).all()
19
+ assert (arr1_sym[0, 1, 1] == arr2_sym[1, 2, 4]).all()
20
+ assert (arr1_sym[0, 5, 1] == arr2_sym[1, 8, 4]).all()
21
+ assert (arr1_sym[6, 5, 1] == arr2_sym[6, 8, 4]).all()