HugoVoxx commited on
Commit
54f445b
·
verified ·
1 Parent(s): da8f7fe

Delete ag4masses/alphageometry/lm_inference_test.py

Browse files
ag4masses/alphageometry/lm_inference_test.py DELETED
@@ -1,89 +0,0 @@
1
- # Copyright 2023 DeepMind Technologies Limited
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Unit tests for lm_inference.py."""
17
- import os
18
- import unittest
19
-
20
- from absl import flags
21
- from absl.testing import absltest
22
- import lm_inference as lm
23
-
24
-
25
- _DATA_PATH = flags.DEFINE_string('data_path', '', 'path to ckpt and vocab.')
26
- _MELIAD_PATH = flags.DEFINE_string(
27
- 'meliad_path', '', 'path to meliad repository.'
28
- ) # pylint: disable=line-too-long
29
-
30
-
31
- class LmInferenceTest(unittest.TestCase):
32
-
33
- @classmethod
34
- def setUpClass(cls):
35
- super().setUpClass()
36
- gin_file = [
37
- 'base_htrans.gin',
38
- 'size/medium_150M.gin',
39
- 'options/positions_t5.gin',
40
- 'options/lr_cosine_decay.gin',
41
- 'options/seq_1024_nocache.gin',
42
- 'geometry_150M_generate.gin',
43
- ]
44
-
45
- gin_param = [
46
- 'DecoderOnlyLanguageModelGenerate.output_token_losses=True',
47
- 'TransformerTaskConfig.batch_size=2',
48
- 'TransformerTaskConfig.sequence_length=128',
49
- 'Trainer.restore_state_variables=False',
50
- ]
51
-
52
- gin_search_paths = [
53
- os.path.join(_MELIAD_PATH.value, 'transformer/configs'),
54
- os.getcwd(),
55
- ]
56
-
57
- vocab_path = os.path.join(_DATA_PATH.value, 'geometry.757.model')
58
-
59
- lm.parse_gin_configuration(gin_file, gin_param, gin_paths=gin_search_paths)
60
-
61
- cls.loaded_lm = lm.LanguageModelInference(
62
- vocab_path, _DATA_PATH.value, mode='beam_search'
63
- )
64
-
65
- def test_lm_decode(self):
66
- outputs = LmInferenceTest.loaded_lm.beam_decode(
67
- '{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c'
68
- ' {F1} x00',
69
- eos_tokens=[';'],
70
- )
71
- self.assertEqual(
72
- outputs['seqs_str'],
73
- ['e : D a b c e 02 D a c b e 03 ;', 'e : C a c e 02 C b d e 03 ;'],
74
- )
75
-
76
- def test_lm_score_may_fail_numerically_for_external_meliad(self):
77
- outputs = LmInferenceTest.loaded_lm.beam_decode(
78
- '{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c'
79
- ' {F1} x00',
80
- eos_tokens=[';'],
81
- )
82
- self.assertEqual(
83
- outputs['scores'],
84
- [-1.18607294559478759765625, -1.10228693485260009765625],
85
- )
86
-
87
-
88
- if __name__ == '__main__':
89
- absltest.main()