fix code
Browse filesSigned-off-by: QscQ <[email protected]>
- LICENSE +201 -0
- README.md +224 -3
- config.json +127 -0
- configuration_minimax_m1.py +152 -0
- docs/function_call_guide.md +270 -0
- docs/function_call_guide_cn.md +267 -0
- docs/transformers_deployment_guide.md +97 -0
- docs/transformers_deployment_guide_cn.md +95 -0
- docs/vllm_deployment_guide.md +171 -0
- docs/vllm_deployment_guide_cn.md +166 -0
- figures/TextBench.png +0 -0
- figures/wechat-qrcode.jpeg +0 -0
- main.py +106 -0
- merges.txt +0 -0
- model.safetensors.index.json +0 -0
- modeling_minimax_m1.py +1701 -0
- tokenizer.json +0 -0
- tokenizer_config.json +10 -0
- vocab.json +0 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2025 MiniMax
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,224 @@
|
|
1 |
-
---
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
pipeline_tag: text-generation
|
3 |
+
license: apache-2.0
|
4 |
+
library_name: transformers
|
5 |
+
tags:
|
6 |
+
- vllm
|
7 |
+
---
|
8 |
+
|
9 |
+
<div align="center">
|
10 |
+
|
11 |
+
<svg width="60%" height="auto" viewBox="0 0 144 48" fill="none" xmlns="http://www.w3.org/2000/svg">
|
12 |
+
<path d="M26.6782 7.96523C26.6782 7.02436 25.913 6.26087 24.9739 6.26087C24.0348 6.26087 23.2695 7.0261 23.2695 7.96523V36.2139C23.2695 38.4 21.4904 40.1791 19.3043 40.1791C17.1183 40.1791 15.3391 38.4 15.3391 36.2139V18.0904C15.3391 17.1496 14.5739 16.3861 13.6348 16.3861C12.6956 16.3861 11.9304 17.1513 11.9304 18.0904V25.7722C11.9304 27.9583 10.1513 29.7374 7.96518 29.7374C5.7791 29.7374 4 27.9583 4 25.7722V22.9878C4 22.3635 4.50609 21.8574 5.13043 21.8574C5.75478 21.8574 6.26087 22.3635 6.26087 22.9878V25.7722C6.26087 26.713 7.02605 27.4765 7.96518 27.4765C8.90431 27.4765 9.66954 26.7113 9.66954 25.7722V18.0904C9.66954 15.9044 11.4487 14.1252 13.6348 14.1252C15.8209 14.1252 17.6 15.9044 17.6 18.0904V36.2139C17.6 37.1548 18.3652 37.9183 19.3043 37.9183C20.2435 37.9183 21.0087 37.153 21.0087 36.2139V25.1322V7.96523C21.0087 5.77914 22.7878 4 24.9739 4C27.16 4 28.9391 5.77914 28.9391 7.96523V31.3565C28.9391 31.9809 28.433 32.487 27.8087 32.487C27.1843 32.487 26.6782 31.9809 26.6782 31.3565V7.96523ZM47.6539 14.1252C45.4678 14.1252 43.6887 15.9044 43.6887 18.0904V33.2296C43.6887 34.1704 42.9235 34.9339 41.9843 34.9339C41.0452 34.9339 40.28 34.1687 40.28 33.2296V7.96523C40.28 5.77914 38.5008 4 36.3148 4C34.1287 4 32.3496 5.77914 32.3496 7.96523V40.0348C32.3496 40.9756 31.5843 41.7391 30.6452 41.7391C29.7061 41.7391 28.9409 40.9739 28.9409 40.0348V36.0643C28.9409 35.44 28.4348 34.9339 27.8104 34.9339C27.1861 34.9339 26.68 35.44 26.68 36.0643V40.0348C26.68 42.2209 28.4591 44 30.6452 44C32.8313 44 34.6104 42.2209 34.6104 40.0348V7.96523C34.6104 7.02436 35.3756 6.26087 36.3148 6.26087C37.2539 6.26087 38.0191 7.0261 38.0191 7.96523V33.2296C38.0191 35.4156 39.7982 37.1948 41.9843 37.1948C44.1704 37.1948 45.9496 35.4156 45.9496 33.2296V18.0904C45.9496 17.1496 46.7148 16.3861 47.6539 16.3861C48.593 16.3861 49.3582 17.1513 49.3582 18.0904V31.3565C49.3582 31.9809 49.8643 32.487 50.4887 32.487C51.113 32.487 51.6191 31.9809 51.6191 31.3565V18.0904C51.6191 15.9044 49.84 14.1252 47.6539 14.1252Z" fill="url(#paint0_linear_17_483)"/>
|
13 |
+
<path d="M68.7671 16.5615H71.2541C71.3254 16.5615 71.3845 16.5859 71.435 16.6363C71.4836 16.6868 71.5097 16.7459 71.5097 16.8172V31.1824C71.5097 31.2537 71.4854 31.3128 71.435 31.3633C71.3845 31.4137 71.3254 31.4381 71.2541 31.4381H68.7671C68.6958 31.4381 68.6367 31.4137 68.5862 31.3633C68.5358 31.3146 68.5115 31.2537 68.5115 31.1824V21.812C68.5115 21.7563 68.4976 21.7268 68.4697 21.7268C68.4419 21.7268 68.4123 21.7476 68.3845 21.7911L66.1323 25.318C66.061 25.4311 65.9619 25.4885 65.8349 25.4885H64.581C64.4541 25.4885 64.3549 25.4328 64.2836 25.318L62.0315 21.7911C62.0036 21.7494 61.9741 21.7302 61.9462 21.7372C61.9184 21.7441 61.9045 21.7772 61.9045 21.8328V31.1824C61.9045 31.2537 61.8802 31.3128 61.8297 31.3633C61.7793 31.4137 61.7202 31.4381 61.6489 31.4381H59.1619C59.0906 31.4381 59.0315 31.4137 58.981 31.3633C58.9306 31.3146 58.9062 31.2537 58.9062 31.1824V16.8172C58.9062 16.7459 58.9306 16.6868 58.981 16.6363C59.0315 16.5859 59.0906 16.5615 59.1619 16.5615H61.6489C61.7758 16.5615 61.8749 16.6189 61.9462 16.732L65.1341 21.6833C65.1758 21.7685 65.2193 21.7685 65.261 21.6833L68.4697 16.732C68.541 16.6189 68.6402 16.5615 68.7671 16.5615Z" fill="currentColor"/>
|
14 |
+
<path d="M74.1764 31.3633C74.1259 31.3146 74.1016 31.2537 74.1016 31.1824V16.8172C74.1016 16.7459 74.1259 16.6868 74.1764 16.6363C74.2268 16.5859 74.2859 16.5615 74.3572 16.5615H76.8442C76.9155 16.5615 76.9746 16.5859 77.0251 16.6363C77.0737 16.6868 77.0998 16.7459 77.0998 16.8172V31.1824C77.0998 31.2537 77.0755 31.3128 77.0251 31.3633C76.9746 31.4137 76.9155 31.4381 76.8442 31.4381H74.3572C74.2859 31.4381 74.2268 31.4137 74.1764 31.3633Z" fill="currentColor"/>
|
15 |
+
<path d="M88.3066 16.6361C88.3553 16.5874 88.4162 16.5613 88.4875 16.5613H90.9744C91.0457 16.5613 91.1049 16.5857 91.1553 16.6361C91.204 16.6865 91.2301 16.7457 91.2301 16.817V31.1822C91.2301 31.2535 91.2057 31.3126 91.1553 31.363C91.1049 31.4135 91.0457 31.4378 90.9744 31.4378H88.5727C88.4301 31.4378 88.331 31.3822 88.2753 31.2674L82.771 22.1717C82.7431 22.13 82.7136 22.1109 82.6858 22.1178C82.6579 22.1248 82.644 22.1578 82.644 22.2135L82.6858 31.1805C82.6858 31.2518 82.6614 31.3109 82.611 31.3613C82.5606 31.4117 82.5014 31.4361 82.4301 31.4361H79.9431C79.8718 31.4361 79.8127 31.4117 79.7623 31.3613C79.7118 31.3126 79.6875 31.2518 79.6875 31.1805V16.8152C79.6875 16.7439 79.7118 16.6848 79.7623 16.6344C79.8127 16.5839 79.8718 16.5596 79.9431 16.5596H82.3449C82.4858 16.5596 82.5849 16.617 82.6423 16.73L88.124 25.7822C88.1518 25.8239 88.1797 25.8431 88.2092 25.8361C88.2371 25.8292 88.251 25.7978 88.251 25.7404L88.2301 16.8152C88.2301 16.7439 88.2545 16.6848 88.3049 16.6344L88.3066 16.6361Z" fill="currentColor"/>
|
16 |
+
<path d="M93.8951 31.3633C93.8446 31.3146 93.8203 31.2537 93.8203 31.1824V16.8172C93.8203 16.7459 93.8446 16.6868 93.8951 16.6363C93.9455 16.5859 94.0047 16.5615 94.076 16.5615H96.5629C96.6342 16.5615 96.6934 16.5859 96.7438 16.6363C96.7925 16.6868 96.8186 16.7459 96.8186 16.8172V31.1824C96.8186 31.2537 96.7942 31.3128 96.7438 31.3633C96.6934 31.4137 96.6342 31.4381 96.5629 31.4381H94.076C94.0047 31.4381 93.9455 31.4137 93.8951 31.3633Z" fill="currentColor"/>
|
17 |
+
<path d="M109.267 16.5615H111.754C111.825 16.5615 111.885 16.5859 111.935 16.6363C111.984 16.6868 112.01 16.7459 112.01 16.8172V31.1824C112.01 31.2537 111.985 31.3128 111.935 31.3633C111.885 31.4137 111.825 31.4381 111.754 31.4381H109.267C109.196 31.4381 109.137 31.4137 109.086 31.3633C109.036 31.3146 109.011 31.2537 109.011 31.1824V21.812C109.011 21.7563 108.998 21.7268 108.97 21.7268C108.942 21.7268 108.912 21.7476 108.885 21.7911L106.632 25.318C106.561 25.4311 106.462 25.4885 106.335 25.4885H105.081C104.954 25.4885 104.855 25.4328 104.784 25.318L102.531 21.7911C102.504 21.7494 102.474 21.7302 102.446 21.7372C102.418 21.7441 102.405 21.7772 102.405 21.8328V31.1824C102.405 31.2537 102.38 31.3128 102.33 31.3633C102.279 31.4137 102.22 31.4381 102.149 31.4381H99.6619C99.5906 31.4381 99.5315 31.4137 99.481 31.3633C99.4306 31.3146 99.4062 31.2537 99.4062 31.1824V16.8172C99.4062 16.7459 99.4306 16.6868 99.481 16.6363C99.5315 16.5859 99.5906 16.5615 99.6619 16.5615H102.149C102.276 16.5615 102.375 16.6189 102.446 16.732L105.634 21.6833C105.676 21.7685 105.719 21.7685 105.761 21.6833L108.97 16.732C109.041 16.6189 109.14 16.5615 109.267 16.5615Z" fill="currentColor"/>
|
18 |
+
<path d="M123.782 31.2241L123.144 29.1424C123.116 29.0867 123.079 29.0572 123.038 29.0572H117.81C117.768 29.0572 117.732 29.085 117.704 29.1424L117.088 31.2241C117.046 31.3668 116.954 31.4363 116.812 31.4363H114.112C114.027 31.4363 113.963 31.412 113.921 31.3615C113.879 31.3128 113.871 31.2381 113.9 31.1389L118.49 16.7737C118.532 16.6328 118.624 16.5615 118.766 16.5615H122.102C122.243 16.5615 122.335 16.6328 122.379 16.7737L126.968 31.1389C126.982 31.1668 126.989 31.2033 126.989 31.245C126.989 31.372 126.911 31.4363 126.756 31.4363H124.057C123.916 31.4363 123.824 31.365 123.78 31.2241H123.782ZM118.554 26.7407H122.295C122.38 26.7407 122.408 26.6989 122.38 26.6137L120.467 20.3024C120.453 20.2467 120.432 20.2207 120.403 20.2276C120.375 20.2346 120.352 20.2589 120.339 20.3024L118.469 26.6137C118.455 26.6989 118.483 26.7407 118.554 26.7407Z" fill="currentColor"/>
|
19 |
+
<path d="M128.222 31.353C128.18 31.2974 128.187 31.2261 128.243 31.1409L132.365 24.0643C132.393 24.0226 132.393 23.9791 132.365 23.9374L128.243 16.8609L128.201 16.7339C128.201 16.6209 128.28 16.5635 128.434 16.5635H131.133C131.274 16.5635 131.38 16.6209 131.452 16.7339L134.213 21.6C134.255 21.6852 134.299 21.6852 134.34 21.6L137.102 16.7339C137.173 16.6209 137.28 16.5635 137.42 16.5635H140.099C140.198 16.5635 140.269 16.5913 140.311 16.6487C140.353 16.7061 140.346 16.7756 140.29 16.8609L136.168 23.9374C136.154 23.9791 136.154 24.0226 136.168 24.0643L140.29 31.1409L140.332 31.2678C140.332 31.3809 140.253 31.4383 140.099 31.4383H137.42C137.278 31.4383 137.172 31.3826 137.102 31.2678L134.34 26.4226C134.299 26.3374 134.255 26.3374 134.213 26.4226L131.429 31.2678C131.358 31.3809 131.252 31.4383 131.111 31.4383H128.433C128.333 31.4383 128.262 31.4104 128.22 31.353H128.222Z" fill="currentColor"/>
|
20 |
+
<defs>
|
21 |
+
<linearGradient id="paint0_linear_17_483" x1="3.99826" y1="24" x2="51.6208" y2="24" gradientUnits="userSpaceOnUse">
|
22 |
+
<stop stop-color="#E21680"/>
|
23 |
+
<stop offset="1" stop-color="#FF633A"/>
|
24 |
+
</linearGradient>
|
25 |
+
</defs>
|
26 |
+
</svg>
|
27 |
+
|
28 |
+
</div>
|
29 |
+
<hr>
|
30 |
+
|
31 |
+
<div align="center" style="line-height: 1;">
|
32 |
+
<a href="https://www.minimax.io" target="_blank" style="margin: 2px;">
|
33 |
+
<img alt="Homepage" src="https://img.shields.io/badge/_Homepage-MiniMax-FF4040?style=flat-square&labelColor=2C3E50&logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHhtbG5zOnhsaW5rPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5L3hsaW5rIiB2aWV3Qm94PSIwIDAgNDkwLjE2IDQxMS43Ij48ZGVmcz48c3R5bGU+LmNscy0xe2ZpbGw6I2ZmZjt9PC9zdHlsZT48L2RlZnM+PHBhdGggY2xhc3M9ImNscy0xIiBkPSJNMjMzLjQ1LDQwLjgxYTE3LjU1LDE3LjU1LDAsMSwwLTM1LjEsMFYzMzEuNTZhNDAuODIsNDAuODIsMCwwLDEtODEuNjMsMFYxNDVhMTcuNTUsMTcuNTUsMCwxLDAtMzUuMDksMHY3OS4wNmE0MC44Miw0MC44MiwwLDAsMS04MS42MywwVjE5NS40MmExMS42MywxMS42MywwLDAsMSwyMy4yNiwwdjI4LjY2YTE3LjU1LDE3LjU1LDAsMCwwLDM1LjEsMFYxNDVBNDAuODIsNDAuODIsMCwwLDEsMTQwLDE0NVYzMzEuNTZhMTcuNTUsMTcuNTUsMCwwLDAsMzUuMSwwVjIxNy41aDBWNDAuODFhNDAuODEsNDAuODEsMCwxLDEsODEuNjIsMFYyODEuNTZhMTEuNjMsMTEuNjMsMCwxLDEtMjMuMjYsMFptMjE1LjksNjMuNEE0MC44Niw0MC44NiwwLDAsMCw0MDguNTMsMTQ1VjMwMC44NWExNy41NSwxNy41NSwwLDAsMS0zNS4wOSwwdi0yNjBhNDAuODIsNDAuODIsMCwwLDAtODEuNjMsMFYzNzAuODlhMTcuNTUsMTcuNTUsMCwwLDEtMzUuMSwwVjMzMGExMS42MywxMS42MywwLDEsMC0yMy4yNiwwdjQwLjg2YTQwLjgxLDQwLjgxLDAsMCwwLDgxLjYyLDBWNDAuODFhMTcuNTUsMTcuNTUsMCwwLDEsMzUuMSwwdjI2MGE0MC44Miw0MC44MiwwLDAsMCw4MS42MywwVjE0NWExNy41NSwxNy41NSwwLDEsMSwzNS4xLDBWMjgxLjU2YTExLjYzLDExLjYzLDAsMCwwLDIzLjI2LDBWMTQ1QTQwLjg1LDQwLjg1LDAsMCwwLDQ0OS4zNSwxMDQuMjFaIi8+PC9zdmc+&logoWidth=20" style="display: inline-block; vertical-align: middle;"/>
|
34 |
+
</a>
|
35 |
+
<a href="https://arxiv.org/abs/2506.13585" target="_blank" style="margin: 2px;">
|
36 |
+
<img alt="Paper" src="https://img.shields.io/badge/📖_Paper-MiniMax--M1-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
|
37 |
+
</a>
|
38 |
+
<a href="https://chat.minimax.io/" target="_blank" style="margin: 2px;">
|
39 |
+
<img alt="Chat" src="https://img.shields.io/badge/_MiniMax_Chat-FF4040?style=flat-square&labelColor=2C3E50&logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHhtbG5zOnhsaW5rPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5L3hsaW5rIiB2aWV3Qm94PSIwIDAgNDkwLjE2IDQxMS43Ij48ZGVmcz48c3R5bGU+LmNscy0xe2ZpbGw6I2ZmZjt9PC9zdHlsZT48L2RlZnM+PHBhdGggY2xhc3M9ImNscy0xIiBkPSJNMjMzLjQ1LDQwLjgxYTE3LjU1LDE3LjU1LDAsMSwwLTM1LjEsMFYzMzEuNTZhNDAuODIsNDAuODIsMCwwLDEtODEuNjMsMFYxNDVhMTcuNTUsMTcuNTUsMCwxLDAtMzUuMDksMHY3OS4wNmE0MC44Miw0MC44MiwwLDAsMS04MS42MywwVjE5NS40MmExMS42MywxMS42MywwLDAsMSwyMy4yNiwwdjI4LjY2YTE3LjU1LDE3LjU1LDAsMCwwLDM1LjEsMFYxNDVBNDAuODIsNDAuODIsMCwwLDEsMTQwLDE0NVYzMzEuNTZhMTcuNTUsMTcuNTUsMCwwLDAsMzUuMSwwVjIxNy41aDBWNDAuODFhNDAuODEsNDAuODEsMCwxLDEsODEuNjIsMFYyODEuNTZhMTEuNjMsMTEuNjMsMCwxLDEtMjMuMjYsMFptMjE1LjksNjMuNEE0MC44Niw0MC44NiwwLDAsMCw0MDguNTMsMTQ1VjMwMC44NWExNy41NSwxNy41NSwwLDAsMS0zNS4wOSwwdi0yNjBhNDAuODIsNDAuODIsMCwwLDAtODEuNjMsMFYzNzAuODlhMTcuNTUsMTcuNTUsMCwwLDEtMzUuMSwwVjMzMGExMS42MywxMS42MywwLDEsMC0yMy4yNiwwdjQwLjg2YTQwLjgxLDQwLjgxLDAsMCwwLDgxLjYyLDBWNDAuODFhMTcuNTUsMTcuNTUsMCwwLDEsMzUuMSwwdjI2MGE0MC44Miw0MC44MiwwLDAsMCw4MS42MywwVjE0NWExNy41NSwxNy41NSwwLDEsMSwzNS4xLDBWMjgxLjU2YTExLjYzLDExLjYzLDAsMCwwLDIzLjI2LDBWMTQ1QTQwLjg1LDQwLjg1LDAsMCwwLDQ0OS4zNSwxMDQuMjFaIi8+PC9zdmc+&logoWidth=20" style="display: inline-block; vertical-align: middle;"/>
|
40 |
+
</a>
|
41 |
+
<a href="https://www.minimax.io/platform" style="margin: 2px;">
|
42 |
+
<img alt="API" src="https://img.shields.io/badge/⚡_API-Platform-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
|
43 |
+
</a>
|
44 |
+
<a href="https://github.com/MiniMax-AI/MiniMax-MCP" style="margin: 2px;">
|
45 |
+
<img alt="MCP" src="https://img.shields.io/badge/🚀_MCP-MiniMax_MCP-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
|
46 |
+
</a>
|
47 |
+
</div>
|
48 |
+
<div align="center" style="line-height: 1;">
|
49 |
+
<a href="https://huggingface.co/MiniMaxAI" target="_blank" style="margin: 2px;">
|
50 |
+
<img alt="Hugging Face" src="https://img.shields.io/badge/🤗_Hugging_Face-MiniMax-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
|
51 |
+
</a>
|
52 |
+
<a href="https://github.com/MiniMax-AI/MiniMax-M1" target="_blank" style="margin: 2px;">
|
53 |
+
<img alt="GitHub" src="https://img.shields.io/badge/🐙_GitHub-MiniMax-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
|
54 |
+
</a>
|
55 |
+
<a href="https://www.modelscope.cn/organization/MiniMax" target="_blank" style="margin: 2px;">
|
56 |
+
<img alt="ModelScope" src="https://img.shields.io/badge/🤖️_ModelScope-MiniMax-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
|
57 |
+
</a>
|
58 |
+
<a href="https://github.com/MiniMax-AI/MiniMax-M1/blob/main/LICENSE" style="margin: 2px;">
|
59 |
+
<img alt="License" src="https://img.shields.io/badge/⚖️_License-Apache_2.0-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
|
60 |
+
</a>
|
61 |
+
<a href="https://github.com/MiniMax-AI/MiniMax-01/blob/main/figures/wechat-qrcode.jpeg" target="_blank" style="margin: 2px;">
|
62 |
+
<img alt="WeChat" src="https://img.shields.io/badge/💬_WeChat-MiniMax-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
|
63 |
+
</a>
|
64 |
+
</div>
|
65 |
+
|
66 |
+
# MiniMax-M1
|
67 |
+
|
68 |
+
## 1. Model Overview
|
69 |
+
|
70 |
+
We introduce MiniMax-M1, the world's first open-weight, large-scale hybrid-attention reasoning model.
|
71 |
+
MiniMax-M1 is powered by a hybrid Mixture-of-Experts (MoE) architecture combined with a lightning
|
72 |
+
attention mechanism. The model is developed based on our previous [MiniMax-Text-01 model](https://huggingface.co/MiniMaxAI/MiniMax-Text-01),
|
73 |
+
which contains a total of 456 billion parameters with 45.9 billion parameters activated
|
74 |
+
per token. Consistent with MiniMax-Text-01, the M1 model natively supports a context length of 1
|
75 |
+
million tokens, 8x the context size of DeepSeek R1. Furthermore, the lightning attention mechanism
|
76 |
+
in MiniMax-M1 enables efficient scaling of test-time compute – For example, compared to DeepSeek
|
77 |
+
R1, M1 consumes 25% of the FLOPs at a generation length of 100K tokens. These properties make M1
|
78 |
+
particularly suitable for complex tasks that require processing long inputs and thinking extensively.
|
79 |
+
MiniMax-M1 is trained using large-scale reinforcement learning (RL) on diverse problems ranging from
|
80 |
+
traditional mathematical reasoning to sandbox-based, real-world software engineering environments.
|
81 |
+
We develop an efficient RL scaling framework for M1 highlighting two perspectives: (1) We propose
|
82 |
+
CISPO, a novel algorithm that clips importance sampling weights instead of token updates, which
|
83 |
+
outperforms other competitive RL variants; (2) Our hybrid-attention design naturally enhances the
|
84 |
+
efficiency of RL, where we address unique challenges when scaling RL with the hybrid architecture. We
|
85 |
+
train two versions of MiniMax-M1 models with [40K](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k) and
|
86 |
+
[80K](https://huggingface.co/MiniMaxAI/MiniMax-M1-80k) thinking budgets respectively. Experiments
|
87 |
+
on standard benchmarks show that our models outperform other strong open-weight models such as
|
88 |
+
the original DeepSeek-R1 and Qwen3-235B, particularly on complex software engineering, tool using,
|
89 |
+
and long context tasks. With efficient scaling of test-time compute, MiniMax-M1 serves as a strong
|
90 |
+
foundation for next-generation language model agents to reason and tackle real-world challenges.
|
91 |
+
|
92 |
+
<p align="center">
|
93 |
+
<img width="100%" src="figures/TextBench.png">
|
94 |
+
<br>
|
95 |
+
<small><em>Benchmark performance comparison of leading commercial and open-weight models across competition-level mathematics, coding, software engineering, agentic tool use, and long-context understanding tasks. We use the MiniMax-M1-80k model here for MiniMax-M1.</em></small>
|
96 |
+
</p>
|
97 |
+
|
98 |
+
|
99 |
+
## 2. Evaluation
|
100 |
+
|
101 |
+
**Performance of MiniMax-M1 on core benchmarks.**
|
102 |
+
|
103 |
+
|
104 |
+
| **Category** | **Task** | **MiniMax-M1-80K** | **MiniMax-M1-40K** | **Qwen3-235B-A22B** | **DeepSeek-R1-0528** | **DeepSeek-R1** | **Seed-Thinking-v1.5** | **Claude 4 Opus** | **Gemini 2.5 Pro (06-05)** | **OpenAI-o3** |
|
105 |
+
|:---|:---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|
106 |
+
| | *Extended Thinking* | *80K* | *40K* | *32k* | *64k* | *32k* | *32k* | *64k* | *64k* | *100k* |
|
107 |
+
| ***Mathematics*** | AIME 2024 | 86.0 | 83.3 | 85.7 | 91.4 | 79.8 | 86.7 | 76.0 | 92.0 | 91.6 |
|
108 |
+
| | AIME 2025 | 76.9 | 74.6 | 81.5 | 87.5 | 70.0 | 74.0 | 75.5 | 88.0 | 88.9 |
|
109 |
+
| | MATH-500 | 96.8 | 96.0 | 96.2 | 98.0 | 97.3 | 96.7 | 98.2 | 98.8 | 98.1 |
|
110 |
+
| ***General Coding*** | LiveCodeBench *(24/8~25/5)* | 65.0 | 62.3 | 65.9 | 73.1 | 55.9 | 67.5 | 56.6 | 77.1 | 75.8 |
|
111 |
+
| | FullStackBench | 68.3 | 67.6 | 62.9 | 69.4 | 70.1 | 69.9 | 70.3 | -- | 69.3 |
|
112 |
+
| ***Reasoning & Knowledge***| GPQA Diamond | 70.0 | 69.2 | 71.1 | 81.0 | 71.5 | 77.3 | 79.6 | 86.4 | 83.3 |
|
113 |
+
| | HLE *(no tools)* | 8.4\* | 7.2\* | 7.6\* | 17.7\* | 8.6\* | 8.2 | 10.7 | 21.6 | 20.3 |
|
114 |
+
| | ZebraLogic | 86.8 | 80.1 | 80.3 | 95.1 | 78.7 | 84.4 | 95.1 | 91.6 | 95.8 |
|
115 |
+
| | MMLU-Pro | 81.1 | 80.6 | 83.0 | 85.0 | 84.0 | 87.0 | 85.0 | 86.0 | 85.0 |
|
116 |
+
| ***Software Engineering***| SWE-bench Verified| 56.0 | 55.6 | 34.4 | 57.6 | 49.2 | 47.0 | 72.5 | 67.2 | 69.1 |
|
117 |
+
| ***Long Context*** | OpenAI-MRCR *(128k)* | 73.4 | 76.1 | 27.7 | 51.5 | 35.8 | 54.3 | 48.9 | 76.8 | 56.5 |
|
118 |
+
| | OpenAI-MRCR *(1M)* | 56.2 | 58.6 | -- | -- | -- | -- | -- | 58.8 | -- |
|
119 |
+
| | LongBench-v2 | 61.5 | 61.0 | 50.1 | 52.1 | 58.3 | 52.5 | 55.6 | 65.0 | 58.8 |
|
120 |
+
| ***Agentic Tool Use***| TAU-bench *(airline)* | 62.0 | 60.0 | 34.7 | 53.5 | -- | 44.0 | 59.6 | 50.0 | 52.0 |
|
121 |
+
| | TAU-bench *(retail)* | 63.5 | 67.8 | 58.6 | 63.9 | -- | 55.7 | 81.4 | 67.0 | 73.9 |
|
122 |
+
| ***Factuality*** | SimpleQA | 18.5 | 17.9 | 11.0 | 27.8 | 30.1 | 12.9 | -- | 54.0 | 49.4 |
|
123 |
+
| ***General Assistant***| MultiChallenge | 44.7 | 44.7 | 40.0 | 45.0 | 40.7 | 43.0 | 45.8 | 51.8 | 56.5 |
|
124 |
+
|
125 |
+
\* conducted on the text-only HLE subset.
|
126 |
+
|
127 |
+
Our models are evaluated with `temperature=1.0`, `top_p=0.95`.
|
128 |
+
|
129 |
+
### SWE-bench methodology
|
130 |
+
We report results derived from the Agentless scaffold. Departing from the original pipeline, our methodology employs a two-stage localization process (without any embedding-based retrieval mechanisms): initial coarse-grained file localization followed by fine-grained localization to specific files and code elements. The values for our models are calculated on the subset of n=486 verified tasks which work on our infrastructure. The excluded 14 test cases that were incompatible with our internal infrastructure are:
|
131 |
+
`"astropy__astropy-7606"`,
|
132 |
+
`"astropy__astropy-8707"`,
|
133 |
+
`"astropy__astropy-8872"`,
|
134 |
+
`"django__django-10097"`,
|
135 |
+
`"matplotlib__matplotlib-20488"`,
|
136 |
+
`"psf__requests-2317"`,
|
137 |
+
`"psf__requests-2931"`,
|
138 |
+
`"psf__requests-5414"`,
|
139 |
+
`"pylint-dev__pylint-6528"`,
|
140 |
+
`"pylint-dev__pylint-7277"`,
|
141 |
+
`"sphinx-doc__sphinx-10435"`,
|
142 |
+
`"sphinx-doc__sphinx-7985"`,
|
143 |
+
`"sphinx-doc__sphinx-8269"`,
|
144 |
+
`"sphinx-doc__sphinx-8475"`
|
145 |
+
|
146 |
+
### TAU-bench methodology
|
147 |
+
We evaluate TAU-Bench with GPT-4.1 as user model and without any custom tools. The maximum number of interaction steps is 40.
|
148 |
+
Our general system prompt is:
|
149 |
+
```
|
150 |
+
- In each round, you need to carefully examine the tools provided to you to determine if any can be used.
|
151 |
+
- You must adhere to all of the policies. Pay attention to the details in the terms. Solutions for most situations can be found within these policies.
|
152 |
+
```
|
153 |
+
|
154 |
+
## 3. Recommendations for Minimax-M1 Model Usage
|
155 |
+
|
156 |
+
To achieve the best results with the Minimax-M1 model, we suggest focusing on two key points: Inference Parameters and the System Prompt.
|
157 |
+
|
158 |
+
### 3.1. Inference Parameters
|
159 |
+
- Temperature: **`1.0`**
|
160 |
+
- Top_p: **`0.95`**
|
161 |
+
|
162 |
+
This setting is optimal for encouraging creativity and diversity in the model's responses. It allows the model to explore a wider range of linguistic possibilities, preventing outputs that are too rigid or repetitive, while still maintaining strong logical coherence.
|
163 |
+
|
164 |
+
### 3.2. System Prompt
|
165 |
+
Tailoring your system prompt to the specific task is crucial for guiding the model effectively. Below are suggested settings for different scenarios.
|
166 |
+
|
167 |
+
#### A. General-Purpose Scenarios
|
168 |
+
For common tasks like summarization, translation, Q&A, or creative writing:
|
169 |
+
```
|
170 |
+
You are a helpful assistant.
|
171 |
+
```
|
172 |
+
#### B. Web Development Scenarios
|
173 |
+
For complex tasks like generating code for web pages:
|
174 |
+
```
|
175 |
+
You are a web development engineer, writing web pages according to the instructions below. You are a powerful code editing assistant capable of writing code and creating artifacts in conversations with users, or modifying and updating existing artifacts as requested by users.
|
176 |
+
All code is written in a single code block to form a complete code file for display, without separating HTML and JavaScript code. An artifact refers to a runnable complete code snippet, you prefer to integrate and output such complete runnable code rather than breaking it down into several code blocks. For certain types of code, they can render graphical interfaces in a UI window. After generation, please check the code execution again to ensure there are no errors in the output.
|
177 |
+
Output only the HTML, without any additional descriptive text. Make the UI looks modern and beautiful.
|
178 |
+
```
|
179 |
+
#### C. Mathematical Scenarios
|
180 |
+
When dealing with problems that require calculation or logical deduction:
|
181 |
+
```
|
182 |
+
Please reason step by step, and put your final answer within \boxed{}.
|
183 |
+
```
|
184 |
+
|
185 |
+
## 4. Deployment Guide
|
186 |
+
|
187 |
+
Download the model from HuggingFace repository:
|
188 |
+
- [MiniMax-M1-40k](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k)
|
189 |
+
- [MiniMax-M1-80k](https://huggingface.co/MiniMaxAI/MiniMax-M1-80k)
|
190 |
+
|
191 |
+
For production deployment, we recommend using [vLLM](https://docs.vllm.ai/en/latest/) to serve MiniMax-M1. vLLM provides excellent performance for serving large language models with the following features:
|
192 |
+
- 🔥 Outstanding service throughout performance
|
193 |
+
- ⚡ Efficient and intelligent memory management
|
194 |
+
- 📦 Powerful batch request processing capability
|
195 |
+
- ⚙️ Deeply optimized underlying performance
|
196 |
+
|
197 |
+
For detailed vLLM deployment instructions, please refer to our [vLLM Deployment Guide](./docs/vllm_deployment_guide.md).
|
198 |
+
Alternatively, you can also deploy using Transformers directly. For detailed Transformers deployment instructions, you can see our [MiniMax-M1 Transformers Deployment Guide](./docs/transformers_deployment_guide.md).
|
199 |
+
|
200 |
+
|
201 |
+
## 5. Function Calling
|
202 |
+
|
203 |
+
The MiniMax-M1 model supports function calling capabilities, enabling the model to identify when external functions need to be called and output function call parameters in a structured format. [MiniMax-M1 Function Call Guide](./docs/function_call_guide.md) provides detailed instructions on how to use the function calling feature of MiniMax-M1.
|
204 |
+
|
205 |
+
|
206 |
+
## 6. Chatbot & API
|
207 |
+
For general use and evaluation, we provide a [Chatbot](https://chat.minimax.io/) with online search capabilities and the [online API](https://www.minimax.io/platform/) for developers. For general use and evaluation, we provide the [MiniMax MCP Server](https://github.com/MiniMax-AI/MiniMax-MCP) with video generation, image generation, speech synthesis, and voice cloning for developers.
|
208 |
+
|
209 |
+
|
210 |
+
## 7. Citation
|
211 |
+
```
|
212 |
+
@misc{minimax2025minimaxm1scalingtesttimecompute,
|
213 |
+
title={MiniMax-M1: Scaling Test-Time Compute Efficiently with Lightning Attention},
|
214 |
+
author={MiniMax},
|
215 |
+
year={2025},
|
216 |
+
eprint={2506.13585},
|
217 |
+
archivePrefix={arXiv},
|
218 |
+
primaryClass={cs.CL},
|
219 |
+
url={https://arxiv.org/abs/2506.13585},
|
220 |
+
}
|
221 |
+
```
|
222 |
+
|
223 |
+
## 8. Contact Us
|
224 |
+
Contact us at [[email protected]](mailto:[email protected]).
|
config.json
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"MiniMaxM1ForCausalLM"
|
4 |
+
],
|
5 |
+
"attention_dropout": 0.0,
|
6 |
+
"attn_type_list": [
|
7 |
+
0,
|
8 |
+
0,
|
9 |
+
0,
|
10 |
+
0,
|
11 |
+
0,
|
12 |
+
0,
|
13 |
+
0,
|
14 |
+
1,
|
15 |
+
0,
|
16 |
+
0,
|
17 |
+
0,
|
18 |
+
0,
|
19 |
+
0,
|
20 |
+
0,
|
21 |
+
0,
|
22 |
+
1,
|
23 |
+
0,
|
24 |
+
0,
|
25 |
+
0,
|
26 |
+
0,
|
27 |
+
0,
|
28 |
+
0,
|
29 |
+
0,
|
30 |
+
1,
|
31 |
+
0,
|
32 |
+
0,
|
33 |
+
0,
|
34 |
+
0,
|
35 |
+
0,
|
36 |
+
0,
|
37 |
+
0,
|
38 |
+
1,
|
39 |
+
0,
|
40 |
+
0,
|
41 |
+
0,
|
42 |
+
0,
|
43 |
+
0,
|
44 |
+
0,
|
45 |
+
0,
|
46 |
+
1,
|
47 |
+
0,
|
48 |
+
0,
|
49 |
+
0,
|
50 |
+
0,
|
51 |
+
0,
|
52 |
+
0,
|
53 |
+
0,
|
54 |
+
1,
|
55 |
+
0,
|
56 |
+
0,
|
57 |
+
0,
|
58 |
+
0,
|
59 |
+
0,
|
60 |
+
0,
|
61 |
+
0,
|
62 |
+
1,
|
63 |
+
0,
|
64 |
+
0,
|
65 |
+
0,
|
66 |
+
0,
|
67 |
+
0,
|
68 |
+
0,
|
69 |
+
0,
|
70 |
+
1,
|
71 |
+
0,
|
72 |
+
0,
|
73 |
+
0,
|
74 |
+
0,
|
75 |
+
0,
|
76 |
+
0,
|
77 |
+
0,
|
78 |
+
1,
|
79 |
+
0,
|
80 |
+
0,
|
81 |
+
0,
|
82 |
+
0,
|
83 |
+
0,
|
84 |
+
0,
|
85 |
+
0,
|
86 |
+
1
|
87 |
+
],
|
88 |
+
"auto_map": {
|
89 |
+
"AutoConfig": "configuration_minimax_m1.MiniMaxM1Config",
|
90 |
+
"AutoModelForCausalLM": "modeling_minimax_m1.MiniMaxM1ForCausalLM"
|
91 |
+
},
|
92 |
+
"bos_token_id": null,
|
93 |
+
"eos_token_id": null,
|
94 |
+
"head_dim": 128,
|
95 |
+
"hidden_act": "silu",
|
96 |
+
"hidden_size": 6144,
|
97 |
+
"initializer_range": 0.02,
|
98 |
+
"intermediate_size": 9216,
|
99 |
+
"layernorm_full_attention_alpha": 3.5565588200778455,
|
100 |
+
"layernorm_full_attention_beta": 1.0,
|
101 |
+
"layernorm_linear_attention_alpha": 3.5565588200778455,
|
102 |
+
"layernorm_linear_attention_beta": 1.0,
|
103 |
+
"layernorm_mlp_alpha": 3.5565588200778455,
|
104 |
+
"layernorm_mlp_beta": 1.0,
|
105 |
+
"max_position_embeddings": 10240000,
|
106 |
+
"model_type": "minimax_m1",
|
107 |
+
"num_attention_heads": 64,
|
108 |
+
"num_experts_per_tok": 2,
|
109 |
+
"num_hidden_layers": 80,
|
110 |
+
"num_key_value_heads": 8,
|
111 |
+
"num_local_experts": 32,
|
112 |
+
"output_router_logits": false,
|
113 |
+
"postnorm": true,
|
114 |
+
"rms_norm_eps": 1e-05,
|
115 |
+
"rope_theta": 10000000,
|
116 |
+
"rotary_dim": 64,
|
117 |
+
"router_aux_loss_coef": 0.001,
|
118 |
+
"router_jitter_noise": 0.0,
|
119 |
+
"shared_intermediate_size": 0,
|
120 |
+
"shared_moe_mode": "sigmoid",
|
121 |
+
"sliding_window": null,
|
122 |
+
"tie_word_embeddings": false,
|
123 |
+
"transformers_version": "4.45.2",
|
124 |
+
"use_cache": true,
|
125 |
+
"vocab_size": 200064
|
126 |
+
}
|
127 |
+
|
configuration_minimax_m1.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" MiniMaxM1 model configuration"""
|
2 |
+
|
3 |
+
from transformers.configuration_utils import PretrainedConfig
|
4 |
+
from transformers.utils import logging
|
5 |
+
|
6 |
+
|
7 |
+
logger = logging.get_logger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
class MiniMaxM1Config(PretrainedConfig):
|
11 |
+
r"""
|
12 |
+
This is the configuration class to store the configuration of a [`MiniMaxM1Model`]. It is used to instantiate an
|
13 |
+
MiniMaxM1 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
14 |
+
with the defaults will yield a similar configuration to that of the MiniMaxM1.
|
15 |
+
|
16 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
17 |
+
documentation from [`PretrainedConfig`] for more information.
|
18 |
+
|
19 |
+
|
20 |
+
Args:
|
21 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
22 |
+
Vocabulary size of the MiniMaxM1 model. Defines the number of different tokens that can be represented by the
|
23 |
+
`inputs_ids` passed when calling [`MiniMaxM1Model`]
|
24 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
25 |
+
Dimension of the hidden representations.
|
26 |
+
intermediate_size (`int`, *optional*, defaults to 14336):
|
27 |
+
Dimension of the MLP representations.
|
28 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
29 |
+
Number of hidden layers in the Transformer encoder.
|
30 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
31 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
32 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
33 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
34 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
35 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
36 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
37 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
38 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
|
39 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
40 |
+
The non-linear activation function (function or string) in the decoder.
|
41 |
+
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
|
42 |
+
The maximum sequence length that this model might ever be used with. MiniMaxM1's sliding window attention
|
43 |
+
allows sequence of up to 4096*32 tokens.
|
44 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
45 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
46 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
47 |
+
The epsilon used by the rms normalization layers.
|
48 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
49 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
50 |
+
relevant if `config.is_decoder=True`.
|
51 |
+
pad_token_id (`int`, *optional*):
|
52 |
+
The id of the padding token.
|
53 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
54 |
+
The id of the "beginning-of-sequence" token.
|
55 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
56 |
+
The id of the "end-of-sequence" token.
|
57 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
58 |
+
Whether the model's input and output word embeddings should be tied.
|
59 |
+
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
60 |
+
The base period of the RoPE embeddings.
|
61 |
+
sliding_window (`int`, *optional*):
|
62 |
+
Sliding window attention window size. If not specified, will default to `4096`.
|
63 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
64 |
+
The dropout ratio for the attention probabilities.
|
65 |
+
num_experts_per_tok (`int`, *optional*, defaults to 2):
|
66 |
+
The number of experts to route per-token, can be also interpreted as the `top-k` routing
|
67 |
+
parameter
|
68 |
+
num_local_experts (`int`, *optional*, defaults to 8):
|
69 |
+
Number of experts per Sparse MLP layer.
|
70 |
+
output_router_logits (`bool`, *optional*, defaults to `False`):
|
71 |
+
Whether or not the router logits should be returned by the model. Enabeling this will also
|
72 |
+
allow the model to output the auxiliary loss. See [here]() for more details
|
73 |
+
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
74 |
+
The aux loss factor for the total loss.
|
75 |
+
router_jitter_noise (`float`, *optional*, defaults to 0.0):
|
76 |
+
Amount of noise to add to the router.
|
77 |
+
|
78 |
+
```python
|
79 |
+
>>> from transformers import MiniMaxM1Model, MiniMaxM1Config
|
80 |
+
|
81 |
+
>>> # Initializing a MiniMaxM1 style configuration
|
82 |
+
>>> configuration = MiniMaxM1Config()
|
83 |
+
|
84 |
+
>>> # Initializing a model from the MiniMaxM1 style configuration
|
85 |
+
>>> model = MiniMaxM1Model(configuration)
|
86 |
+
|
87 |
+
>>> # Accessing the model configuration
|
88 |
+
>>> configuration = model.config
|
89 |
+
```"""
|
90 |
+
|
91 |
+
model_type = "MiniMaxM1"
|
92 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
93 |
+
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
vocab_size=32000,
|
97 |
+
hidden_size=4096,
|
98 |
+
intermediate_size=14336,
|
99 |
+
num_hidden_layers=32,
|
100 |
+
num_attention_heads=32,
|
101 |
+
num_key_value_heads=8,
|
102 |
+
hidden_act="silu",
|
103 |
+
max_position_embeddings=4096 * 32,
|
104 |
+
initializer_range=0.02,
|
105 |
+
rms_norm_eps=1e-5,
|
106 |
+
use_cache=True,
|
107 |
+
pad_token_id=None,
|
108 |
+
bos_token_id=None,
|
109 |
+
eos_token_id=None,
|
110 |
+
tie_word_embeddings=False,
|
111 |
+
rope_theta=1e6,
|
112 |
+
sliding_window=None,
|
113 |
+
attention_dropout=0.0,
|
114 |
+
num_experts_per_tok=2,
|
115 |
+
num_local_experts=8,
|
116 |
+
output_router_logits=False,
|
117 |
+
router_aux_loss_coef=0.001,
|
118 |
+
router_jitter_noise=0.0,
|
119 |
+
**kwargs,
|
120 |
+
):
|
121 |
+
self.vocab_size = vocab_size
|
122 |
+
self.max_position_embeddings = max_position_embeddings
|
123 |
+
self.hidden_size = hidden_size
|
124 |
+
self.intermediate_size = intermediate_size
|
125 |
+
self.num_hidden_layers = num_hidden_layers
|
126 |
+
self.num_attention_heads = num_attention_heads
|
127 |
+
self.sliding_window = sliding_window
|
128 |
+
|
129 |
+
# for backward compatibility
|
130 |
+
if num_key_value_heads is None:
|
131 |
+
num_key_value_heads = num_attention_heads
|
132 |
+
|
133 |
+
self.num_key_value_heads = num_key_value_heads
|
134 |
+
self.hidden_act = hidden_act
|
135 |
+
self.initializer_range = initializer_range
|
136 |
+
self.rms_norm_eps = rms_norm_eps
|
137 |
+
self.use_cache = use_cache
|
138 |
+
self.rope_theta = rope_theta
|
139 |
+
self.attention_dropout = attention_dropout
|
140 |
+
|
141 |
+
self.num_experts_per_tok = num_experts_per_tok
|
142 |
+
self.num_local_experts = num_local_experts
|
143 |
+
self.output_router_logits = output_router_logits
|
144 |
+
self.router_aux_loss_coef = router_aux_loss_coef
|
145 |
+
self.router_jitter_noise = router_jitter_noise
|
146 |
+
super().__init__(
|
147 |
+
pad_token_id=pad_token_id,
|
148 |
+
bos_token_id=bos_token_id,
|
149 |
+
eos_token_id=eos_token_id,
|
150 |
+
tie_word_embeddings=tie_word_embeddings,
|
151 |
+
**kwargs,
|
152 |
+
)
|
docs/function_call_guide.md
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MiniMax-M1 Function Call Guide
|
2 |
+
|
3 |
+
[FunctionCall中文使用指南](./function_call_guide_cn.md)
|
4 |
+
|
5 |
+
## 📖 Introduction
|
6 |
+
|
7 |
+
The MiniMax-M1 model supports function calling capabilities, enabling the model to identify when external functions need to be called and output function call parameters in a structured format. This document provides detailed instructions on how to use the function calling feature of MiniMax-M1.
|
8 |
+
|
9 |
+
## 🚀 Quick Start
|
10 |
+
|
11 |
+
### Using Chat Template
|
12 |
+
|
13 |
+
MiniMax-M1 uses a specific chat template format to handle function calls. The chat template is defined in `tokenizer_config.json`, and you can use it in your code through the template.
|
14 |
+
|
15 |
+
```python
|
16 |
+
from transformers import AutoTokenizer
|
17 |
+
|
18 |
+
def get_default_tools():
|
19 |
+
return [
|
20 |
+
{
|
21 |
+
{
|
22 |
+
"name": "get_current_weather",
|
23 |
+
"description": "Get the latest weather for a location",
|
24 |
+
"parameters": {
|
25 |
+
"type": "object",
|
26 |
+
"properties": {
|
27 |
+
"location": {
|
28 |
+
"type": "string",
|
29 |
+
"description": "A certain city, such as Beijing, Shanghai"
|
30 |
+
}
|
31 |
+
},
|
32 |
+
}
|
33 |
+
"required": ["location"],
|
34 |
+
"type": "object"
|
35 |
+
}
|
36 |
+
}
|
37 |
+
]
|
38 |
+
|
39 |
+
# Load model and tokenizer
|
40 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
41 |
+
prompt = "What's the weather like in Shanghai today?"
|
42 |
+
messages = [
|
43 |
+
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant created by Minimax based on MiniMax-M1 model."}]},
|
44 |
+
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
45 |
+
]
|
46 |
+
|
47 |
+
# Enable function call tools
|
48 |
+
tools = get_default_tools()
|
49 |
+
|
50 |
+
# Apply chat template and add tool definitions
|
51 |
+
text = tokenizer.apply_chat_template(
|
52 |
+
messages,
|
53 |
+
tokenize=False,
|
54 |
+
add_generation_prompt=True,
|
55 |
+
tools=tools
|
56 |
+
)
|
57 |
+
```
|
58 |
+
|
59 |
+
## 🛠️ Function Call Definition
|
60 |
+
|
61 |
+
### Function Structure
|
62 |
+
|
63 |
+
Function calls need to be defined in the `tools` field of the request body. Each function consists of the following components:
|
64 |
+
|
65 |
+
```json
|
66 |
+
{
|
67 |
+
"tools": [
|
68 |
+
{
|
69 |
+
"name": "search_web",
|
70 |
+
"description": "Search function.",
|
71 |
+
"parameters": {
|
72 |
+
"properties": {
|
73 |
+
"query_list": {
|
74 |
+
"description": "Keywords for search, with list element count of 1.",
|
75 |
+
"items": { "type": "string" },
|
76 |
+
"type": "array"
|
77 |
+
},
|
78 |
+
"query_tag": {
|
79 |
+
"description": "Classification of the query",
|
80 |
+
"items": { "type": "string" },
|
81 |
+
"type": "array"
|
82 |
+
}
|
83 |
+
},
|
84 |
+
"required": [ "query_list", "query_tag" ],
|
85 |
+
"type": "object"
|
86 |
+
}
|
87 |
+
}
|
88 |
+
]
|
89 |
+
}
|
90 |
+
```
|
91 |
+
|
92 |
+
**Field Descriptions:**
|
93 |
+
- `name`: Function name
|
94 |
+
- `description`: Function description
|
95 |
+
- `parameters`: Function parameter definition
|
96 |
+
- `properties`: Parameter property definitions, where key is the parameter name and value contains detailed parameter description
|
97 |
+
- `required`: List of required parameters
|
98 |
+
- `type`: Parameter type (usually "object")
|
99 |
+
|
100 |
+
### Internal Model Processing Format
|
101 |
+
|
102 |
+
When processed internally by the model, function definitions are converted to a special format and concatenated to the input text:
|
103 |
+
|
104 |
+
```
|
105 |
+
]~!b[]~b]system ai_setting=MiniMax AI
|
106 |
+
MiniMax AI is an AI assistant independently developed by MiniMax. [e~[
|
107 |
+
]~b]system tool_setting=tools
|
108 |
+
You are provided with these tools:
|
109 |
+
<tools>
|
110 |
+
{"name": "search_web", "description": "Search function.", "parameters": {"properties": {"query_list": {"description": "Keywords for search, with list element count of 1.", "items": {"type": "string"}, "type": "array"}, "query_tag": {"description": "Classification of the query", "items": {"type": "string"}, "type": "array"}}, "required": ["query_list", "query_tag"], "type": "object"}}
|
111 |
+
</tools>
|
112 |
+
|
113 |
+
If you need to call tools, please respond with <tool_calls></tool_calls> XML tags, and provide tool-name and json-object of arguments, following the format below:
|
114 |
+
<tool_calls>
|
115 |
+
{"name": <tool-name>, "arguments": <args-json-object>}
|
116 |
+
...
|
117 |
+
</tool_calls>[e~[
|
118 |
+
]~b]user name=User
|
119 |
+
When were the most recent launch events for OpenAI and Gemini?[e~[
|
120 |
+
]~b]ai name=MiniMax AI
|
121 |
+
```
|
122 |
+
|
123 |
+
### Model Output Format
|
124 |
+
|
125 |
+
The model outputs function calls in the following format:
|
126 |
+
|
127 |
+
```xml
|
128 |
+
<think>
|
129 |
+
Okay, I will search for the OpenAI and Gemini latest release.
|
130 |
+
</think>
|
131 |
+
<tool_calls>
|
132 |
+
{"name": "search_web", "arguments": {"query_tag": ["technology", "events"], "query_list": ["\"OpenAI\" \"latest\" \"release\""]}}
|
133 |
+
{"name": "search_web", "arguments": {"query_tag": ["technology", "events"], "query_list": ["\"Gemini\" \"latest\" \"release\""]}}
|
134 |
+
</tool_calls>
|
135 |
+
```
|
136 |
+
|
137 |
+
## 📥 Function Call Result Processing
|
138 |
+
|
139 |
+
### Parsing Function Calls
|
140 |
+
|
141 |
+
You can use the following code to parse function calls from the model output:
|
142 |
+
|
143 |
+
```python
|
144 |
+
import re
|
145 |
+
import json
|
146 |
+
|
147 |
+
def parse_function_calls(content: str):
|
148 |
+
"""
|
149 |
+
Parse function calls from model output
|
150 |
+
"""
|
151 |
+
function_calls = []
|
152 |
+
|
153 |
+
# Match content within <tool_calls> tags
|
154 |
+
tool_calls_pattern = r"<tool_calls>(.*?)</tool_calls>"
|
155 |
+
tool_calls_match = re.search(tool_calls_pattern, content, re.DOTALL)
|
156 |
+
|
157 |
+
if not tool_calls_match:
|
158 |
+
return function_calls
|
159 |
+
|
160 |
+
tool_calls_content = tool_calls_match.group(1).strip()
|
161 |
+
|
162 |
+
# Parse each function call (one JSON object per line)
|
163 |
+
for line in tool_calls_content.split('\n'):
|
164 |
+
line = line.strip()
|
165 |
+
if not line:
|
166 |
+
continue
|
167 |
+
|
168 |
+
try:
|
169 |
+
# Parse JSON format function call
|
170 |
+
call_data = json.loads(line)
|
171 |
+
function_name = call_data.get("name")
|
172 |
+
arguments = call_data.get("arguments", {})
|
173 |
+
|
174 |
+
function_calls.append({
|
175 |
+
"name": function_name,
|
176 |
+
"arguments": arguments
|
177 |
+
})
|
178 |
+
|
179 |
+
print(f"Function call: {function_name}, Arguments: {arguments}")
|
180 |
+
|
181 |
+
except json.JSONDecodeError as e:
|
182 |
+
print(f"Parameter parsing failed: {line}, Error: {e}")
|
183 |
+
|
184 |
+
return function_calls
|
185 |
+
|
186 |
+
# Example: Handle weather query function
|
187 |
+
def execute_function_call(function_name: str, arguments: dict):
|
188 |
+
"""
|
189 |
+
Execute function call and return result
|
190 |
+
"""
|
191 |
+
if function_name == "get_current_weather":
|
192 |
+
location = arguments.get("location", "Unknown location")
|
193 |
+
# Build function execution result
|
194 |
+
return {
|
195 |
+
"role": "tool",
|
196 |
+
"name": function_name,
|
197 |
+
"content": json.dumps({
|
198 |
+
"location": location,
|
199 |
+
"temperature": "25",
|
200 |
+
"unit": "celsius",
|
201 |
+
"weather": "Sunny"
|
202 |
+
}, ensure_ascii=False)
|
203 |
+
}
|
204 |
+
elif function_name == "search_web":
|
205 |
+
query_list = arguments.get("query_list", [])
|
206 |
+
query_tag = arguments.get("query_tag", [])
|
207 |
+
# Simulate search results
|
208 |
+
return {
|
209 |
+
"role": "tool",
|
210 |
+
"name": function_name,
|
211 |
+
"content": f"Search keywords: {query_list}, Categories: {query_tag}\nSearch results: Relevant information found"
|
212 |
+
}
|
213 |
+
|
214 |
+
return None
|
215 |
+
```
|
216 |
+
|
217 |
+
### Returning Function Execution Results to the Model
|
218 |
+
|
219 |
+
After successfully parsing function calls, you should add the function execution results to the conversation history so that the model can access and utilize this information in subsequent interactions.
|
220 |
+
|
221 |
+
#### Single Result
|
222 |
+
|
223 |
+
If the model decides to call `search_web`, we suggest you to return the function result in the following format, with the `name` field set to the specific tool name.
|
224 |
+
|
225 |
+
```json
|
226 |
+
{
|
227 |
+
"data": [
|
228 |
+
{
|
229 |
+
"role": "tool",
|
230 |
+
"name": "search_web",
|
231 |
+
"content": "search_result"
|
232 |
+
}
|
233 |
+
]
|
234 |
+
}
|
235 |
+
```
|
236 |
+
|
237 |
+
Corresponding model input format:
|
238 |
+
```
|
239 |
+
]~b]tool name=search_web
|
240 |
+
search_result[e~[
|
241 |
+
```
|
242 |
+
|
243 |
+
|
244 |
+
#### Multiple Result
|
245 |
+
If the model decides to call `search_web` and `get_current_weather` at the same time, we suggest you to return the multiple function results in the following format, with the `name` field set to "tools", and use the `content` field to contain multiple results.
|
246 |
+
|
247 |
+
|
248 |
+
```json
|
249 |
+
{
|
250 |
+
"data": [
|
251 |
+
{
|
252 |
+
"role": "tool",
|
253 |
+
"name": "tools",
|
254 |
+
"content": "Tool name: search_web\nTool result: test_result1\n\nTool name: get_current_weather\nTool result: test_result2"
|
255 |
+
}
|
256 |
+
]
|
257 |
+
}
|
258 |
+
```
|
259 |
+
|
260 |
+
Corresponding model input format:
|
261 |
+
```
|
262 |
+
]~b]tool name=tools
|
263 |
+
Tool name: search_web
|
264 |
+
Tool result: test_result1
|
265 |
+
|
266 |
+
Tool name: get_current_weather
|
267 |
+
Tool result: test_result2[e~[
|
268 |
+
```
|
269 |
+
|
270 |
+
While we suggest following the above formats, as long as the model input is easy to understand, the specific values of `name` and `content` is entirely up to the caller.
|
docs/function_call_guide_cn.md
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MiniMax-M1 函数调用(Function Call)功能指南
|
2 |
+
|
3 |
+
## 📖 简介
|
4 |
+
|
5 |
+
MiniMax-M1 模型支持函数调用功能,使模型能够识别何时需要调用外部函数,并以结构化格式输出函数调用参数。本文档详细介绍了如何使用 MiniMax-M1 的函数调用功能。
|
6 |
+
|
7 |
+
## 🚀 快速开始
|
8 |
+
|
9 |
+
### 聊天模板使用
|
10 |
+
|
11 |
+
MiniMax-M1 使用特定的聊天模板格式处理函数调用。聊天模板定义在 `tokenizer_config.json` 中,你可以在代码中通过 template 来进行使用。
|
12 |
+
|
13 |
+
```python
|
14 |
+
from transformers import AutoTokenizer
|
15 |
+
|
16 |
+
def get_default_tools():
|
17 |
+
return [
|
18 |
+
{
|
19 |
+
{
|
20 |
+
"name": "get_current_weather",
|
21 |
+
"description": "Get the latest weather for a location",
|
22 |
+
"parameters": {
|
23 |
+
"type": "object",
|
24 |
+
"properties": {
|
25 |
+
"location": {
|
26 |
+
"type": "string",
|
27 |
+
"description": "A certain city, such as Beijing, Shanghai"
|
28 |
+
}
|
29 |
+
},
|
30 |
+
}
|
31 |
+
"required": ["location"],
|
32 |
+
"type": "object"
|
33 |
+
}
|
34 |
+
}
|
35 |
+
]
|
36 |
+
|
37 |
+
# 加载模型和分词器
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
39 |
+
prompt = "What's the weather like in Shanghai today?"
|
40 |
+
messages = [
|
41 |
+
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant created by Minimax based on MiniMax-M1 model."}]},
|
42 |
+
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
43 |
+
]
|
44 |
+
|
45 |
+
# 启用函数调用工具
|
46 |
+
tools = get_default_tools()
|
47 |
+
|
48 |
+
# 应用聊天模板,并加入工具定义
|
49 |
+
text = tokenizer.apply_chat_template(
|
50 |
+
messages,
|
51 |
+
tokenize=False,
|
52 |
+
add_generation_prompt=True,
|
53 |
+
tools=tools
|
54 |
+
)
|
55 |
+
```
|
56 |
+
|
57 |
+
## 🛠️ 函数调用的定义
|
58 |
+
|
59 |
+
### 函数结构体
|
60 |
+
|
61 |
+
函数调用需要在请求体中定义 `tools` 字段,每个函数由以下部分组成:
|
62 |
+
|
63 |
+
```json
|
64 |
+
{
|
65 |
+
"tools": [
|
66 |
+
{
|
67 |
+
"name": "search_web",
|
68 |
+
"description": "搜索函数。",
|
69 |
+
"parameters": {
|
70 |
+
"properties": {
|
71 |
+
"query_list": {
|
72 |
+
"description": "进行搜索的关键词,列表元素个数为1。",
|
73 |
+
"items": { "type": "string" },
|
74 |
+
"type": "array"
|
75 |
+
},
|
76 |
+
"query_tag": {
|
77 |
+
"description": "query的分类",
|
78 |
+
"items": { "type": "string" },
|
79 |
+
"type": "array"
|
80 |
+
}
|
81 |
+
},
|
82 |
+
"required": [ "query_list", "query_tag" ],
|
83 |
+
"type": "object"
|
84 |
+
}
|
85 |
+
}
|
86 |
+
]
|
87 |
+
}
|
88 |
+
```
|
89 |
+
|
90 |
+
**字段说明:**
|
91 |
+
- `name`: 函数名称
|
92 |
+
- `description`: 函数功能描述
|
93 |
+
- `parameters`: 函数参数定义
|
94 |
+
- `properties`: 参数属性定义,key 是参数名,value 包含参数的详细描述
|
95 |
+
- `required`: 必填参数列表
|
96 |
+
- `type`: 参数类型(通常为 "object")
|
97 |
+
|
98 |
+
### 模型内部处理格式
|
99 |
+
|
100 |
+
在模型内部处理时,函数定义会被转换为特殊格式并拼接到输入文本中:
|
101 |
+
|
102 |
+
```
|
103 |
+
]~!b[]~b]system ai_setting=MiniMax AI
|
104 |
+
MiniMax AI是由上海稀宇科技有限公司(MiniMax)自主研发的AI助理。[e~[
|
105 |
+
]~b]system tool_setting=tools
|
106 |
+
You are provided with these tools:
|
107 |
+
<tools>
|
108 |
+
{"name": "search_web", "description": "搜索函数。", "parameters": {"properties": {"query_list": {"description": "进行搜索的关键词,列表元素个数为1。", "items": {"type": "string"}, "type": "array"}, "query_tag": {"description": "query的分类", "items": {"type": "string"}, "type": "array"}}, "required": ["query_list", "query_tag"], "type": "object"}}
|
109 |
+
</tools>
|
110 |
+
|
111 |
+
If you need to call tools, please respond with <tool_calls></tool_calls> XML tags, and provide tool-name and json-object of arguments, following the format below:
|
112 |
+
<tool_calls>
|
113 |
+
{"name": <tool-name>, "arguments": <args-json-object>}
|
114 |
+
...
|
115 |
+
</tool_calls>[e~[
|
116 |
+
]~b]user name=用户
|
117 |
+
OpenAI 和 Gemini 的最近一次发布会都是什么时候?[e~[
|
118 |
+
]~b]ai name=MiniMax AI
|
119 |
+
```
|
120 |
+
|
121 |
+
### 模型输出格式
|
122 |
+
|
123 |
+
模型会以以下格式输出函数调用:
|
124 |
+
|
125 |
+
```xml
|
126 |
+
<think>
|
127 |
+
Okay, I will search for the OpenAI and Gemini latest release.
|
128 |
+
</think>
|
129 |
+
<tool_calls>
|
130 |
+
{"name": "search_web", "arguments": {"query_tag": ["technology", "events"], "query_list": ["\"OpenAI\" \"latest\" \"release\""]}}
|
131 |
+
{"name": "search_web", "arguments": {"query_tag": ["technology", "events"], "query_list": ["\"Gemini\" \"latest\" \"release\""]}}
|
132 |
+
</tool_calls>
|
133 |
+
```
|
134 |
+
|
135 |
+
## 📥 函数调用结果处理
|
136 |
+
|
137 |
+
### 解析函数调用
|
138 |
+
|
139 |
+
您可以使用以下代码解析模型输出的函数调用:
|
140 |
+
|
141 |
+
```python
|
142 |
+
import re
|
143 |
+
import json
|
144 |
+
|
145 |
+
def parse_function_calls(content: str):
|
146 |
+
"""
|
147 |
+
解析模型输出中的函数调用
|
148 |
+
"""
|
149 |
+
function_calls = []
|
150 |
+
|
151 |
+
# 匹配 <tool_calls> 标签内的内容
|
152 |
+
tool_calls_pattern = r"<tool_calls>(.*?)</tool_calls>"
|
153 |
+
tool_calls_match = re.search(tool_calls_pattern, content, re.DOTALL)
|
154 |
+
|
155 |
+
if not tool_calls_match:
|
156 |
+
return function_calls
|
157 |
+
|
158 |
+
tool_calls_content = tool_calls_match.group(1).strip()
|
159 |
+
|
160 |
+
# 解析每个函数调用(每行一个JSON对象)
|
161 |
+
for line in tool_calls_content.split('\n'):
|
162 |
+
line = line.strip()
|
163 |
+
if not line:
|
164 |
+
continue
|
165 |
+
|
166 |
+
try:
|
167 |
+
# 解析JSON格式的函数调用
|
168 |
+
call_data = json.loads(line)
|
169 |
+
function_name = call_data.get("name")
|
170 |
+
arguments = call_data.get("arguments", {})
|
171 |
+
|
172 |
+
function_calls.append({
|
173 |
+
"name": function_name,
|
174 |
+
"arguments": arguments
|
175 |
+
})
|
176 |
+
|
177 |
+
print(f"调用函数: {function_name}, 参数: {arguments}")
|
178 |
+
|
179 |
+
except json.JSONDecodeError as e:
|
180 |
+
print(f"参数解析失败: {line}, 错误: {e}")
|
181 |
+
|
182 |
+
return function_calls
|
183 |
+
|
184 |
+
# 示例:处理天气查询函数
|
185 |
+
def execute_function_call(function_name: str, arguments: dict):
|
186 |
+
"""
|
187 |
+
执行函数调用并返回结果
|
188 |
+
"""
|
189 |
+
if function_name == "get_current_weather":
|
190 |
+
location = arguments.get("location", "未知位置")
|
191 |
+
# 构建函数执行结果
|
192 |
+
return {
|
193 |
+
"role": "tool",
|
194 |
+
"name": function_name,
|
195 |
+
"content": json.dumps({
|
196 |
+
"location": location,
|
197 |
+
"temperature": "25",
|
198 |
+
"unit": "celsius",
|
199 |
+
"weather": "晴朗"
|
200 |
+
}, ensure_ascii=False)
|
201 |
+
}
|
202 |
+
elif function_name == "search_web":
|
203 |
+
query_list = arguments.get("query_list", [])
|
204 |
+
query_tag = arguments.get("query_tag", [])
|
205 |
+
# 模拟搜索结果
|
206 |
+
return {
|
207 |
+
"role": "tool",
|
208 |
+
"name": function_name,
|
209 |
+
"content": f"搜索关键词: {query_list}, 分类: {query_tag}\n搜索结果: 相关信息已找到"
|
210 |
+
}
|
211 |
+
|
212 |
+
return None
|
213 |
+
```
|
214 |
+
|
215 |
+
### 将函数执行结果返回给模型
|
216 |
+
|
217 |
+
成功解析函数调用后,您应将函数执行结果添加到对话历史中,以便模型在后续交互中能够访问和利用这些信息。
|
218 |
+
|
219 |
+
#### 单个结果
|
220 |
+
|
221 |
+
假如模型调用了 `search_web` 函数,您可以参考如下格式添加执行结果,`name` 字段为具体的函数名称。
|
222 |
+
|
223 |
+
```json
|
224 |
+
{
|
225 |
+
"data": [
|
226 |
+
{
|
227 |
+
"role": "tool",
|
228 |
+
"name": "search_web",
|
229 |
+
"content": "search_result"
|
230 |
+
}
|
231 |
+
]
|
232 |
+
}
|
233 |
+
```
|
234 |
+
|
235 |
+
对应如下的模型输入格式:
|
236 |
+
```
|
237 |
+
]~b]tool name=search_web
|
238 |
+
search_result[e~[
|
239 |
+
```
|
240 |
+
|
241 |
+
|
242 |
+
#### 多个结果
|
243 |
+
假如模型同时调用了 `search_web` 和 `get_current_weather` 函数,您可以参考如下格式添加执行结果,`name` 字段为"tools",`content`包含多个结果。
|
244 |
+
|
245 |
+
```json
|
246 |
+
{
|
247 |
+
"data": [
|
248 |
+
{
|
249 |
+
"role": "tool",
|
250 |
+
"name": "tools",
|
251 |
+
"content": "Tool name: search_web\nTool result: test_result1\n\nTool name: get_current_weather\nTool result: test_result2"
|
252 |
+
}
|
253 |
+
]
|
254 |
+
}
|
255 |
+
```
|
256 |
+
|
257 |
+
对应如下的模型输入格式:
|
258 |
+
```
|
259 |
+
]~b]tool name=tools
|
260 |
+
Tool name: search_web
|
261 |
+
Tool result: test_result1
|
262 |
+
|
263 |
+
Tool name: get_current_weather
|
264 |
+
Tool result: test_result2[e~[
|
265 |
+
```
|
266 |
+
|
267 |
+
虽然我们建议您参考以上格式,但只要返回给模型的输入易于理解,`name` 和 `content` 的具体内容完全由您自主决定。
|
docs/transformers_deployment_guide.md
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🚀 MiniMax Model Transformers Deployment Guide
|
2 |
+
|
3 |
+
[Transformers中文版部署指南](./transformers_deployment_guide_cn.md)
|
4 |
+
|
5 |
+
## 📖 Introduction
|
6 |
+
|
7 |
+
This guide will help you deploy the MiniMax-M1 model using the [Transformers](https://huggingface.co/docs/transformers/index) library. Transformers is a widely used deep learning library that provides a rich collection of pre-trained models and flexible model operation interfaces.
|
8 |
+
|
9 |
+
## 🛠️ Environment Setup
|
10 |
+
|
11 |
+
### Installing Transformers
|
12 |
+
|
13 |
+
```bash
|
14 |
+
pip install transformers torch accelerate
|
15 |
+
```
|
16 |
+
|
17 |
+
## 📋 Basic Usage Example
|
18 |
+
|
19 |
+
The pre-trained model can be used as follows:
|
20 |
+
|
21 |
+
```python
|
22 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
23 |
+
|
24 |
+
MODEL_PATH = "{MODEL_PATH}"
|
25 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", trust_remote_code=True)
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
27 |
+
|
28 |
+
messages = [
|
29 |
+
{"role": "user", "content": [{"type": "text", "text": "What is your favourite condiment?"}]},
|
30 |
+
{"role": "assistant", "content": [{"type": "text", "text": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}]},
|
31 |
+
{"role": "user", "content": [{"type": "text", "text": "Do you have mayonnaise recipes?"}]}
|
32 |
+
]
|
33 |
+
|
34 |
+
text = tokenizer.apply_chat_template(
|
35 |
+
messages,
|
36 |
+
tokenize=False,
|
37 |
+
add_generation_prompt=True
|
38 |
+
)
|
39 |
+
|
40 |
+
model_inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
41 |
+
|
42 |
+
generation_config = GenerationConfig(
|
43 |
+
max_new_tokens=20,
|
44 |
+
eos_token_id=tokenizer.eos_token_id,
|
45 |
+
use_cache=True,
|
46 |
+
)
|
47 |
+
|
48 |
+
generated_ids = model.generate(**model_inputs, generation_config=generation_config)
|
49 |
+
|
50 |
+
generated_ids = [
|
51 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
52 |
+
]
|
53 |
+
|
54 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
55 |
+
print(response)
|
56 |
+
```
|
57 |
+
|
58 |
+
## ⚡ Performance Optimization
|
59 |
+
|
60 |
+
### Speeding up with Flash Attention
|
61 |
+
|
62 |
+
The code snippet above showcases inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
|
63 |
+
|
64 |
+
First, make sure to install the latest version of Flash Attention 2:
|
65 |
+
|
66 |
+
```bash
|
67 |
+
pip install -U flash-attn --no-build-isolation
|
68 |
+
```
|
69 |
+
|
70 |
+
Also make sure that you have hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the [Flash Attention repository](https://github.com/Dao-AILab/flash-attention). Additionally, ensure you load your model in half-precision (e.g. `torch.float16`).
|
71 |
+
|
72 |
+
To load and run a model using Flash Attention-2, refer to the snippet below:
|
73 |
+
|
74 |
+
```python
|
75 |
+
import torch
|
76 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
77 |
+
|
78 |
+
MODEL_PATH = "{MODEL_PATH}"
|
79 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, trust_remote_code=True, torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto")
|
80 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
81 |
+
|
82 |
+
prompt = "My favourite condiment is"
|
83 |
+
|
84 |
+
model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
|
85 |
+
generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
|
86 |
+
response = tokenizer.batch_decode(generated_ids)[0]
|
87 |
+
print(response)
|
88 |
+
```
|
89 |
+
|
90 |
+
## 📮 Getting Support
|
91 |
+
|
92 |
+
If you encounter any issues while deploying the MiniMax-M1 model:
|
93 |
+
- Please check our official documentation
|
94 |
+
- Contact our technical support team through official channels
|
95 |
+
- Submit an Issue on our GitHub repository
|
96 |
+
|
97 |
+
We continuously optimize the deployment experience on Transformers and welcome your feedback!
|
docs/transformers_deployment_guide_cn.md
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🚀 MiniMax 模型 Transformers 部署指南
|
2 |
+
|
3 |
+
## 📖 简介
|
4 |
+
|
5 |
+
本指南将帮助您使用 [Transformers](https://huggingface.co/docs/transformers/index) 库部署 MiniMax-M1 模型。Transformers 是一个广泛使用的深度学习库,提供了丰富的预训练模型和灵活的模型操作接口。
|
6 |
+
|
7 |
+
## 🛠️ 环境准备
|
8 |
+
|
9 |
+
### 安装 Transformers
|
10 |
+
|
11 |
+
```bash
|
12 |
+
pip install transformers torch accelerate
|
13 |
+
```
|
14 |
+
|
15 |
+
## 📋 基本使用示例
|
16 |
+
|
17 |
+
预训练模型可以按照以下方式使用:
|
18 |
+
|
19 |
+
```python
|
20 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
21 |
+
|
22 |
+
MODEL_PATH = "{MODEL_PATH}"
|
23 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", trust_remote_code=True)
|
24 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
25 |
+
|
26 |
+
messages = [
|
27 |
+
{"role": "user", "content": [{"type": "text", "text": "What is your favourite condiment?"}]},
|
28 |
+
{"role": "assistant", "content": [{"type": "text", "text": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}]},
|
29 |
+
{"role": "user", "content": [{"type": "text", "text": "Do you have mayonnaise recipes?"}]}
|
30 |
+
]
|
31 |
+
|
32 |
+
text = tokenizer.apply_chat_template(
|
33 |
+
messages,
|
34 |
+
tokenize=False,
|
35 |
+
add_generation_prompt=True
|
36 |
+
)
|
37 |
+
|
38 |
+
model_inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
39 |
+
|
40 |
+
generation_config = GenerationConfig(
|
41 |
+
max_new_tokens=20,
|
42 |
+
eos_token_id=tokenizer.eos_token_id,
|
43 |
+
use_cache=True,
|
44 |
+
)
|
45 |
+
|
46 |
+
generated_ids = model.generate(**model_inputs, generation_config=generation_config)
|
47 |
+
|
48 |
+
generated_ids = [
|
49 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
50 |
+
]
|
51 |
+
|
52 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
53 |
+
print(response)
|
54 |
+
```
|
55 |
+
|
56 |
+
## ⚡ 性能优化
|
57 |
+
|
58 |
+
### 使用 Flash Attention 加速
|
59 |
+
|
60 |
+
上面的代码片段展示了不使用任何优化技巧的推理过程。但通过利用 [Flash Attention](../perf_train_gpu_one#flash-attention-2),可以大幅加速模型,因为它提供了模型内部使用的注意力机制的更快实现。
|
61 |
+
|
62 |
+
首先,确保安装最新版本的 Flash Attention 2:
|
63 |
+
|
64 |
+
```bash
|
65 |
+
pip install -U flash-attn --no-build-isolation
|
66 |
+
```
|
67 |
+
|
68 |
+
还要确保您拥有与 Flash-Attention 2 兼容的硬件。在[Flash Attention 官方仓库](https://github.com/Dao-AILab/flash-attention)的官方文档中了解更多信息。此外,请确保以半精度(例如 `torch.float16`)加载模型。
|
69 |
+
|
70 |
+
要使用 Flash Attention-2 加载和运行模型,请参考以下代码片段:
|
71 |
+
|
72 |
+
```python
|
73 |
+
import torch
|
74 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
75 |
+
|
76 |
+
MODEL_PATH = "{MODEL_PATH}"
|
77 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, trust_remote_code=True, torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto")
|
78 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
79 |
+
|
80 |
+
prompt = "My favourite condiment is"
|
81 |
+
|
82 |
+
model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
|
83 |
+
generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
|
84 |
+
response = tokenizer.batch_decode(generated_ids)[0]
|
85 |
+
print(response)
|
86 |
+
```
|
87 |
+
|
88 |
+
## 📮 获取支持
|
89 |
+
|
90 |
+
如果您在部署 MiniMax-M1 模型过程中遇到任何问题:
|
91 |
+
- 请查看我们的官方文档
|
92 |
+
- 通过官方渠道联系我们的技术支持团队
|
93 |
+
- 在我们的 GitHub 仓库提交 Issue
|
94 |
+
|
95 |
+
我们会持续优化 Transformers 上的部署体验,欢迎您的反馈!
|
docs/vllm_deployment_guide.md
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🚀 MiniMax Models vLLM Deployment Guide
|
2 |
+
|
3 |
+
[vLLM中文版部署指南](./vllm_deployment_guide_cn.md)
|
4 |
+
|
5 |
+
## 📖 Introduction
|
6 |
+
|
7 |
+
We recommend using [vLLM](https://docs.vllm.ai/en/latest/) to deploy [MiniMax-M1](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k) model. Based on our testing, vLLM performs excellently when deploying this model, with the following features:
|
8 |
+
|
9 |
+
- 🔥 Outstanding service throughput performance
|
10 |
+
- ⚡ Efficient and intelligent memory management
|
11 |
+
- 📦 Powerful batch request processing capability
|
12 |
+
- ⚙️ Deeply optimized underlying performance
|
13 |
+
|
14 |
+
The MiniMax-M1 model can run efficiently on a single server equipped with 8 H800 or 8 H20 GPUs. In terms of hardware configuration, a server with 8 H800 GPUs can process context inputs up to 2 million tokens, while a server equipped with 8 H20 GPUs can support ultra-long context processing capabilities of up to 5 million tokens.
|
15 |
+
|
16 |
+
## 💾 Obtaining MiniMax Models
|
17 |
+
|
18 |
+
### MiniMax-M1 Model Obtaining
|
19 |
+
|
20 |
+
You can download the model from our official HuggingFace repository: [MiniMax-M1-40k](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k), [MiniMax-M1-80k](https://huggingface.co/MiniMaxAI/MiniMax-M1-80k)
|
21 |
+
|
22 |
+
Download command:
|
23 |
+
```
|
24 |
+
pip install -U huggingface-hub
|
25 |
+
huggingface-cli download MiniMaxAI/MiniMax-M1-40k
|
26 |
+
# huggingface-cli download MiniMaxAI/MiniMax-M1-80k
|
27 |
+
|
28 |
+
# If you encounter network issues, you can set a proxy
|
29 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
30 |
+
```
|
31 |
+
|
32 |
+
Or download using git:
|
33 |
+
|
34 |
+
```bash
|
35 |
+
git lfs install
|
36 |
+
git clone https://huggingface.co/MiniMaxAI/MiniMax-M1-40k
|
37 |
+
git clone https://huggingface.co/MiniMaxAI/MiniMax-M1-80k
|
38 |
+
```
|
39 |
+
|
40 |
+
⚠️ **Important Note**: Please ensure that [Git LFS](https://git-lfs.github.com/) is installed on your system, which is necessary for completely downloading the model weight files.
|
41 |
+
|
42 |
+
## 🛠️ Deployment Options
|
43 |
+
|
44 |
+
### Option 1: Deploy Using Docker (Recommended)
|
45 |
+
|
46 |
+
To ensure consistency and stability of the deployment environment, we recommend using Docker for deployment.
|
47 |
+
|
48 |
+
⚠️ **Version Requirements**:
|
49 |
+
- MiniMax-M1 model requires vLLM version 0.8.3 or later for full support
|
50 |
+
- If you are using a Docker image with vLLM version lower than the required version, you will need to:
|
51 |
+
1. Update to the latest vLLM code
|
52 |
+
2. Recompile vLLM from source. Follow the compilation instructions in Solution 2 of the Common Issues section
|
53 |
+
- Special Note: For vLLM versions between 0.8.3 and 0.9.2, you need to modify the model configuration:
|
54 |
+
1. Open `config.json`
|
55 |
+
2. Change `config['architectures'] = ["MiniMaxM1ForCausalLM"]` to `config['architectures'] = ["MiniMaxText01ForCausalLM"]`
|
56 |
+
|
57 |
+
1. Get the container image:
|
58 |
+
```bash
|
59 |
+
docker pull vllm/vllm-openai:v0.8.3
|
60 |
+
```
|
61 |
+
|
62 |
+
2. Run the container:
|
63 |
+
```bash
|
64 |
+
# Set environment variables
|
65 |
+
IMAGE=vllm/vllm-openai:v0.8.3
|
66 |
+
MODEL_DIR=<model storage path>
|
67 |
+
CODE_DIR=<code path>
|
68 |
+
NAME=MiniMaxImage
|
69 |
+
|
70 |
+
# Docker run configuration
|
71 |
+
DOCKER_RUN_CMD="--network=host --privileged --ipc=host --ulimit memlock=-1 --shm-size=2gb --rm --gpus all --ulimit stack=67108864"
|
72 |
+
|
73 |
+
# Start the container
|
74 |
+
sudo docker run -it \
|
75 |
+
-v $MODEL_DIR:$MODEL_DIR \
|
76 |
+
-v $CODE_DIR:$CODE_DIR \
|
77 |
+
--name $NAME \
|
78 |
+
$DOCKER_RUN_CMD \
|
79 |
+
$IMAGE /bin/bash
|
80 |
+
```
|
81 |
+
|
82 |
+
|
83 |
+
### Option 2: Direct Installation of vLLM
|
84 |
+
|
85 |
+
If your environment meets the following requirements:
|
86 |
+
|
87 |
+
- CUDA 12.1
|
88 |
+
- PyTorch 2.1
|
89 |
+
|
90 |
+
You can directly install vLLM
|
91 |
+
|
92 |
+
Installation command:
|
93 |
+
```bash
|
94 |
+
pip install vllm
|
95 |
+
```
|
96 |
+
|
97 |
+
💡 If you are using other environment configurations, please refer to the [vLLM Installation Guide](https://docs.vllm.ai/en/latest/getting_started/installation.html)
|
98 |
+
|
99 |
+
## 🚀 Starting the Service
|
100 |
+
|
101 |
+
### Launch MiniMax-M1 Service
|
102 |
+
|
103 |
+
```bash
|
104 |
+
export SAFETENSORS_FAST_GPU=1
|
105 |
+
export VLLM_USE_V1=0
|
106 |
+
python3 -m vllm.entrypoints.openai.api_server \
|
107 |
+
--model <model storage path> \
|
108 |
+
--tensor-parallel-size 8 \
|
109 |
+
--trust-remote-code \
|
110 |
+
--quantization experts_int8 \
|
111 |
+
--max_model_len 4096 \
|
112 |
+
--dtype bfloat16
|
113 |
+
```
|
114 |
+
|
115 |
+
### API Call Example
|
116 |
+
|
117 |
+
```bash
|
118 |
+
curl http://localhost:8000/v1/chat/completions \
|
119 |
+
-H "Content-Type: application/json" \
|
120 |
+
-d '{
|
121 |
+
"model": "MiniMaxAI/MiniMax-M1",
|
122 |
+
"messages": [
|
123 |
+
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
124 |
+
{"role": "user", "content": [{"type": "text", "text": "Who won the world series in 2020?"}]}
|
125 |
+
]
|
126 |
+
}'
|
127 |
+
```
|
128 |
+
|
129 |
+
## ❗ Common Issues
|
130 |
+
|
131 |
+
### Module Loading Problems
|
132 |
+
If you encounter the following error:
|
133 |
+
```
|
134 |
+
import vllm._C # noqa
|
135 |
+
ModuleNotFoundError: No module named 'vllm._C'
|
136 |
+
```
|
137 |
+
|
138 |
+
Or
|
139 |
+
|
140 |
+
```
|
141 |
+
MiniMax-M1 model is not currently supported
|
142 |
+
```
|
143 |
+
|
144 |
+
We provide two solutions:
|
145 |
+
|
146 |
+
#### Solution 1: Copy Dependency Files
|
147 |
+
```bash
|
148 |
+
cd <working directory>
|
149 |
+
git clone https://github.com/vllm-project/vllm.git
|
150 |
+
cd vllm
|
151 |
+
cp /usr/local/lib/python3.12/dist-packages/vllm/*.so vllm
|
152 |
+
cp -r /usr/local/lib/python3.12/dist-packages/vllm/vllm_flash_attn/* vllm/vllm_flash_attn
|
153 |
+
```
|
154 |
+
|
155 |
+
#### Solution 2: Install from Source
|
156 |
+
```bash
|
157 |
+
cd <working directory>
|
158 |
+
git clone https://github.com/vllm-project/vllm.git
|
159 |
+
|
160 |
+
cd vllm/
|
161 |
+
pip install -e .
|
162 |
+
```
|
163 |
+
|
164 |
+
## 📮 Getting Support
|
165 |
+
|
166 |
+
If you encounter any issues while deploying MiniMax-M1 model:
|
167 |
+
- Please check our official documentation
|
168 |
+
- Contact our technical support team through official channels
|
169 |
+
- Submit an [Issue](https://github.com/MiniMax-AI/MiniMax-M1/issues) on our GitHub repository
|
170 |
+
|
171 |
+
We will continuously optimize the deployment experience of this model and welcome your feedback!
|
docs/vllm_deployment_guide_cn.md
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🚀 MiniMax 模型 vLLM 部署指南
|
2 |
+
|
3 |
+
## 📖 简介
|
4 |
+
|
5 |
+
我们推荐使用 [vLLM](https://docs.vllm.ai/en/latest/) 来部署 [MiniMax-M1](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k) 模型。经过我们的测试,vLLM 在部署这个模型时表现出色,具有以下特点:
|
6 |
+
|
7 |
+
- 🔥 卓越的服务吞吐量性能
|
8 |
+
- ⚡ 高效智能的内存管理机制
|
9 |
+
- 📦 强大的批量请求处理能力
|
10 |
+
- ⚙️ 深度优化的底层性能
|
11 |
+
|
12 |
+
MiniMax-M1 模型可在单台配备8个H800或8个H20 GPU的服务器上高效运行。在硬件配置方面,搭载8个H800 GPU的服务器可处理长达200万token的上下文输入,而配备8个H20 GPU的服务器则能够支持高达500万token的超长上下文处理能力。
|
13 |
+
|
14 |
+
## 💾 获取 MiniMax 模型
|
15 |
+
|
16 |
+
### MiniMax-M1 模型获取
|
17 |
+
|
18 |
+
您可以从我们的官方 HuggingFace 仓库下载模型:[MiniMax-M1-40k](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k)、[MiniMax-M1-80k](https://huggingface.co/MiniMaxAI/MiniMax-M1-80k)
|
19 |
+
|
20 |
+
下载命令:
|
21 |
+
```
|
22 |
+
pip install -U huggingface-hub
|
23 |
+
huggingface-cli download MiniMaxAI/MiniMax-M1-40k
|
24 |
+
# huggingface-cli download MiniMaxAI/MiniMax-M1-80k
|
25 |
+
|
26 |
+
# 如果遇到网络问题,可以设置代理
|
27 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
28 |
+
```
|
29 |
+
|
30 |
+
或者使用 git 下载:
|
31 |
+
|
32 |
+
```bash
|
33 |
+
git lfs install
|
34 |
+
git clone https://huggingface.co/MiniMaxAI/MiniMax-M1-40k
|
35 |
+
git clone https://huggingface.co/MiniMaxAI/MiniMax-M1-80k
|
36 |
+
```
|
37 |
+
|
38 |
+
⚠️ **重要提示**:请确保系统已安装 [Git LFS](https://git-lfs.github.com/),这对于完整下载模型权重文件是必需的。
|
39 |
+
|
40 |
+
## 🛠️ 部署方案
|
41 |
+
|
42 |
+
### 方案一:使用 Docker 部署(推荐)
|
43 |
+
|
44 |
+
为确保部署环境的一致性和稳定性,我们推荐使用 Docker 进行部署。
|
45 |
+
|
46 |
+
⚠️ **版本要求**:
|
47 |
+
- 基础要求:vLLM 版本必须 ≥ 0.8.3,以确保对 MiniMax-M1 模型的完整支持
|
48 |
+
- 特殊说明:如果使用 vLLM 0.8.3 至 0.9.2 之间的版本,需要修改模型配置文件:
|
49 |
+
- 打开 `config.json`
|
50 |
+
- 将 `config['architectures'] = ["MiniMaxM1ForCausalLM"]` 修改为 `config['architectures'] = ["MiniMaxText01ForCausalLM"]`
|
51 |
+
|
52 |
+
1. 获取容器镜像:
|
53 |
+
```bash
|
54 |
+
docker pull vllm/vllm-openai:v0.8.3
|
55 |
+
```
|
56 |
+
|
57 |
+
2. 运行容器:
|
58 |
+
```bash
|
59 |
+
# 设置环境变量
|
60 |
+
IMAGE=vllm/vllm-openai:v0.8.3
|
61 |
+
MODEL_DIR=<模型存放路径>
|
62 |
+
CODE_DIR=<代码路径>
|
63 |
+
NAME=MiniMaxImage
|
64 |
+
|
65 |
+
# Docker运行配置
|
66 |
+
DOCKER_RUN_CMD="--network=host --privileged --ipc=host --ulimit memlock=-1 --shm-size=2gb --rm --gpus all --ulimit stack=67108864"
|
67 |
+
|
68 |
+
# 启动容器
|
69 |
+
sudo docker run -it \
|
70 |
+
-v $MODEL_DIR:$MODEL_DIR \
|
71 |
+
-v $CODE_DIR:$CODE_DIR \
|
72 |
+
--name $NAME \
|
73 |
+
$DOCKER_RUN_CMD \
|
74 |
+
$IMAGE /bin/bash
|
75 |
+
```
|
76 |
+
|
77 |
+
|
78 |
+
### 方案二:直接安装 vLLM
|
79 |
+
|
80 |
+
如果您的环境满足以下要求:
|
81 |
+
|
82 |
+
- CUDA 12.1
|
83 |
+
- PyTorch 2.1
|
84 |
+
|
85 |
+
可以直接安装 vLLM
|
86 |
+
|
87 |
+
安装命令:
|
88 |
+
```bash
|
89 |
+
pip install vllm
|
90 |
+
```
|
91 |
+
|
92 |
+
💡 如果您使用其他环境配置,请参考 [vLLM 安装指南](https://docs.vllm.ai/en/latest/getting_started/installation.html)
|
93 |
+
|
94 |
+
## 🚀 启动服务
|
95 |
+
|
96 |
+
### 启动 MiniMax-M1 服务
|
97 |
+
|
98 |
+
```bash
|
99 |
+
export SAFETENSORS_FAST_GPU=1
|
100 |
+
export VLLM_USE_V1=0
|
101 |
+
python3 -m vllm.entrypoints.openai.api_server \
|
102 |
+
--model <模型存放路径> \
|
103 |
+
--tensor-parallel-size 8 \
|
104 |
+
--trust-remote-code \
|
105 |
+
--quantization experts_int8 \
|
106 |
+
--max_model_len 4096 \
|
107 |
+
--dtype bfloat16
|
108 |
+
```
|
109 |
+
|
110 |
+
### API 调用示例
|
111 |
+
|
112 |
+
```bash
|
113 |
+
curl http://localhost:8000/v1/chat/completions \
|
114 |
+
-H "Content-Type: application/json" \
|
115 |
+
-d '{
|
116 |
+
"model": "MiniMaxAI/MiniMax-M1",
|
117 |
+
"messages": [
|
118 |
+
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
119 |
+
{"role": "user", "content": [{"type": "text", "text": "Who won the world series in 2020?"}]}
|
120 |
+
]
|
121 |
+
}'
|
122 |
+
```
|
123 |
+
|
124 |
+
## ❗ 常见问题
|
125 |
+
|
126 |
+
### 模块加载问题
|
127 |
+
如果遇到以下错误:
|
128 |
+
```
|
129 |
+
import vllm._C # noqa
|
130 |
+
ModuleNotFoundError: No module named 'vllm._C'
|
131 |
+
```
|
132 |
+
|
133 |
+
或
|
134 |
+
|
135 |
+
```
|
136 |
+
当前并不支持 MiniMax-M1 模型
|
137 |
+
```
|
138 |
+
|
139 |
+
我们提供两种解决方案:
|
140 |
+
|
141 |
+
#### 解决方案一:复制依赖文件
|
142 |
+
```bash
|
143 |
+
cd <工作目录>
|
144 |
+
git clone https://github.com/vllm-project/vllm.git
|
145 |
+
cd vllm
|
146 |
+
cp /usr/local/lib/python3.12/dist-packages/vllm/*.so vllm
|
147 |
+
cp -r /usr/local/lib/python3.12/dist-packages/vllm/vllm_flash_attn/* vllm/vllm_flash_attn
|
148 |
+
```
|
149 |
+
|
150 |
+
#### 解决方案二:从源码安装
|
151 |
+
```bash
|
152 |
+
cd <工作目录>
|
153 |
+
git clone https://github.com/vllm-project/vllm.git
|
154 |
+
|
155 |
+
cd vllm/
|
156 |
+
pip install -e .
|
157 |
+
```
|
158 |
+
|
159 |
+
## 📮 获取支持
|
160 |
+
|
161 |
+
如果您在部署 MiniMax-M1 模型过程中遇到任何问题:
|
162 |
+
- 请查看我们的官方文档
|
163 |
+
- 通过官方渠道联系我们的技术支持团队
|
164 |
+
- 在我们的 GitHub 仓库提交 [Issue](https://github.com/MiniMax-AI/MiniMax-M1/issues)
|
165 |
+
|
166 |
+
我们会持续优化模型的部署体验,欢迎您的反馈!
|
figures/TextBench.png
ADDED
![]() |
figures/wechat-qrcode.jpeg
ADDED
![]() |
main.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, QuantoConfig, GenerationConfig
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
"""
|
6 |
+
usage:
|
7 |
+
export SAFETENSORS_FAST_GPU=1
|
8 |
+
python main.py --quant_type int8 --world_size 8 --model_id <model_path>
|
9 |
+
"""
|
10 |
+
|
11 |
+
def generate_quanto_config(hf_config: AutoConfig, quant_type: str):
|
12 |
+
QUANT_TYPE_MAP = {
|
13 |
+
"default": None,
|
14 |
+
"int8": QuantoConfig(
|
15 |
+
weights="int8",
|
16 |
+
modules_to_not_convert=[
|
17 |
+
"lm_head",
|
18 |
+
"embed_tokens",
|
19 |
+
] + [f"model.layers.{i}.coefficient" for i in range(hf_config.num_hidden_layers)]
|
20 |
+
+ [f"model.layers.{i}.block_sparse_moe.gate" for i in range(hf_config.num_hidden_layers)]
|
21 |
+
),
|
22 |
+
}
|
23 |
+
return QUANT_TYPE_MAP[quant_type]
|
24 |
+
|
25 |
+
|
26 |
+
def parse_args():
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
parser.add_argument("--quant_type", type=str, default="default", choices=["default", "int8"])
|
29 |
+
parser.add_argument("--model_id", type=str, required=True)
|
30 |
+
parser.add_argument("--world_size", type=int, required=True)
|
31 |
+
return parser.parse_args()
|
32 |
+
|
33 |
+
|
34 |
+
def check_params(args, hf_config: AutoConfig):
|
35 |
+
if args.quant_type == "int8":
|
36 |
+
assert args.world_size >= 8, "int8 weight-only quantization requires at least 8 GPUs"
|
37 |
+
|
38 |
+
assert hf_config.num_hidden_layers % args.world_size == 0, f"num_hidden_layers({hf_config.num_hidden_layers}) must be divisible by world_size({args.world_size})"
|
39 |
+
|
40 |
+
|
41 |
+
@torch.no_grad()
|
42 |
+
def main():
|
43 |
+
args = parse_args()
|
44 |
+
print("\n=============== Argument ===============")
|
45 |
+
for key in vars(args):
|
46 |
+
print(f"{key}: {vars(args)[key]}")
|
47 |
+
print("========================================")
|
48 |
+
|
49 |
+
model_id = args.model_id
|
50 |
+
|
51 |
+
hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
|
52 |
+
check_params(args, hf_config)
|
53 |
+
quantization_config = generate_quanto_config(hf_config, args.quant_type)
|
54 |
+
|
55 |
+
device_map = {
|
56 |
+
'model.embed_tokens': 'cuda:0',
|
57 |
+
'model.norm': f'cuda:{args.world_size - 1}',
|
58 |
+
'lm_head': f'cuda:{args.world_size - 1}'
|
59 |
+
}
|
60 |
+
layers_per_device = hf_config.num_hidden_layers // args.world_size
|
61 |
+
for i in range(args.world_size):
|
62 |
+
for j in range(layers_per_device):
|
63 |
+
device_map[f'model.layers.{i * layers_per_device + j}'] = f'cuda:{i}'
|
64 |
+
|
65 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
66 |
+
message = [
|
67 |
+
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
68 |
+
{"role": "user", "content": [{"type": "text", "text": "Hello, what is the weather today?"}]}
|
69 |
+
]
|
70 |
+
tools = [
|
71 |
+
{"name": "get_location", "description": "Get the location of the user.", "parameters": {"type": "object", "properties": {}}},
|
72 |
+
{"name": "get_weather", "description": "Get the weather of a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The name of the city"}}}},
|
73 |
+
{"name": "get_news", "description": "Get the news.", "parameters": {"type": "object", "properties": {"domain": {"type": "string", "description": "The domain of the news"}}}}
|
74 |
+
]
|
75 |
+
text = tokenizer.apply_chat_template(
|
76 |
+
message,
|
77 |
+
tools,
|
78 |
+
tokenize=False,
|
79 |
+
add_generation_prompt=True
|
80 |
+
)
|
81 |
+
model_inputs = tokenizer(text, return_tensors="pt").to("cuda")
|
82 |
+
quantized_model = AutoModelForCausalLM.from_pretrained(
|
83 |
+
model_id,
|
84 |
+
torch_dtype="bfloat16",
|
85 |
+
device_map=device_map,
|
86 |
+
quantization_config=quantization_config,
|
87 |
+
trust_remote_code=True,
|
88 |
+
offload_buffers=True,
|
89 |
+
)
|
90 |
+
generation_config = GenerationConfig(
|
91 |
+
max_new_tokens=20,
|
92 |
+
eos_token_id=200020,
|
93 |
+
use_cache=True,
|
94 |
+
)
|
95 |
+
generated_ids = quantized_model.generate(**model_inputs, generation_config=generation_config)
|
96 |
+
print(f"generated_ids: {generated_ids}")
|
97 |
+
generated_ids = [
|
98 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
99 |
+
]
|
100 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
101 |
+
print(response)
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
main()
|
105 |
+
|
106 |
+
|
merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model.safetensors.index.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
modeling_minimax_m1.py
ADDED
@@ -0,0 +1,1701 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" PyTorch MiniMaxM1 model."""
|
2 |
+
import inspect
|
3 |
+
import math
|
4 |
+
import warnings
|
5 |
+
from typing import List, Optional, Tuple, Union
|
6 |
+
import os
|
7 |
+
import copy
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.utils.checkpoint
|
11 |
+
from torch import nn
|
12 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
from transformers.activations import ACT2FN
|
15 |
+
from transformers.cache_utils import Cache, DynamicCache
|
16 |
+
from transformers.modeling_attn_mask_utils import (
|
17 |
+
_prepare_4d_causal_attention_mask,
|
18 |
+
)
|
19 |
+
from transformers.modeling_outputs import (
|
20 |
+
MoeCausalLMOutputWithPast,
|
21 |
+
MoeModelOutputWithPast,
|
22 |
+
SequenceClassifierOutputWithPast,
|
23 |
+
)
|
24 |
+
from transformers.modeling_utils import PreTrainedModel
|
25 |
+
from transformers.utils import (
|
26 |
+
add_start_docstrings,
|
27 |
+
add_start_docstrings_to_model_forward,
|
28 |
+
is_flash_attn_2_available,
|
29 |
+
is_flash_attn_greater_or_equal_2_10,
|
30 |
+
logging,
|
31 |
+
replace_return_docstrings,
|
32 |
+
)
|
33 |
+
from transformers.utils.import_utils import is_torch_fx_available
|
34 |
+
from .configuration_minimax_m1 import MiniMaxM1Config
|
35 |
+
|
36 |
+
if is_flash_attn_2_available():
|
37 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
38 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
39 |
+
|
40 |
+
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
41 |
+
|
42 |
+
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
43 |
+
# It means that the function will not be traced through and simply appear as a node in the graph.
|
44 |
+
if is_torch_fx_available():
|
45 |
+
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
46 |
+
|
47 |
+
use_triton = eval(os.environ.get("use_triton", default="False"))
|
48 |
+
debug = eval(os.environ.get("debug", default="False"))
|
49 |
+
do_eval = eval(os.environ.get("do_eval", default="False"))
|
50 |
+
eval_and_not_generate = eval(os.environ.get("eval_and_not_generate", default="False"))
|
51 |
+
BLOCK = 256
|
52 |
+
|
53 |
+
logger = logging.get_logger(__name__)
|
54 |
+
|
55 |
+
_CONFIG_FOR_DOC = "MiniMaxM1Config"
|
56 |
+
|
57 |
+
|
58 |
+
def get_activation_fn(activation):
|
59 |
+
if debug:
|
60 |
+
logger.info(f"activation: {activation}")
|
61 |
+
if activation == "gelu":
|
62 |
+
return F.gelu
|
63 |
+
elif activation == "relu":
|
64 |
+
return F.relu
|
65 |
+
elif activation == "elu":
|
66 |
+
return F.elu
|
67 |
+
elif activation == "sigmoid":
|
68 |
+
return F.sigmoid
|
69 |
+
elif activation == "exp":
|
70 |
+
|
71 |
+
def f(x):
|
72 |
+
with torch.no_grad():
|
73 |
+
x_max = torch.max(x, dim=-1, keepdims=True).values
|
74 |
+
y = torch.exp(x - x_max)
|
75 |
+
|
76 |
+
return y
|
77 |
+
|
78 |
+
return f
|
79 |
+
elif activation == "leak":
|
80 |
+
return F.leaky_relu
|
81 |
+
elif activation == "1+elu":
|
82 |
+
|
83 |
+
def f(x):
|
84 |
+
return 1 + F.elu(x)
|
85 |
+
|
86 |
+
return f
|
87 |
+
elif activation == "2+elu":
|
88 |
+
|
89 |
+
def f(x):
|
90 |
+
return 2 + F.elu(x)
|
91 |
+
|
92 |
+
return f
|
93 |
+
elif activation == "silu" or activation == "swish":
|
94 |
+
return F.silu
|
95 |
+
elif activation == "sine":
|
96 |
+
return torch.sin
|
97 |
+
else:
|
98 |
+
logger.info(
|
99 |
+
f"activation: does not support {activation}, use Identity!!!")
|
100 |
+
return lambda x: x
|
101 |
+
|
102 |
+
|
103 |
+
def load_balancing_loss_func(
|
104 |
+
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2,
|
105 |
+
attention_mask: Optional[torch.Tensor] = None
|
106 |
+
) -> float:
|
107 |
+
r"""
|
108 |
+
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
109 |
+
|
110 |
+
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
|
111 |
+
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
112 |
+
experts is too unbalanced.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
|
116 |
+
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
117 |
+
shape [batch_size X sequence_length, num_experts].
|
118 |
+
attention_mask (`torch.Tensor`, None):
|
119 |
+
The attention_mask used in forward function
|
120 |
+
shape [batch_size X sequence_length] if not None.
|
121 |
+
num_experts (`int`, *optional*):
|
122 |
+
Number of experts
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
The auxiliary loss.
|
126 |
+
"""
|
127 |
+
if gate_logits is None or not isinstance(gate_logits, tuple):
|
128 |
+
return 0
|
129 |
+
|
130 |
+
if isinstance(gate_logits, tuple):
|
131 |
+
compute_device = gate_logits[0].device
|
132 |
+
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
133 |
+
|
134 |
+
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
135 |
+
|
136 |
+
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
137 |
+
|
138 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
139 |
+
|
140 |
+
if attention_mask is None:
|
141 |
+
# Compute the percentage of tokens routed to each experts
|
142 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
143 |
+
|
144 |
+
# Compute the average probability of routing to these experts
|
145 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
146 |
+
else:
|
147 |
+
batch_size, sequence_length = attention_mask.shape
|
148 |
+
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
|
149 |
+
|
150 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
151 |
+
expert_attention_mask = (
|
152 |
+
attention_mask[None, :, :, None, None]
|
153 |
+
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
|
154 |
+
.reshape(-1, top_k, num_experts)
|
155 |
+
.to(compute_device)
|
156 |
+
)
|
157 |
+
|
158 |
+
# Compute the percentage of tokens routed to each experts
|
159 |
+
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
|
160 |
+
expert_attention_mask, dim=0
|
161 |
+
)
|
162 |
+
|
163 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
164 |
+
router_per_expert_attention_mask = (
|
165 |
+
attention_mask[None, :, :, None]
|
166 |
+
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
167 |
+
.reshape(-1, num_experts)
|
168 |
+
.to(compute_device)
|
169 |
+
)
|
170 |
+
|
171 |
+
# Compute the average probability of routing to these experts
|
172 |
+
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
|
173 |
+
router_per_expert_attention_mask, dim=0
|
174 |
+
)
|
175 |
+
|
176 |
+
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
177 |
+
return overall_loss * num_experts
|
178 |
+
|
179 |
+
|
180 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
181 |
+
def _get_unpad_data(attention_mask):
|
182 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
183 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
184 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
185 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
186 |
+
return (
|
187 |
+
indices,
|
188 |
+
cu_seqlens,
|
189 |
+
max_seqlen_in_batch,
|
190 |
+
)
|
191 |
+
|
192 |
+
|
193 |
+
class GLU(nn.Module):
|
194 |
+
|
195 |
+
def __init__(self, d1, d2, bias=False):
|
196 |
+
super().__init__()
|
197 |
+
|
198 |
+
self.l1 = nn.Linear(d1, d2, bias=bias)
|
199 |
+
self.l2 = nn.Linear(d1, d2, bias=bias)
|
200 |
+
self.l3 = nn.Linear(d2, d1, bias=bias)
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
o1 = self.l1(x)
|
204 |
+
o2 = self.l2(x)
|
205 |
+
output = o1 * o2
|
206 |
+
output = self.l3(output)
|
207 |
+
return output
|
208 |
+
|
209 |
+
|
210 |
+
class MiniMaxM1LightningAttention(nn.Module):
|
211 |
+
def __init__(self, config: MiniMaxM1Config, layer_idx: Optional[int] = None):
|
212 |
+
super().__init__()
|
213 |
+
bias = False
|
214 |
+
self.hidden_size = config.hidden_size
|
215 |
+
self.num_heads = config.num_attention_heads
|
216 |
+
self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
|
217 |
+
|
218 |
+
self.out_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=bias)
|
219 |
+
self.act = get_activation_fn(config.hidden_act)
|
220 |
+
self.norm = MiniMaxM1RMSNorm(self.head_dim * self.num_heads)
|
221 |
+
|
222 |
+
self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias)
|
223 |
+
self.output_gate = nn.Linear(self.hidden_size, self.head_dim * self.num_heads, bias=bias)
|
224 |
+
|
225 |
+
# for inference only
|
226 |
+
self.offset = 0
|
227 |
+
self.layer_idx = layer_idx
|
228 |
+
|
229 |
+
def forward(
|
230 |
+
self,
|
231 |
+
hidden_states,
|
232 |
+
attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
|
233 |
+
output_attentions: bool = False,
|
234 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
235 |
+
use_cache: bool = False,
|
236 |
+
slope_rate: Optional[torch.Tensor] = None,
|
237 |
+
**kwargs
|
238 |
+
):
|
239 |
+
if (not self.training) and (not do_eval):
|
240 |
+
return self.inference(
|
241 |
+
hidden_states,
|
242 |
+
attn_mask,
|
243 |
+
output_attentions,
|
244 |
+
past_key_value,
|
245 |
+
use_cache,
|
246 |
+
slope_rate,
|
247 |
+
)
|
248 |
+
|
249 |
+
def inference(
|
250 |
+
self,
|
251 |
+
x,
|
252 |
+
attn_mask: Optional[torch.Tensor] = None, # (b, n)
|
253 |
+
output_attentions: bool = False,
|
254 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
255 |
+
use_cache: bool = False,
|
256 |
+
slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
|
257 |
+
):
|
258 |
+
# x: b n d
|
259 |
+
b, n, d = x.shape
|
260 |
+
# linear map
|
261 |
+
qkv = self.act(self.qkv_proj(x))
|
262 |
+
new_shape = qkv.size()[:-1] + (self.num_heads, -1)
|
263 |
+
qkv = qkv.view(*new_shape)
|
264 |
+
q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3)
|
265 |
+
q = q.transpose(1, 2)
|
266 |
+
k = k.transpose(1, 2)
|
267 |
+
v = v.transpose(1, 2)
|
268 |
+
|
269 |
+
if past_key_value is None:
|
270 |
+
self.offset = q.shape[-2]
|
271 |
+
else:
|
272 |
+
self.offset += 1
|
273 |
+
|
274 |
+
# for align with metaseq
|
275 |
+
ratio = torch.exp(-slope_rate)
|
276 |
+
|
277 |
+
# only use for the first time
|
278 |
+
if past_key_value is None:
|
279 |
+
slope_rate = slope_rate.to(torch.float32)
|
280 |
+
if attn_mask is not None:
|
281 |
+
v = v.masked_fill((1 - attn_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0)
|
282 |
+
NUM_BLOCK = (n + BLOCK - 1) // BLOCK
|
283 |
+
b, h, n, d = q.shape
|
284 |
+
e = v.shape[-1]
|
285 |
+
# other
|
286 |
+
array = torch.arange(BLOCK).to(q) + 1
|
287 |
+
q_decay = torch.exp(-slope_rate * array.reshape(-1, 1))
|
288 |
+
k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1)))
|
289 |
+
index = array[:, None] - array[None, :]
|
290 |
+
s_index = slope_rate * index[
|
291 |
+
None,
|
292 |
+
None,
|
293 |
+
]
|
294 |
+
s_index = torch.where(index >= 0, -s_index, float("-inf"))
|
295 |
+
diag_decay = torch.exp(s_index)
|
296 |
+
|
297 |
+
kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device)
|
298 |
+
output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
|
299 |
+
for i in range(NUM_BLOCK):
|
300 |
+
si = i * BLOCK
|
301 |
+
ei = min(si + BLOCK, n)
|
302 |
+
m = ei - si
|
303 |
+
qi = q[:, :, si:ei].contiguous()
|
304 |
+
ki = k[:, :, si:ei].contiguous()
|
305 |
+
vi = v[:, :, si:ei].contiguous()
|
306 |
+
qkv_none_diag = torch.matmul(qi * q_decay[:, :m], kv).to(torch.float32)
|
307 |
+
|
308 |
+
# diag
|
309 |
+
qk = torch.matmul(qi, ki.transpose(-1, -2)).to(torch.float32) * diag_decay[:, :, :m, :m]
|
310 |
+
qkv_diag = torch.matmul(qk, vi.to(torch.float32))
|
311 |
+
block_decay = torch.exp(-slope_rate * m)
|
312 |
+
output[:, :, si:ei] = qkv_none_diag + qkv_diag
|
313 |
+
kv = block_decay * kv + torch.matmul((ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi)
|
314 |
+
|
315 |
+
else:
|
316 |
+
kv = past_key_value
|
317 |
+
output = []
|
318 |
+
for i in range(n):
|
319 |
+
kv = ratio * kv + torch.einsum(
|
320 |
+
"... n d, ... n e -> ... d e",
|
321 |
+
k[:, :, i:i + 1],
|
322 |
+
v[:, :, i:i + 1],
|
323 |
+
)
|
324 |
+
qkv = torch.einsum("... n e, ... e d -> ... n d", q[:, :, i:i + 1], kv.to(q.dtype))
|
325 |
+
output.append(qkv)
|
326 |
+
output = torch.concat(output, dim=-2)
|
327 |
+
# reshape
|
328 |
+
output = rearrange(output, "b h n d -> b n (h d)")
|
329 |
+
# normalize
|
330 |
+
output = self.norm(output)
|
331 |
+
# gate
|
332 |
+
output = F.sigmoid(self.output_gate(x)) * output
|
333 |
+
# outproj
|
334 |
+
output = self.out_proj(output)
|
335 |
+
|
336 |
+
attn_weights = None
|
337 |
+
|
338 |
+
return output, attn_weights, kv
|
339 |
+
|
340 |
+
|
341 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxM1
|
342 |
+
class MiniMaxM1RMSNorm(nn.Module):
|
343 |
+
def __init__(self, hidden_size, eps=1e-6):
|
344 |
+
"""
|
345 |
+
MiniMaxM1RMSNorm is equivalent to T5LayerNorm
|
346 |
+
"""
|
347 |
+
super().__init__()
|
348 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
349 |
+
self.variance_epsilon = eps
|
350 |
+
|
351 |
+
def forward(self, hidden_states):
|
352 |
+
input_dtype = hidden_states.dtype
|
353 |
+
hidden_states = hidden_states.to(torch.float32)
|
354 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
355 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
356 |
+
return self.weight * hidden_states.to(input_dtype)
|
357 |
+
|
358 |
+
|
359 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->MiniMaxM1
|
360 |
+
class MiniMaxM1RotaryEmbedding(nn.Module):
|
361 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
362 |
+
super().__init__()
|
363 |
+
|
364 |
+
self.dim = dim
|
365 |
+
self.max_position_embeddings = max_position_embeddings
|
366 |
+
self.base = base
|
367 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
368 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
369 |
+
|
370 |
+
# Build here to make `torch.jit.trace` work.
|
371 |
+
self._set_cos_sin_cache(
|
372 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
|
373 |
+
)
|
374 |
+
|
375 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
376 |
+
self.max_seq_len_cached = seq_len
|
377 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
378 |
+
|
379 |
+
freqs = torch.outer(t, self.inv_freq)
|
380 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
381 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
382 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
383 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
384 |
+
|
385 |
+
def forward(self, x, seq_len=None):
|
386 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
387 |
+
if seq_len > self.max_seq_len_cached:
|
388 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
|
389 |
+
|
390 |
+
return (
|
391 |
+
self.cos_cached[:seq_len].to(dtype=torch.float32),
|
392 |
+
self.sin_cached[:seq_len].to(dtype=torch.float32),
|
393 |
+
)
|
394 |
+
|
395 |
+
|
396 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
397 |
+
def rotate_half(x):
|
398 |
+
"""Rotates half the hidden dims of the input."""
|
399 |
+
x1 = x[..., : x.shape[-1] // 2]
|
400 |
+
x2 = x[..., x.shape[-1] // 2:]
|
401 |
+
return torch.cat((-x2, x1), dim=-1)
|
402 |
+
|
403 |
+
|
404 |
+
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
|
405 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
406 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
407 |
+
|
408 |
+
Args:
|
409 |
+
q (`torch.Tensor`): The query tensor.
|
410 |
+
k (`torch.Tensor`): The key tensor.
|
411 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
412 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
413 |
+
position_ids (`torch.Tensor`):
|
414 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
415 |
+
used to pass offsetted position ids when working with a KV-cache.
|
416 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
417 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
418 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
419 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
420 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
421 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
422 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
423 |
+
Returns:
|
424 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
425 |
+
"""
|
426 |
+
dtype = q.dtype
|
427 |
+
rot_dim = cos.shape[-1]
|
428 |
+
q_, q_pass = q[..., :rot_dim], q[..., rot_dim:]
|
429 |
+
k_, k_pass = k[..., :rot_dim], k[..., rot_dim:]
|
430 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
431 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
432 |
+
q_embed = (q_ * cos) + (rotate_half(q_) * sin)
|
433 |
+
k_embed = (k_ * cos) + (rotate_half(k_) * sin)
|
434 |
+
return torch.cat((q_embed, q_pass), dim=-1).to(dtype), torch.cat((k_embed, k_pass), dim=-1).to(dtype)
|
435 |
+
|
436 |
+
|
437 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
438 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
439 |
+
"""
|
440 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
441 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
442 |
+
"""
|
443 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
444 |
+
if n_rep == 1:
|
445 |
+
return hidden_states
|
446 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
447 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
448 |
+
|
449 |
+
|
450 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->MiniMaxM1
|
451 |
+
class MiniMaxM1Attention(nn.Module):
|
452 |
+
"""
|
453 |
+
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
454 |
+
and "Generating Long Sequences with Sparse Transformers".
|
455 |
+
"""
|
456 |
+
|
457 |
+
def __init__(self, config: MiniMaxM1Config, layer_idx: Optional[int] = None):
|
458 |
+
super().__init__()
|
459 |
+
self.config = config
|
460 |
+
self.layer_idx = layer_idx
|
461 |
+
if layer_idx is None:
|
462 |
+
logger.warning_once(
|
463 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
464 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
465 |
+
"when creating this class."
|
466 |
+
)
|
467 |
+
|
468 |
+
self.hidden_size = config.hidden_size
|
469 |
+
self.num_heads = config.num_attention_heads
|
470 |
+
self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
|
471 |
+
self.num_key_value_heads = config.num_key_value_heads
|
472 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
473 |
+
self.max_position_embeddings = config.max_position_embeddings
|
474 |
+
self.rope_theta = config.rope_theta
|
475 |
+
self.is_causal = True
|
476 |
+
self.attention_dropout = config.attention_dropout
|
477 |
+
|
478 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
479 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
480 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
481 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
482 |
+
self.rotary_dim = getattr(config, 'rotary_dim', self.head_dim)
|
483 |
+
|
484 |
+
self.rotary_emb = MiniMaxM1RotaryEmbedding(
|
485 |
+
self.rotary_dim,
|
486 |
+
max_position_embeddings=self.max_position_embeddings,
|
487 |
+
base=self.rope_theta,
|
488 |
+
)
|
489 |
+
|
490 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
491 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
492 |
+
|
493 |
+
def forward(
|
494 |
+
self,
|
495 |
+
hidden_states: torch.Tensor,
|
496 |
+
attention_mask: Optional[torch.Tensor] = None,
|
497 |
+
position_ids: Optional[torch.LongTensor] = None,
|
498 |
+
past_key_value: Optional[Cache] = None,
|
499 |
+
output_attentions: bool = False,
|
500 |
+
use_cache: bool = False,
|
501 |
+
**kwargs,
|
502 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
503 |
+
if "padding_mask" in kwargs:
|
504 |
+
warnings.warn(
|
505 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
506 |
+
)
|
507 |
+
bsz, q_len, _ = hidden_states.size()
|
508 |
+
|
509 |
+
query_states = self.q_proj(hidden_states)
|
510 |
+
key_states = self.k_proj(hidden_states)
|
511 |
+
value_states = self.v_proj(hidden_states)
|
512 |
+
|
513 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
514 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
515 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
516 |
+
|
517 |
+
kv_seq_len = key_states.shape[-2]
|
518 |
+
if past_key_value is not None:
|
519 |
+
if self.layer_idx is None:
|
520 |
+
raise ValueError(
|
521 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
522 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
523 |
+
"with a layer index."
|
524 |
+
)
|
525 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
526 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
527 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
528 |
+
|
529 |
+
if past_key_value is not None:
|
530 |
+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
531 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
532 |
+
|
533 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
534 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
535 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
536 |
+
|
537 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
538 |
+
|
539 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
540 |
+
raise ValueError(
|
541 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
542 |
+
f" {attn_weights.size()}"
|
543 |
+
)
|
544 |
+
|
545 |
+
if attention_mask is not None:
|
546 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
547 |
+
raise ValueError(
|
548 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
549 |
+
)
|
550 |
+
|
551 |
+
attn_weights = attn_weights + attention_mask
|
552 |
+
|
553 |
+
# upcast attention to fp32
|
554 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
555 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
556 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
557 |
+
|
558 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
559 |
+
raise ValueError(
|
560 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
561 |
+
f" {attn_output.size()}"
|
562 |
+
)
|
563 |
+
|
564 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
565 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
566 |
+
|
567 |
+
attn_output = self.o_proj(attn_output)
|
568 |
+
|
569 |
+
if not output_attentions:
|
570 |
+
attn_weights = None
|
571 |
+
|
572 |
+
return attn_output, attn_weights, past_key_value
|
573 |
+
|
574 |
+
|
575 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->MiniMaxM1
|
576 |
+
class MiniMaxM1FlashAttention2(MiniMaxM1Attention):
|
577 |
+
"""
|
578 |
+
MiniMaxM1 flash attention module. This module inherits from `MiniMaxM1Attention` as the weights of the module stays
|
579 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
580 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
581 |
+
"""
|
582 |
+
|
583 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
584 |
+
def __init__(self, *args, **kwargs):
|
585 |
+
super().__init__(*args, **kwargs)
|
586 |
+
|
587 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
588 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
589 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
590 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
591 |
+
|
592 |
+
def forward(
|
593 |
+
self,
|
594 |
+
hidden_states: torch.Tensor,
|
595 |
+
attention_mask: Optional[torch.Tensor] = None,
|
596 |
+
position_ids: Optional[torch.LongTensor] = None,
|
597 |
+
past_key_value: Optional[Union[Cache, Tuple[torch.Tensor]]] = None,
|
598 |
+
output_attentions: bool = False,
|
599 |
+
use_cache: bool = False,
|
600 |
+
**kwargs,
|
601 |
+
):
|
602 |
+
if "padding_mask" in kwargs:
|
603 |
+
warnings.warn(
|
604 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
605 |
+
)
|
606 |
+
|
607 |
+
# overwrite attention_mask with padding_mask
|
608 |
+
attention_mask = kwargs.pop("padding_mask")
|
609 |
+
bsz, q_len, _ = hidden_states.size()
|
610 |
+
|
611 |
+
query_states = self.q_proj(hidden_states)
|
612 |
+
key_states = self.k_proj(hidden_states)
|
613 |
+
value_states = self.v_proj(hidden_states)
|
614 |
+
|
615 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
616 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
617 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
618 |
+
|
619 |
+
kv_seq_len = key_states.shape[-2]
|
620 |
+
if past_key_value is not None:
|
621 |
+
kv_seq_len += past_key_value[0].shape[-3]
|
622 |
+
|
623 |
+
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
624 |
+
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
625 |
+
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
626 |
+
|
627 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
628 |
+
|
629 |
+
use_sliding_windows = (
|
630 |
+
_flash_supports_window_size
|
631 |
+
and getattr(self.config, "sliding_window", None) is not None
|
632 |
+
and kv_seq_len > self.config.sliding_window
|
633 |
+
)
|
634 |
+
|
635 |
+
if not _flash_supports_window_size:
|
636 |
+
logger.warning_once(
|
637 |
+
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
|
638 |
+
" make sure to upgrade flash-attn library."
|
639 |
+
)
|
640 |
+
|
641 |
+
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
642 |
+
|
643 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
644 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
645 |
+
# cast them back in float16 just to be sure everything works as expected.
|
646 |
+
input_dtype = query_states.dtype
|
647 |
+
if input_dtype == torch.float32:
|
648 |
+
if torch.is_autocast_enabled():
|
649 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
650 |
+
# Handle the case where the model is quantized
|
651 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
652 |
+
target_dtype = self.config._pre_quantization_dtype
|
653 |
+
else:
|
654 |
+
target_dtype = self.q_proj.weight.dtype
|
655 |
+
|
656 |
+
logger.warning_once(
|
657 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
658 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
659 |
+
f" {target_dtype}."
|
660 |
+
)
|
661 |
+
|
662 |
+
query_states = query_states.to(target_dtype)
|
663 |
+
key_states = key_states.to(target_dtype)
|
664 |
+
value_states = value_states.to(target_dtype)
|
665 |
+
|
666 |
+
# Reshape to the expected shape for Flash Attention
|
667 |
+
query_states = query_states.transpose(1, 2)
|
668 |
+
key_states = key_states.transpose(1, 2)
|
669 |
+
value_states = value_states.transpose(1, 2)
|
670 |
+
|
671 |
+
if past_key_value is not None:
|
672 |
+
# reuse k, v, for evaluation only
|
673 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=-3)
|
674 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=-3)
|
675 |
+
|
676 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
677 |
+
|
678 |
+
attn_output = self._flash_attention_forward(
|
679 |
+
query_states,
|
680 |
+
key_states,
|
681 |
+
value_states,
|
682 |
+
attention_mask,
|
683 |
+
q_len,
|
684 |
+
dropout=dropout_rate,
|
685 |
+
use_sliding_windows=use_sliding_windows,
|
686 |
+
)
|
687 |
+
|
688 |
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
689 |
+
attn_output = self.o_proj(attn_output)
|
690 |
+
|
691 |
+
if not output_attentions:
|
692 |
+
attn_weights = None
|
693 |
+
|
694 |
+
return attn_output, attn_weights, past_key_value
|
695 |
+
|
696 |
+
def _flash_attention_forward(
|
697 |
+
self,
|
698 |
+
query_states,
|
699 |
+
key_states,
|
700 |
+
value_states,
|
701 |
+
attention_mask,
|
702 |
+
query_length,
|
703 |
+
dropout=0.0,
|
704 |
+
softmax_scale=None,
|
705 |
+
use_sliding_windows=False,
|
706 |
+
):
|
707 |
+
"""
|
708 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
709 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
710 |
+
|
711 |
+
Args:
|
712 |
+
query_states (`torch.Tensor`):
|
713 |
+
Input query states to be passed to Flash Attention API
|
714 |
+
key_states (`torch.Tensor`):
|
715 |
+
Input key states to be passed to Flash Attention API
|
716 |
+
value_states (`torch.Tensor`):
|
717 |
+
Input value states to be passed to Flash Attention API
|
718 |
+
attention_mask (`torch.Tensor`):
|
719 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
720 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
721 |
+
dropout (`float`):
|
722 |
+
Attention dropout
|
723 |
+
softmax_scale (`float`, *optional*):
|
724 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
725 |
+
use_sliding_windows (`bool`, *optional*):
|
726 |
+
Whether to activate sliding window attention.
|
727 |
+
"""
|
728 |
+
if not self._flash_attn_uses_top_left_mask:
|
729 |
+
causal = self.is_causal
|
730 |
+
else:
|
731 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
732 |
+
causal = self.is_causal and query_length != 1
|
733 |
+
|
734 |
+
# Contains at least one padding token in the sequence
|
735 |
+
if attention_mask is not None:
|
736 |
+
batch_size = query_states.shape[0]
|
737 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
738 |
+
query_states, key_states, value_states, attention_mask, query_length
|
739 |
+
)
|
740 |
+
|
741 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
742 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
743 |
+
|
744 |
+
if not use_sliding_windows:
|
745 |
+
attn_output_unpad = flash_attn_varlen_func(
|
746 |
+
query_states,
|
747 |
+
key_states,
|
748 |
+
value_states,
|
749 |
+
cu_seqlens_q=cu_seqlens_q,
|
750 |
+
cu_seqlens_k=cu_seqlens_k,
|
751 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
752 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
753 |
+
dropout_p=dropout,
|
754 |
+
softmax_scale=softmax_scale,
|
755 |
+
causal=causal,
|
756 |
+
)
|
757 |
+
else:
|
758 |
+
attn_output_unpad = flash_attn_varlen_func(
|
759 |
+
query_states,
|
760 |
+
key_states,
|
761 |
+
value_states,
|
762 |
+
cu_seqlens_q=cu_seqlens_q,
|
763 |
+
cu_seqlens_k=cu_seqlens_k,
|
764 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
765 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
766 |
+
dropout_p=dropout,
|
767 |
+
softmax_scale=softmax_scale,
|
768 |
+
causal=causal,
|
769 |
+
window_size=(self.config.sliding_window, self.config.sliding_window),
|
770 |
+
)
|
771 |
+
|
772 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
773 |
+
else:
|
774 |
+
if not use_sliding_windows:
|
775 |
+
attn_output = flash_attn_func(
|
776 |
+
query_states,
|
777 |
+
key_states,
|
778 |
+
value_states,
|
779 |
+
dropout,
|
780 |
+
softmax_scale=softmax_scale,
|
781 |
+
causal=causal,
|
782 |
+
)
|
783 |
+
else:
|
784 |
+
attn_output = flash_attn_func(
|
785 |
+
query_states,
|
786 |
+
key_states,
|
787 |
+
value_states,
|
788 |
+
dropout,
|
789 |
+
softmax_scale=softmax_scale,
|
790 |
+
causal=causal,
|
791 |
+
window_size=(self.config.sliding_window, self.config.sliding_window),
|
792 |
+
)
|
793 |
+
|
794 |
+
return attn_output
|
795 |
+
|
796 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
797 |
+
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
798 |
+
|
799 |
+
# On the first iteration we need to properly re-create the padding mask
|
800 |
+
# by slicing it on the proper place
|
801 |
+
if kv_seq_len != attention_mask.shape[-1]:
|
802 |
+
attention_mask_num_tokens = attention_mask.shape[-1]
|
803 |
+
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len:]
|
804 |
+
|
805 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
806 |
+
|
807 |
+
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
808 |
+
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
809 |
+
|
810 |
+
if query_length == kv_seq_len:
|
811 |
+
query_layer = index_first_axis(
|
812 |
+
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
813 |
+
)
|
814 |
+
cu_seqlens_q = cu_seqlens_k
|
815 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
816 |
+
indices_q = indices_k
|
817 |
+
elif query_length == 1:
|
818 |
+
max_seqlen_in_batch_q = 1
|
819 |
+
cu_seqlens_q = torch.arange(
|
820 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
821 |
+
) # There is a memcpy here, that is very bad.
|
822 |
+
indices_q = cu_seqlens_q[:-1]
|
823 |
+
query_layer = query_layer.squeeze(1)
|
824 |
+
else:
|
825 |
+
# The -q_len: slice assumes left padding.
|
826 |
+
attention_mask = attention_mask[:, -query_length:]
|
827 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
828 |
+
|
829 |
+
return (
|
830 |
+
query_layer,
|
831 |
+
key_layer,
|
832 |
+
value_layer,
|
833 |
+
indices_q,
|
834 |
+
(cu_seqlens_q, cu_seqlens_k),
|
835 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
836 |
+
)
|
837 |
+
|
838 |
+
|
839 |
+
class MiniMaxM1MLP(nn.Module):
|
840 |
+
def __init__(self, config):
|
841 |
+
super().__init__()
|
842 |
+
self.config = config
|
843 |
+
self.hidden_size = config.hidden_size
|
844 |
+
self.intermediate_size = config.intermediate_size
|
845 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
846 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
847 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
848 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
849 |
+
|
850 |
+
def forward(self, x):
|
851 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
852 |
+
return down_proj
|
853 |
+
|
854 |
+
|
855 |
+
class MiniMaxM1BlockSparseTop2MLP(nn.Module):
|
856 |
+
def __init__(self, config: MiniMaxM1Config):
|
857 |
+
super().__init__()
|
858 |
+
self.ffn_dim = config.intermediate_size
|
859 |
+
self.hidden_dim = config.hidden_size
|
860 |
+
|
861 |
+
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
862 |
+
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
863 |
+
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
864 |
+
|
865 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
866 |
+
|
867 |
+
def forward(self, hidden_states):
|
868 |
+
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
869 |
+
current_hidden_states = self.w2(current_hidden_states)
|
870 |
+
return current_hidden_states
|
871 |
+
|
872 |
+
|
873 |
+
class MiniMaxM1BLockSparseTop2MLP(MiniMaxM1BlockSparseTop2MLP):
|
874 |
+
def __init__(self, *args, **kwargs):
|
875 |
+
logger.warning_once(
|
876 |
+
"MiniMaxM1BLockSparseTop2MLP is deprecated by MiniMaxM1BlockSparseTop2MLP and will be removed in v4.40."
|
877 |
+
)
|
878 |
+
super().__init__(*args, **kwargs)
|
879 |
+
|
880 |
+
|
881 |
+
class MiniMaxM1SparseMoeBlock(nn.Module):
|
882 |
+
"""
|
883 |
+
This implementation is
|
884 |
+
strictly equivalent to standard MoE with full capacity (no
|
885 |
+
dropped tokens). It's faster since it formulates MoE operations
|
886 |
+
in terms of block-sparse operations to accomodate imbalanced
|
887 |
+
assignments of tokens to experts, whereas standard MoE either
|
888 |
+
(1) drop tokens at the cost of reduced performance or (2) set
|
889 |
+
capacity factor to number of experts and thus waste computation
|
890 |
+
and memory on padding.
|
891 |
+
"""
|
892 |
+
|
893 |
+
def __init__(self, config):
|
894 |
+
super().__init__()
|
895 |
+
self.hidden_dim = config.hidden_size
|
896 |
+
self.ffn_dim = config.intermediate_size
|
897 |
+
self.num_experts = config.num_local_experts
|
898 |
+
self.top_k = config.num_experts_per_tok
|
899 |
+
|
900 |
+
# gating
|
901 |
+
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
902 |
+
|
903 |
+
self.experts = nn.ModuleList([MiniMaxM1BlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
904 |
+
|
905 |
+
# Jitter parameters
|
906 |
+
self.jitter_noise = config.router_jitter_noise
|
907 |
+
|
908 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
909 |
+
""" """
|
910 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
911 |
+
if self.training and self.jitter_noise > 0:
|
912 |
+
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
|
913 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
914 |
+
# router_logits: (batch * sequence_length, n_experts)
|
915 |
+
router_logits = self.gate(hidden_states)
|
916 |
+
|
917 |
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
918 |
+
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
919 |
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
920 |
+
# we cast back to the input dtype
|
921 |
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
922 |
+
|
923 |
+
final_hidden_states = torch.zeros(
|
924 |
+
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
925 |
+
)
|
926 |
+
|
927 |
+
# One hot encode the selected experts to create an expert mask
|
928 |
+
# this will be used to easily index which expert is going to be sollicitated
|
929 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
930 |
+
|
931 |
+
# Loop over all available experts in the model and perform the computation on each expert
|
932 |
+
for expert_idx in range(self.num_experts):
|
933 |
+
expert_layer = self.experts[expert_idx]
|
934 |
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
935 |
+
|
936 |
+
# Index the correct hidden states and compute the expert hidden state for
|
937 |
+
# the current expert. We need to make sure to multiply the output hidden
|
938 |
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
939 |
+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
940 |
+
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
941 |
+
|
942 |
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
943 |
+
# the `top_x` tensor here.
|
944 |
+
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
945 |
+
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
946 |
+
return final_hidden_states, router_logits
|
947 |
+
|
948 |
+
|
949 |
+
class MiniMaxM1DecoderLayer(nn.Module):
|
950 |
+
def __init__(self, config: MiniMaxM1Config, layer_idx: int):
|
951 |
+
super().__init__()
|
952 |
+
self.config = config
|
953 |
+
self.hidden_size = config.hidden_size
|
954 |
+
|
955 |
+
self.self_attn = self.build_attn(config, layer_idx)
|
956 |
+
|
957 |
+
self.layer_idx = layer_idx
|
958 |
+
|
959 |
+
self.block_sparse_moe = MiniMaxM1SparseMoeBlock(config)
|
960 |
+
self.input_layernorm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
961 |
+
self.post_attention_layernorm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
962 |
+
|
963 |
+
self.postnorm = getattr(config, 'postnorm', False)
|
964 |
+
self.layernorm_attention_alpha = getattr(config, 'layernorm_linear_attention_alpha', 1) \
|
965 |
+
if config.attention_type == 0 else getattr(config, 'layernorm_full_attention_alpha', 1)
|
966 |
+
self.layernorm_attention_beta = getattr(config, 'layernorm_linear_attention_beta', 1) \
|
967 |
+
if config.attention_type == 0 else getattr(config, 'layernorm_full_attention_beta', 1)
|
968 |
+
self.layernorm_mlp_alpha = getattr(config, 'layernorm_mlp_alpha', 1)
|
969 |
+
self.layernorm_mlp_beta = getattr(config, 'layernorm_mlp_beta', 1)
|
970 |
+
|
971 |
+
shared_intermediate = getattr(config, 'shared_intermediate_size', 0)
|
972 |
+
self.shared_moe = False
|
973 |
+
if shared_intermediate > 0:
|
974 |
+
self.shared_moe = True
|
975 |
+
self.shared_mlp = MiniMaxM1MLP(config)
|
976 |
+
self.coefficient = torch.nn.Linear(self.hidden_size, 1, bias=False)
|
977 |
+
|
978 |
+
def build_attn(self, config, layer_idx):
|
979 |
+
if config.attention_type == 0:
|
980 |
+
Attention_module = MiniMaxM1LightningAttention
|
981 |
+
else:
|
982 |
+
Attention_module = MiniMaxM1FlashAttention2
|
983 |
+
|
984 |
+
return Attention_module(
|
985 |
+
config,
|
986 |
+
layer_idx
|
987 |
+
)
|
988 |
+
|
989 |
+
def forward(
|
990 |
+
self,
|
991 |
+
hidden_states: torch.Tensor,
|
992 |
+
attention_mask: Optional[torch.Tensor] = None,
|
993 |
+
position_ids: Optional[torch.LongTensor] = None,
|
994 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
995 |
+
output_attentions: Optional[bool] = False,
|
996 |
+
output_router_logits: Optional[bool] = False,
|
997 |
+
use_cache: Optional[bool] = False,
|
998 |
+
slope_rate: Optional[float] = None,
|
999 |
+
**kwargs,
|
1000 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
1001 |
+
if "padding_mask" in kwargs:
|
1002 |
+
warnings.warn(
|
1003 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
1004 |
+
)
|
1005 |
+
"""
|
1006 |
+
Args:
|
1007 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
1008 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
1009 |
+
`(batch, sequence_length)` where padding elements are indicated by 0.
|
1010 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
1011 |
+
output_attentions (`bool`, *optional*):
|
1012 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
1013 |
+
returned tensors for more detail.
|
1014 |
+
output_router_logits (`bool`, *optional*):
|
1015 |
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
1016 |
+
should not be returned during inference.
|
1017 |
+
use_cache (`bool`, *optional*):
|
1018 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
1019 |
+
(see `past_key_values`).
|
1020 |
+
"""
|
1021 |
+
|
1022 |
+
residual = hidden_states
|
1023 |
+
|
1024 |
+
hidden_states = self.input_layernorm(hidden_states)
|
1025 |
+
if self.postnorm:
|
1026 |
+
residual = hidden_states
|
1027 |
+
|
1028 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
1029 |
+
hidden_states=hidden_states,
|
1030 |
+
position_ids=position_ids,
|
1031 |
+
attn_mask=attention_mask,
|
1032 |
+
past_key_value=past_key_value,
|
1033 |
+
output_attentions=output_attentions,
|
1034 |
+
use_cache=use_cache,
|
1035 |
+
slope_rate=slope_rate,
|
1036 |
+
)
|
1037 |
+
|
1038 |
+
hidden_states = residual * self.layernorm_attention_alpha \
|
1039 |
+
+ hidden_states * self.layernorm_attention_beta
|
1040 |
+
|
1041 |
+
# Fully Connected
|
1042 |
+
residual = hidden_states
|
1043 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
1044 |
+
if self.postnorm:
|
1045 |
+
residual = hidden_states
|
1046 |
+
|
1047 |
+
moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
1048 |
+
if self.shared_moe:
|
1049 |
+
output_mlp = self.shared_mlp(hidden_states)
|
1050 |
+
weight_fp32 = self.coefficient.weight.float()
|
1051 |
+
coef = hidden_states.to(torch.float32) @ weight_fp32.T
|
1052 |
+
coef = torch.nn.functional.sigmoid(coef).to(hidden_states.dtype)
|
1053 |
+
hidden_states = moe_hidden_states * (1 - coef) + output_mlp * coef
|
1054 |
+
else:
|
1055 |
+
hidden_states = moe_hidden_states
|
1056 |
+
|
1057 |
+
hidden_states = residual * self.layernorm_mlp_alpha \
|
1058 |
+
+ hidden_states * self.layernorm_mlp_beta
|
1059 |
+
|
1060 |
+
outputs = (hidden_states,)
|
1061 |
+
|
1062 |
+
if output_attentions:
|
1063 |
+
outputs += (self_attn_weights,)
|
1064 |
+
|
1065 |
+
if use_cache:
|
1066 |
+
outputs += (present_key_value,)
|
1067 |
+
|
1068 |
+
if output_router_logits:
|
1069 |
+
outputs += (router_logits,)
|
1070 |
+
|
1071 |
+
return outputs
|
1072 |
+
|
1073 |
+
|
1074 |
+
MIXTRAL_START_DOCSTRING = r"""
|
1075 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
1076 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
1077 |
+
etc.)
|
1078 |
+
|
1079 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
1080 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
1081 |
+
and behavior.
|
1082 |
+
|
1083 |
+
Parameters:
|
1084 |
+
config ([`MiniMaxM1Config`]):
|
1085 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
1086 |
+
load the weights associated with the model, only the configuration. Check out the
|
1087 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
1088 |
+
"""
|
1089 |
+
|
1090 |
+
|
1091 |
+
@add_start_docstrings(
|
1092 |
+
"The bare MiniMaxM1 Model outputting raw hidden-states without any specific head on top.",
|
1093 |
+
MIXTRAL_START_DOCSTRING,
|
1094 |
+
)
|
1095 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->MiniMaxM1
|
1096 |
+
class MiniMaxM1PreTrainedModel(PreTrainedModel):
|
1097 |
+
config_class = MiniMaxM1Config
|
1098 |
+
base_model_prefix = "model"
|
1099 |
+
supports_gradient_checkpointing = True
|
1100 |
+
_no_split_modules = ["MiniMaxM1DecoderLayer"]
|
1101 |
+
_skip_keys_device_placement = "past_key_values"
|
1102 |
+
_supports_flash_attn_2 = True
|
1103 |
+
_supports_sdpa = True
|
1104 |
+
|
1105 |
+
def _init_weights(self, module):
|
1106 |
+
std = self.config.initializer_range
|
1107 |
+
if isinstance(module, nn.Linear):
|
1108 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
1109 |
+
if module.bias is not None:
|
1110 |
+
module.bias.data.zero_()
|
1111 |
+
elif isinstance(module, nn.Embedding):
|
1112 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
1113 |
+
if module.padding_idx is not None:
|
1114 |
+
module.weight.data[module.padding_idx].zero_()
|
1115 |
+
|
1116 |
+
|
1117 |
+
MIXTRAL_INPUTS_DOCSTRING = r"""
|
1118 |
+
Args:
|
1119 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
1120 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
1121 |
+
it.
|
1122 |
+
|
1123 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
1124 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
1125 |
+
|
1126 |
+
[What are input IDs?](../glossary#input-ids)
|
1127 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1128 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
1129 |
+
|
1130 |
+
- 1 for tokens that are **not masked**,
|
1131 |
+
- 0 for tokens that are **masked**.
|
1132 |
+
|
1133 |
+
[What are attention masks?](../glossary#attention-mask)
|
1134 |
+
|
1135 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
1136 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
1137 |
+
|
1138 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
1139 |
+
`past_key_values`).
|
1140 |
+
|
1141 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
1142 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
1143 |
+
information on the default strategy.
|
1144 |
+
|
1145 |
+
- 1 indicates the head is **not masked**,
|
1146 |
+
- 0 indicates the head is **masked**.
|
1147 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1148 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
1149 |
+
config.n_positions - 1]`.
|
1150 |
+
|
1151 |
+
[What are position IDs?](../glossary#position-ids)
|
1152 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
1153 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
1154 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
1155 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
1156 |
+
|
1157 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
1158 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
1159 |
+
|
1160 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
1161 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
1162 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
1163 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
1164 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
1165 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
1166 |
+
model's internal embedding lookup matrix.
|
1167 |
+
use_cache (`bool`, *optional*):
|
1168 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
1169 |
+
`past_key_values`).
|
1170 |
+
output_attentions (`bool`, *optional*):
|
1171 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
1172 |
+
tensors for more detail.
|
1173 |
+
output_hidden_states (`bool`, *optional*):
|
1174 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
1175 |
+
more detail.
|
1176 |
+
output_router_logits (`bool`, *optional*):
|
1177 |
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
1178 |
+
should not be returned during inference.
|
1179 |
+
return_dict (`bool`, *optional*):
|
1180 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
1181 |
+
"""
|
1182 |
+
|
1183 |
+
|
1184 |
+
@add_start_docstrings(
|
1185 |
+
"The bare MiniMaxM1 Model outputting raw hidden-states without any specific head on top.",
|
1186 |
+
MIXTRAL_START_DOCSTRING,
|
1187 |
+
)
|
1188 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->MiniMaxM1
|
1189 |
+
class MiniMaxM1Model(MiniMaxM1PreTrainedModel):
|
1190 |
+
"""
|
1191 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniMaxM1DecoderLayer`]
|
1192 |
+
|
1193 |
+
Args:
|
1194 |
+
config: MiniMaxM1Config
|
1195 |
+
"""
|
1196 |
+
|
1197 |
+
def __init__(self, config: MiniMaxM1Config):
|
1198 |
+
super().__init__(config)
|
1199 |
+
self.padding_idx = config.pad_token_id
|
1200 |
+
self.vocab_size = config.vocab_size
|
1201 |
+
|
1202 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
1203 |
+
self.attn_type_list = config.attn_type_list
|
1204 |
+
config_copy = copy.deepcopy(config)
|
1205 |
+
|
1206 |
+
self.layers = nn.ModuleList([])
|
1207 |
+
for i in range(config.num_hidden_layers):
|
1208 |
+
_config = copy.deepcopy(config)
|
1209 |
+
if self.attn_type_list[i] == 0:
|
1210 |
+
_config._attn_implementation = 'linear_attention'
|
1211 |
+
_config.attention_type = 0
|
1212 |
+
else:
|
1213 |
+
_config._attn_implementation = config_copy._attn_implementation
|
1214 |
+
_config.attention_type = 1
|
1215 |
+
self.layers.append(MiniMaxM1DecoderLayer(_config, i))
|
1216 |
+
|
1217 |
+
self._attn_implementation = config_copy._attn_implementation
|
1218 |
+
self.norm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1219 |
+
|
1220 |
+
self.gradient_checkpointing = False
|
1221 |
+
self.slopes = self._build_slope_tensor(config.num_attention_heads)
|
1222 |
+
# mask
|
1223 |
+
self._linear_attn_mask = torch.empty(0)
|
1224 |
+
|
1225 |
+
# Initialize weights and apply final processing
|
1226 |
+
self.post_init()
|
1227 |
+
|
1228 |
+
def get_input_embeddings(self):
|
1229 |
+
return self.embed_tokens
|
1230 |
+
|
1231 |
+
def set_input_embeddings(self, value):
|
1232 |
+
self.embed_tokens = value
|
1233 |
+
|
1234 |
+
@staticmethod
|
1235 |
+
def _build_slope_tensor(n_attention_heads: int):
|
1236 |
+
|
1237 |
+
def get_slopes(n):
|
1238 |
+
|
1239 |
+
def get_slopes_power_of_2(n):
|
1240 |
+
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
1241 |
+
ratio = start
|
1242 |
+
return [start * ratio ** i for i in range(n)]
|
1243 |
+
|
1244 |
+
if math.log2(n).is_integer():
|
1245 |
+
return get_slopes_power_of_2(
|
1246 |
+
n) # In the paper, we only train models that have 2^a heads for some a. This function has
|
1247 |
+
else: # some good properties that only occur when the input is a power of 2. To maintain that even
|
1248 |
+
closest_power_of_2 = 2 ** math.floor(
|
1249 |
+
math.log2(n)) # when the number of heads is not a power of 2, we use this workaround.
|
1250 |
+
return (get_slopes_power_of_2(closest_power_of_2)
|
1251 |
+
+ get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
|
1252 |
+
|
1253 |
+
# h, 1, 1
|
1254 |
+
slopes = torch.tensor(get_slopes(n_attention_heads), dtype=torch.float32).reshape(n_attention_heads, 1, 1)
|
1255 |
+
|
1256 |
+
return slopes
|
1257 |
+
|
1258 |
+
# Ignore copy
|
1259 |
+
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
1260 |
+
def forward(
|
1261 |
+
self,
|
1262 |
+
input_ids: torch.LongTensor = None,
|
1263 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1264 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1265 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1266 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1267 |
+
use_cache: Optional[bool] = None,
|
1268 |
+
output_attentions: Optional[bool] = None,
|
1269 |
+
output_hidden_states: Optional[bool] = None,
|
1270 |
+
output_router_logits: Optional[bool] = None,
|
1271 |
+
return_dict: Optional[bool] = None,
|
1272 |
+
) -> Union[Tuple, MoeModelOutputWithPast]:
|
1273 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1274 |
+
output_router_logits = (
|
1275 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
1276 |
+
)
|
1277 |
+
output_hidden_states = (
|
1278 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1279 |
+
)
|
1280 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1281 |
+
|
1282 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1283 |
+
|
1284 |
+
# retrieve input_ids and inputs_embeds
|
1285 |
+
if input_ids is not None and inputs_embeds is not None:
|
1286 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
1287 |
+
elif input_ids is not None:
|
1288 |
+
batch_size, seq_length = input_ids.shape
|
1289 |
+
default_device = input_ids.device
|
1290 |
+
elif inputs_embeds is not None:
|
1291 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
1292 |
+
default_device = inputs_embeds.device
|
1293 |
+
else:
|
1294 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
1295 |
+
|
1296 |
+
past_key_values_length = 0
|
1297 |
+
|
1298 |
+
if self.gradient_checkpointing and self.training:
|
1299 |
+
if use_cache:
|
1300 |
+
logger.warning_once(
|
1301 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
1302 |
+
)
|
1303 |
+
use_cache = False
|
1304 |
+
|
1305 |
+
seq_length_with_past = seq_length
|
1306 |
+
if past_key_values is not None:
|
1307 |
+
for idx in range(len(past_key_values)):
|
1308 |
+
if self.attn_type_list[idx] == 1:
|
1309 |
+
past_key_values_length = past_key_values[idx][0].shape[-3]
|
1310 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
1311 |
+
break
|
1312 |
+
|
1313 |
+
if position_ids is None:
|
1314 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
1315 |
+
position_ids = torch.arange(
|
1316 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
1317 |
+
)
|
1318 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
1319 |
+
else:
|
1320 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
1321 |
+
|
1322 |
+
if inputs_embeds is None:
|
1323 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
1324 |
+
|
1325 |
+
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
1326 |
+
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
1327 |
+
if is_padding_right:
|
1328 |
+
raise ValueError(
|
1329 |
+
"You are attempting to perform batched generation with padding_side='right'"
|
1330 |
+
" this may lead to unexpected behaviour for Flash Attention version of MiniMaxM1. Make sure to "
|
1331 |
+
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1332 |
+
)
|
1333 |
+
slope_rates = [self.slopes.to(default_device) for _ in range(len(self.layers))]
|
1334 |
+
hidden_states = inputs_embeds
|
1335 |
+
# decoder layers
|
1336 |
+
all_hidden_states = () if output_hidden_states else None
|
1337 |
+
all_self_attns = () if output_attentions else None
|
1338 |
+
all_router_logits = () if output_router_logits else None
|
1339 |
+
next_decoder_cache = () if use_cache else None
|
1340 |
+
|
1341 |
+
for idx, decoder_layer in enumerate(self.layers):
|
1342 |
+
if output_hidden_states:
|
1343 |
+
all_hidden_states += (hidden_states,)
|
1344 |
+
|
1345 |
+
past_key_value = (past_key_values[idx] if past_key_values is not None else None)
|
1346 |
+
attn_mask = attention_mask
|
1347 |
+
slope_rate = slope_rates[idx]
|
1348 |
+
slope_rate = slope_rate * (1 - idx / (len(self.layers) - 1) + 1e-5)
|
1349 |
+
if self.gradient_checkpointing and self.training:
|
1350 |
+
layer_outputs = self._gradient_checkpointing_func(
|
1351 |
+
decoder_layer.__call__,
|
1352 |
+
hidden_states,
|
1353 |
+
attention_mask,
|
1354 |
+
position_ids,
|
1355 |
+
past_key_values,
|
1356 |
+
output_attentions,
|
1357 |
+
output_router_logits,
|
1358 |
+
use_cache,
|
1359 |
+
)
|
1360 |
+
else:
|
1361 |
+
layer_outputs = decoder_layer(
|
1362 |
+
hidden_states,
|
1363 |
+
attention_mask=attn_mask,
|
1364 |
+
position_ids=position_ids,
|
1365 |
+
past_key_value=past_key_value,
|
1366 |
+
output_attentions=output_attentions,
|
1367 |
+
output_router_logits=output_router_logits,
|
1368 |
+
use_cache=use_cache,
|
1369 |
+
slope_rate=slope_rate
|
1370 |
+
)
|
1371 |
+
|
1372 |
+
hidden_states = layer_outputs[0]
|
1373 |
+
|
1374 |
+
if use_cache:
|
1375 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
1376 |
+
|
1377 |
+
if output_attentions:
|
1378 |
+
all_self_attns += (layer_outputs[1],)
|
1379 |
+
|
1380 |
+
if output_router_logits:
|
1381 |
+
all_router_logits += (layer_outputs[-1],)
|
1382 |
+
|
1383 |
+
hidden_states = self.norm(hidden_states)
|
1384 |
+
|
1385 |
+
# add hidden states from the last decoder layer
|
1386 |
+
if output_hidden_states:
|
1387 |
+
all_hidden_states += (hidden_states,)
|
1388 |
+
next_cache = next_decoder_cache if use_cache else None
|
1389 |
+
if not return_dict:
|
1390 |
+
return tuple(
|
1391 |
+
v
|
1392 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
1393 |
+
if v is not None
|
1394 |
+
)
|
1395 |
+
return MoeModelOutputWithPast(
|
1396 |
+
last_hidden_state=hidden_states,
|
1397 |
+
past_key_values=next_cache,
|
1398 |
+
hidden_states=all_hidden_states,
|
1399 |
+
attentions=all_self_attns,
|
1400 |
+
router_logits=all_router_logits,
|
1401 |
+
)
|
1402 |
+
|
1403 |
+
|
1404 |
+
class MiniMaxM1ForCausalLM(MiniMaxM1PreTrainedModel):
|
1405 |
+
_tied_weights_keys = ["lm_head.weight"]
|
1406 |
+
|
1407 |
+
def __init__(self, config):
|
1408 |
+
super().__init__(config)
|
1409 |
+
self.model = MiniMaxM1Model(config)
|
1410 |
+
self.vocab_size = config.vocab_size
|
1411 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1412 |
+
self.router_aux_loss_coef = config.router_aux_loss_coef
|
1413 |
+
self.num_experts = config.num_local_experts
|
1414 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
1415 |
+
# Initialize weights and apply final processing
|
1416 |
+
self.post_init()
|
1417 |
+
|
1418 |
+
def get_input_embeddings(self):
|
1419 |
+
return self.model.embed_tokens
|
1420 |
+
|
1421 |
+
def set_input_embeddings(self, value):
|
1422 |
+
self.model.embed_tokens = value
|
1423 |
+
|
1424 |
+
def get_output_embeddings(self):
|
1425 |
+
return self.lm_head
|
1426 |
+
|
1427 |
+
def set_output_embeddings(self, new_embeddings):
|
1428 |
+
self.lm_head = new_embeddings
|
1429 |
+
|
1430 |
+
def set_decoder(self, decoder):
|
1431 |
+
self.model = decoder
|
1432 |
+
|
1433 |
+
def get_decoder(self):
|
1434 |
+
return self.model
|
1435 |
+
|
1436 |
+
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
1437 |
+
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1438 |
+
# Ignore copy
|
1439 |
+
def forward(
|
1440 |
+
self,
|
1441 |
+
input_ids: torch.LongTensor = None,
|
1442 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1443 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1444 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1445 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1446 |
+
labels: Optional[torch.LongTensor] = None,
|
1447 |
+
use_cache: Optional[bool] = None,
|
1448 |
+
output_attentions: Optional[bool] = None,
|
1449 |
+
output_hidden_states: Optional[bool] = None,
|
1450 |
+
output_router_logits: Optional[bool] = None,
|
1451 |
+
return_dict: Optional[bool] = None,
|
1452 |
+
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
1453 |
+
r"""
|
1454 |
+
Args:
|
1455 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1456 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1457 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1458 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1459 |
+
|
1460 |
+
Returns:
|
1461 |
+
|
1462 |
+
Example:
|
1463 |
+
|
1464 |
+
```python
|
1465 |
+
>>> from transformers import AutoTokenizer, MiniMaxM1ForCausalLM
|
1466 |
+
|
1467 |
+
>>> model = MiniMaxM1ForCausalLM.from_pretrained(PATH_TO_WEIGHTS)
|
1468 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_WEIGHTS)
|
1469 |
+
|
1470 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
1471 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1472 |
+
|
1473 |
+
>>> # Generate
|
1474 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1475 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1476 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
1477 |
+
```"""
|
1478 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1479 |
+
output_router_logits = (
|
1480 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
1481 |
+
)
|
1482 |
+
|
1483 |
+
output_hidden_states = (
|
1484 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1485 |
+
)
|
1486 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1487 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1488 |
+
outputs = self.model(
|
1489 |
+
input_ids=input_ids,
|
1490 |
+
attention_mask=attention_mask,
|
1491 |
+
position_ids=position_ids,
|
1492 |
+
past_key_values=past_key_values,
|
1493 |
+
inputs_embeds=inputs_embeds,
|
1494 |
+
use_cache=use_cache,
|
1495 |
+
output_attentions=output_attentions,
|
1496 |
+
output_hidden_states=output_hidden_states,
|
1497 |
+
output_router_logits=output_router_logits,
|
1498 |
+
return_dict=return_dict,
|
1499 |
+
)
|
1500 |
+
|
1501 |
+
hidden_states = outputs[0]
|
1502 |
+
logits = self.lm_head(hidden_states)
|
1503 |
+
logits = logits.float()
|
1504 |
+
|
1505 |
+
loss = None
|
1506 |
+
if labels is not None:
|
1507 |
+
# Shift so that tokens < n predict n
|
1508 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
1509 |
+
shift_labels = labels[..., 1:].contiguous()
|
1510 |
+
# Flatten the tokens
|
1511 |
+
loss_fct = CrossEntropyLoss()
|
1512 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1513 |
+
shift_labels = shift_labels.view(-1)
|
1514 |
+
# Enable model parallelism
|
1515 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
1516 |
+
loss = loss_fct(shift_logits, shift_labels)
|
1517 |
+
|
1518 |
+
aux_loss = None
|
1519 |
+
if output_router_logits:
|
1520 |
+
aux_loss = load_balancing_loss_func(
|
1521 |
+
outputs.router_logits if return_dict else outputs[-1],
|
1522 |
+
self.num_experts,
|
1523 |
+
self.num_experts_per_tok,
|
1524 |
+
attention_mask,
|
1525 |
+
)
|
1526 |
+
if labels is not None:
|
1527 |
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
1528 |
+
|
1529 |
+
if not return_dict:
|
1530 |
+
output = (logits,) + outputs[1:]
|
1531 |
+
if output_router_logits:
|
1532 |
+
output = (aux_loss,) + output
|
1533 |
+
return (loss,) + output if loss is not None else output
|
1534 |
+
|
1535 |
+
torch.cuda.empty_cache()
|
1536 |
+
return MoeCausalLMOutputWithPast(
|
1537 |
+
loss=loss,
|
1538 |
+
aux_loss=aux_loss,
|
1539 |
+
logits=logits,
|
1540 |
+
past_key_values=outputs.past_key_values,
|
1541 |
+
hidden_states=outputs.hidden_states,
|
1542 |
+
attentions=outputs.attentions,
|
1543 |
+
router_logits=outputs.router_logits,
|
1544 |
+
)
|
1545 |
+
|
1546 |
+
def prepare_inputs_for_generation(
|
1547 |
+
self,
|
1548 |
+
input_ids,
|
1549 |
+
past_key_values=None,
|
1550 |
+
attention_mask=None,
|
1551 |
+
inputs_embeds=None,
|
1552 |
+
**kwargs,
|
1553 |
+
):
|
1554 |
+
if past_key_values:
|
1555 |
+
input_ids = input_ids[:, -1:]
|
1556 |
+
|
1557 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1558 |
+
if inputs_embeds is not None and past_key_values is None:
|
1559 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
1560 |
+
else:
|
1561 |
+
model_inputs = {"input_ids": input_ids}
|
1562 |
+
|
1563 |
+
model_inputs.update({
|
1564 |
+
"past_key_values": past_key_values,
|
1565 |
+
"use_cache": kwargs.get("use_cache"),
|
1566 |
+
"attention_mask": attention_mask,
|
1567 |
+
})
|
1568 |
+
return model_inputs
|
1569 |
+
|
1570 |
+
@staticmethod
|
1571 |
+
def _reorder_cache(past_key_values, beam_idx):
|
1572 |
+
reordered_past = ()
|
1573 |
+
for layer_past in past_key_values:
|
1574 |
+
reordered_past += (
|
1575 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1576 |
+
)
|
1577 |
+
return reordered_past
|
1578 |
+
|
1579 |
+
|
1580 |
+
@add_start_docstrings(
|
1581 |
+
"""
|
1582 |
+
The MiniMaxM1 Model transformer with a sequence classification head on top (linear layer).
|
1583 |
+
|
1584 |
+
[`MiniMaxM1ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1585 |
+
(e.g. GPT-2) do.
|
1586 |
+
|
1587 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1588 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1589 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1590 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1591 |
+
each row of the batch).
|
1592 |
+
""",
|
1593 |
+
MIXTRAL_START_DOCSTRING,
|
1594 |
+
)
|
1595 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->MiniMaxM1, LLAMA->MIXTRAL
|
1596 |
+
class MiniMaxM1ForSequenceClassification(MiniMaxM1PreTrainedModel):
|
1597 |
+
def __init__(self, config):
|
1598 |
+
super().__init__(config)
|
1599 |
+
self.num_labels = config.num_labels
|
1600 |
+
self.model = MiniMaxM1Model(config)
|
1601 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1602 |
+
|
1603 |
+
# Initialize weights and apply final processing
|
1604 |
+
self.post_init()
|
1605 |
+
|
1606 |
+
def get_input_embeddings(self):
|
1607 |
+
return self.model.embed_tokens
|
1608 |
+
|
1609 |
+
def set_input_embeddings(self, value):
|
1610 |
+
self.model.embed_tokens = value
|
1611 |
+
|
1612 |
+
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
1613 |
+
def forward(
|
1614 |
+
self,
|
1615 |
+
input_ids: torch.LongTensor = None,
|
1616 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1617 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1618 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
1619 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1620 |
+
labels: Optional[torch.LongTensor] = None,
|
1621 |
+
use_cache: Optional[bool] = None,
|
1622 |
+
output_attentions: Optional[bool] = None,
|
1623 |
+
output_hidden_states: Optional[bool] = None,
|
1624 |
+
return_dict: Optional[bool] = None,
|
1625 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1626 |
+
r"""
|
1627 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1628 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1629 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1630 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1631 |
+
"""
|
1632 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1633 |
+
|
1634 |
+
transformer_outputs = self.model(
|
1635 |
+
input_ids,
|
1636 |
+
attention_mask=attention_mask,
|
1637 |
+
position_ids=position_ids,
|
1638 |
+
past_key_values=past_key_values,
|
1639 |
+
inputs_embeds=inputs_embeds,
|
1640 |
+
use_cache=use_cache,
|
1641 |
+
output_attentions=output_attentions,
|
1642 |
+
output_hidden_states=output_hidden_states,
|
1643 |
+
return_dict=return_dict,
|
1644 |
+
)
|
1645 |
+
hidden_states = transformer_outputs[0]
|
1646 |
+
logits = self.score(hidden_states)
|
1647 |
+
|
1648 |
+
if input_ids is not None:
|
1649 |
+
batch_size = input_ids.shape[0]
|
1650 |
+
else:
|
1651 |
+
batch_size = inputs_embeds.shape[0]
|
1652 |
+
|
1653 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
1654 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
1655 |
+
if self.config.pad_token_id is None:
|
1656 |
+
sequence_lengths = -1
|
1657 |
+
else:
|
1658 |
+
if input_ids is not None:
|
1659 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
1660 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
1661 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
1662 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
1663 |
+
else:
|
1664 |
+
sequence_lengths = -1
|
1665 |
+
|
1666 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
1667 |
+
|
1668 |
+
loss = None
|
1669 |
+
if labels is not None:
|
1670 |
+
labels = labels.to(logits.device)
|
1671 |
+
if self.config.problem_type is None:
|
1672 |
+
if self.num_labels == 1:
|
1673 |
+
self.config.problem_type = "regression"
|
1674 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1675 |
+
self.config.problem_type = "single_label_classification"
|
1676 |
+
else:
|
1677 |
+
self.config.problem_type = "multi_label_classification"
|
1678 |
+
|
1679 |
+
if self.config.problem_type == "regression":
|
1680 |
+
loss_fct = MSELoss()
|
1681 |
+
if self.num_labels == 1:
|
1682 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1683 |
+
else:
|
1684 |
+
loss = loss_fct(pooled_logits, labels)
|
1685 |
+
elif self.config.problem_type == "single_label_classification":
|
1686 |
+
loss_fct = CrossEntropyLoss()
|
1687 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
1688 |
+
elif self.config.problem_type == "multi_label_classification":
|
1689 |
+
loss_fct = BCEWithLogitsLoss()
|
1690 |
+
loss = loss_fct(pooled_logits, labels)
|
1691 |
+
if not return_dict:
|
1692 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
1693 |
+
return ((loss,) + output) if loss is not None else output
|
1694 |
+
|
1695 |
+
return SequenceClassifierOutputWithPast(
|
1696 |
+
loss=loss,
|
1697 |
+
logits=pooled_logits,
|
1698 |
+
past_key_values=transformer_outputs.past_key_values,
|
1699 |
+
hidden_states=transformer_outputs.hidden_states,
|
1700 |
+
attentions=transformer_outputs.attentions,
|
1701 |
+
)
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"bos_token": "<beginning_of_sentence>",
|
4 |
+
"clean_up_tokenization_spaces": false,
|
5 |
+
"eos_token": "<end_of_sentence>",
|
6 |
+
"model_max_length": 40960000,
|
7 |
+
"tokenizer_class": "GPT2Tokenizer",
|
8 |
+
"unk_token": "<end_of_document>",
|
9 |
+
"chat_template": "{{ '<begin_of_document>' -}}{% set ns = namespace(system_prompt='') -%}{% for message in messages -%}{% if message['role'] == 'system' -%}{% set ns.system_prompt = ns.system_prompt + message['content'][0]['text'] -%}{% endif -%}{%- endfor -%}{% if ns.system_prompt != '' -%}{{ '<beginning_of_sentence>system ai_setting=assistant\n' + ns.system_prompt + '<end_of_sentence>\n' -}}{%- endif -%}{% if tools -%}{{ '<beginning_of_sentence>system tool_setting=tools\nYou are provided with these tools:\n<tools>\n' -}}{% for tool in tools -%}{{ tool | tojson ~ '\n' -}}{%- endfor -%}{{ '</tools>\n\nIf you need to call tools, please respond with <tool_calls></tool_calls> XML tags, and provide tool-name and json-object of arguments, following the format below:\n<tool_calls>\n{''name'': <tool-name-1>, ''arguments'': <args-json-object-1>}\n...\n</tool_calls><end_of_sentence>\n' -}}{%- endif -%}{% for message in messages -%}{% if message['role'] == 'user' -%}{{ '<beginning_of_sentence>user name=user\n' + message['content'][0]['text'] + '<end_of_sentence>\n' -}}{% elif message['role'] == 'assistant' -%}{{ '<beginning_of_sentence>ai name=assistant\n' -}}{% for content in message['content'] | selectattr('type', 'equalto', 'text') -%}{{ content['text'] -}}{%- endfor -%}{{ '<end_of_sentence>\n' -}}{% elif message['role'] == 'tool' -%}{{ '<beginning_of_sentence>tool name=tools\n' }} {%- for content in message['content'] -%}{{- 'tool name: ' + content['name'] + '\n' + 'tool result: ' + content['text'] + '\n\n' -}} {%- endfor -%}{{- '<end_of_sentence>\n' -}}{% endif -%}{%- endfor -%}{% if add_generation_prompt -%}{{ '<beginning_of_sentence>ai name=assistant\n' -}}{%- endif -%}"
|
10 |
+
}
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|