QscQ commited on
Commit
8043ae4
·
1 Parent(s): c036f7f

Signed-off-by: QscQ <[email protected]>

LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2025 MiniMax
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,224 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: text-generation
3
+ license: apache-2.0
4
+ library_name: transformers
5
+ tags:
6
+ - vllm
7
+ ---
8
+
9
+ <div align="center">
10
+
11
+ <svg width="60%" height="auto" viewBox="0 0 144 48" fill="none" xmlns="http://www.w3.org/2000/svg">
12
+ <path d="M26.6782 7.96523C26.6782 7.02436 25.913 6.26087 24.9739 6.26087C24.0348 6.26087 23.2695 7.0261 23.2695 7.96523V36.2139C23.2695 38.4 21.4904 40.1791 19.3043 40.1791C17.1183 40.1791 15.3391 38.4 15.3391 36.2139V18.0904C15.3391 17.1496 14.5739 16.3861 13.6348 16.3861C12.6956 16.3861 11.9304 17.1513 11.9304 18.0904V25.7722C11.9304 27.9583 10.1513 29.7374 7.96518 29.7374C5.7791 29.7374 4 27.9583 4 25.7722V22.9878C4 22.3635 4.50609 21.8574 5.13043 21.8574C5.75478 21.8574 6.26087 22.3635 6.26087 22.9878V25.7722C6.26087 26.713 7.02605 27.4765 7.96518 27.4765C8.90431 27.4765 9.66954 26.7113 9.66954 25.7722V18.0904C9.66954 15.9044 11.4487 14.1252 13.6348 14.1252C15.8209 14.1252 17.6 15.9044 17.6 18.0904V36.2139C17.6 37.1548 18.3652 37.9183 19.3043 37.9183C20.2435 37.9183 21.0087 37.153 21.0087 36.2139V25.1322V7.96523C21.0087 5.77914 22.7878 4 24.9739 4C27.16 4 28.9391 5.77914 28.9391 7.96523V31.3565C28.9391 31.9809 28.433 32.487 27.8087 32.487C27.1843 32.487 26.6782 31.9809 26.6782 31.3565V7.96523ZM47.6539 14.1252C45.4678 14.1252 43.6887 15.9044 43.6887 18.0904V33.2296C43.6887 34.1704 42.9235 34.9339 41.9843 34.9339C41.0452 34.9339 40.28 34.1687 40.28 33.2296V7.96523C40.28 5.77914 38.5008 4 36.3148 4C34.1287 4 32.3496 5.77914 32.3496 7.96523V40.0348C32.3496 40.9756 31.5843 41.7391 30.6452 41.7391C29.7061 41.7391 28.9409 40.9739 28.9409 40.0348V36.0643C28.9409 35.44 28.4348 34.9339 27.8104 34.9339C27.1861 34.9339 26.68 35.44 26.68 36.0643V40.0348C26.68 42.2209 28.4591 44 30.6452 44C32.8313 44 34.6104 42.2209 34.6104 40.0348V7.96523C34.6104 7.02436 35.3756 6.26087 36.3148 6.26087C37.2539 6.26087 38.0191 7.0261 38.0191 7.96523V33.2296C38.0191 35.4156 39.7982 37.1948 41.9843 37.1948C44.1704 37.1948 45.9496 35.4156 45.9496 33.2296V18.0904C45.9496 17.1496 46.7148 16.3861 47.6539 16.3861C48.593 16.3861 49.3582 17.1513 49.3582 18.0904V31.3565C49.3582 31.9809 49.8643 32.487 50.4887 32.487C51.113 32.487 51.6191 31.9809 51.6191 31.3565V18.0904C51.6191 15.9044 49.84 14.1252 47.6539 14.1252Z" fill="url(#paint0_linear_17_483)"/>
13
+ <path d="M68.7671 16.5615H71.2541C71.3254 16.5615 71.3845 16.5859 71.435 16.6363C71.4836 16.6868 71.5097 16.7459 71.5097 16.8172V31.1824C71.5097 31.2537 71.4854 31.3128 71.435 31.3633C71.3845 31.4137 71.3254 31.4381 71.2541 31.4381H68.7671C68.6958 31.4381 68.6367 31.4137 68.5862 31.3633C68.5358 31.3146 68.5115 31.2537 68.5115 31.1824V21.812C68.5115 21.7563 68.4976 21.7268 68.4697 21.7268C68.4419 21.7268 68.4123 21.7476 68.3845 21.7911L66.1323 25.318C66.061 25.4311 65.9619 25.4885 65.8349 25.4885H64.581C64.4541 25.4885 64.3549 25.4328 64.2836 25.318L62.0315 21.7911C62.0036 21.7494 61.9741 21.7302 61.9462 21.7372C61.9184 21.7441 61.9045 21.7772 61.9045 21.8328V31.1824C61.9045 31.2537 61.8802 31.3128 61.8297 31.3633C61.7793 31.4137 61.7202 31.4381 61.6489 31.4381H59.1619C59.0906 31.4381 59.0315 31.4137 58.981 31.3633C58.9306 31.3146 58.9062 31.2537 58.9062 31.1824V16.8172C58.9062 16.7459 58.9306 16.6868 58.981 16.6363C59.0315 16.5859 59.0906 16.5615 59.1619 16.5615H61.6489C61.7758 16.5615 61.8749 16.6189 61.9462 16.732L65.1341 21.6833C65.1758 21.7685 65.2193 21.7685 65.261 21.6833L68.4697 16.732C68.541 16.6189 68.6402 16.5615 68.7671 16.5615Z" fill="currentColor"/>
14
+ <path d="M74.1764 31.3633C74.1259 31.3146 74.1016 31.2537 74.1016 31.1824V16.8172C74.1016 16.7459 74.1259 16.6868 74.1764 16.6363C74.2268 16.5859 74.2859 16.5615 74.3572 16.5615H76.8442C76.9155 16.5615 76.9746 16.5859 77.0251 16.6363C77.0737 16.6868 77.0998 16.7459 77.0998 16.8172V31.1824C77.0998 31.2537 77.0755 31.3128 77.0251 31.3633C76.9746 31.4137 76.9155 31.4381 76.8442 31.4381H74.3572C74.2859 31.4381 74.2268 31.4137 74.1764 31.3633Z" fill="currentColor"/>
15
+ <path d="M88.3066 16.6361C88.3553 16.5874 88.4162 16.5613 88.4875 16.5613H90.9744C91.0457 16.5613 91.1049 16.5857 91.1553 16.6361C91.204 16.6865 91.2301 16.7457 91.2301 16.817V31.1822C91.2301 31.2535 91.2057 31.3126 91.1553 31.363C91.1049 31.4135 91.0457 31.4378 90.9744 31.4378H88.5727C88.4301 31.4378 88.331 31.3822 88.2753 31.2674L82.771 22.1717C82.7431 22.13 82.7136 22.1109 82.6858 22.1178C82.6579 22.1248 82.644 22.1578 82.644 22.2135L82.6858 31.1805C82.6858 31.2518 82.6614 31.3109 82.611 31.3613C82.5606 31.4117 82.5014 31.4361 82.4301 31.4361H79.9431C79.8718 31.4361 79.8127 31.4117 79.7623 31.3613C79.7118 31.3126 79.6875 31.2518 79.6875 31.1805V16.8152C79.6875 16.7439 79.7118 16.6848 79.7623 16.6344C79.8127 16.5839 79.8718 16.5596 79.9431 16.5596H82.3449C82.4858 16.5596 82.5849 16.617 82.6423 16.73L88.124 25.7822C88.1518 25.8239 88.1797 25.8431 88.2092 25.8361C88.2371 25.8292 88.251 25.7978 88.251 25.7404L88.2301 16.8152C88.2301 16.7439 88.2545 16.6848 88.3049 16.6344L88.3066 16.6361Z" fill="currentColor"/>
16
+ <path d="M93.8951 31.3633C93.8446 31.3146 93.8203 31.2537 93.8203 31.1824V16.8172C93.8203 16.7459 93.8446 16.6868 93.8951 16.6363C93.9455 16.5859 94.0047 16.5615 94.076 16.5615H96.5629C96.6342 16.5615 96.6934 16.5859 96.7438 16.6363C96.7925 16.6868 96.8186 16.7459 96.8186 16.8172V31.1824C96.8186 31.2537 96.7942 31.3128 96.7438 31.3633C96.6934 31.4137 96.6342 31.4381 96.5629 31.4381H94.076C94.0047 31.4381 93.9455 31.4137 93.8951 31.3633Z" fill="currentColor"/>
17
+ <path d="M109.267 16.5615H111.754C111.825 16.5615 111.885 16.5859 111.935 16.6363C111.984 16.6868 112.01 16.7459 112.01 16.8172V31.1824C112.01 31.2537 111.985 31.3128 111.935 31.3633C111.885 31.4137 111.825 31.4381 111.754 31.4381H109.267C109.196 31.4381 109.137 31.4137 109.086 31.3633C109.036 31.3146 109.011 31.2537 109.011 31.1824V21.812C109.011 21.7563 108.998 21.7268 108.97 21.7268C108.942 21.7268 108.912 21.7476 108.885 21.7911L106.632 25.318C106.561 25.4311 106.462 25.4885 106.335 25.4885H105.081C104.954 25.4885 104.855 25.4328 104.784 25.318L102.531 21.7911C102.504 21.7494 102.474 21.7302 102.446 21.7372C102.418 21.7441 102.405 21.7772 102.405 21.8328V31.1824C102.405 31.2537 102.38 31.3128 102.33 31.3633C102.279 31.4137 102.22 31.4381 102.149 31.4381H99.6619C99.5906 31.4381 99.5315 31.4137 99.481 31.3633C99.4306 31.3146 99.4062 31.2537 99.4062 31.1824V16.8172C99.4062 16.7459 99.4306 16.6868 99.481 16.6363C99.5315 16.5859 99.5906 16.5615 99.6619 16.5615H102.149C102.276 16.5615 102.375 16.6189 102.446 16.732L105.634 21.6833C105.676 21.7685 105.719 21.7685 105.761 21.6833L108.97 16.732C109.041 16.6189 109.14 16.5615 109.267 16.5615Z" fill="currentColor"/>
18
+ <path d="M123.782 31.2241L123.144 29.1424C123.116 29.0867 123.079 29.0572 123.038 29.0572H117.81C117.768 29.0572 117.732 29.085 117.704 29.1424L117.088 31.2241C117.046 31.3668 116.954 31.4363 116.812 31.4363H114.112C114.027 31.4363 113.963 31.412 113.921 31.3615C113.879 31.3128 113.871 31.2381 113.9 31.1389L118.49 16.7737C118.532 16.6328 118.624 16.5615 118.766 16.5615H122.102C122.243 16.5615 122.335 16.6328 122.379 16.7737L126.968 31.1389C126.982 31.1668 126.989 31.2033 126.989 31.245C126.989 31.372 126.911 31.4363 126.756 31.4363H124.057C123.916 31.4363 123.824 31.365 123.78 31.2241H123.782ZM118.554 26.7407H122.295C122.38 26.7407 122.408 26.6989 122.38 26.6137L120.467 20.3024C120.453 20.2467 120.432 20.2207 120.403 20.2276C120.375 20.2346 120.352 20.2589 120.339 20.3024L118.469 26.6137C118.455 26.6989 118.483 26.7407 118.554 26.7407Z" fill="currentColor"/>
19
+ <path d="M128.222 31.353C128.18 31.2974 128.187 31.2261 128.243 31.1409L132.365 24.0643C132.393 24.0226 132.393 23.9791 132.365 23.9374L128.243 16.8609L128.201 16.7339C128.201 16.6209 128.28 16.5635 128.434 16.5635H131.133C131.274 16.5635 131.38 16.6209 131.452 16.7339L134.213 21.6C134.255 21.6852 134.299 21.6852 134.34 21.6L137.102 16.7339C137.173 16.6209 137.28 16.5635 137.42 16.5635H140.099C140.198 16.5635 140.269 16.5913 140.311 16.6487C140.353 16.7061 140.346 16.7756 140.29 16.8609L136.168 23.9374C136.154 23.9791 136.154 24.0226 136.168 24.0643L140.29 31.1409L140.332 31.2678C140.332 31.3809 140.253 31.4383 140.099 31.4383H137.42C137.278 31.4383 137.172 31.3826 137.102 31.2678L134.34 26.4226C134.299 26.3374 134.255 26.3374 134.213 26.4226L131.429 31.2678C131.358 31.3809 131.252 31.4383 131.111 31.4383H128.433C128.333 31.4383 128.262 31.4104 128.22 31.353H128.222Z" fill="currentColor"/>
20
+ <defs>
21
+ <linearGradient id="paint0_linear_17_483" x1="3.99826" y1="24" x2="51.6208" y2="24" gradientUnits="userSpaceOnUse">
22
+ <stop stop-color="#E21680"/>
23
+ <stop offset="1" stop-color="#FF633A"/>
24
+ </linearGradient>
25
+ </defs>
26
+ </svg>
27
+
28
+ </div>
29
+ <hr>
30
+
31
+ <div align="center" style="line-height: 1;">
32
+ <a href="https://www.minimax.io" target="_blank" style="margin: 2px;">
33
+ <img alt="Homepage" src="https://img.shields.io/badge/_Homepage-MiniMax-FF4040?style=flat-square&labelColor=2C3E50&logo=&logoWidth=20" style="display: inline-block; vertical-align: middle;"/>
34
+ </a>
35
+ <a href="https://arxiv.org/abs/2506.13585" target="_blank" style="margin: 2px;">
36
+ <img alt="Paper" src="https://img.shields.io/badge/📖_Paper-MiniMax--M1-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
37
+ </a>
38
+ <a href="https://chat.minimax.io/" target="_blank" style="margin: 2px;">
39
+ <img alt="Chat" src="https://img.shields.io/badge/_MiniMax_Chat-FF4040?style=flat-square&labelColor=2C3E50&logo=&logoWidth=20" style="display: inline-block; vertical-align: middle;"/>
40
+ </a>
41
+ <a href="https://www.minimax.io/platform" style="margin: 2px;">
42
+ <img alt="API" src="https://img.shields.io/badge/⚡_API-Platform-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
43
+ </a>
44
+ <a href="https://github.com/MiniMax-AI/MiniMax-MCP" style="margin: 2px;">
45
+ <img alt="MCP" src="https://img.shields.io/badge/🚀_MCP-MiniMax_MCP-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
46
+ </a>
47
+ </div>
48
+ <div align="center" style="line-height: 1;">
49
+ <a href="https://huggingface.co/MiniMaxAI" target="_blank" style="margin: 2px;">
50
+ <img alt="Hugging Face" src="https://img.shields.io/badge/🤗_Hugging_Face-MiniMax-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
51
+ </a>
52
+ <a href="https://github.com/MiniMax-AI/MiniMax-M1" target="_blank" style="margin: 2px;">
53
+ <img alt="GitHub" src="https://img.shields.io/badge/🐙_GitHub-MiniMax-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
54
+ </a>
55
+ <a href="https://www.modelscope.cn/organization/MiniMax" target="_blank" style="margin: 2px;">
56
+ <img alt="ModelScope" src="https://img.shields.io/badge/🤖️_ModelScope-MiniMax-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
57
+ </a>
58
+ <a href="https://github.com/MiniMax-AI/MiniMax-M1/blob/main/LICENSE" style="margin: 2px;">
59
+ <img alt="License" src="https://img.shields.io/badge/⚖️_License-Apache_2.0-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
60
+ </a>
61
+ <a href="https://github.com/MiniMax-AI/MiniMax-01/blob/main/figures/wechat-qrcode.jpeg" target="_blank" style="margin: 2px;">
62
+ <img alt="WeChat" src="https://img.shields.io/badge/💬_WeChat-MiniMax-FF4040?style=flat-square&labelColor=2C3E50" style="display: inline-block; vertical-align: middle;"/>
63
+ </a>
64
+ </div>
65
+
66
+ # MiniMax-M1
67
+
68
+ ## 1. Model Overview
69
+
70
+ We introduce MiniMax-M1, the world's first open-weight, large-scale hybrid-attention reasoning model.
71
+ MiniMax-M1 is powered by a hybrid Mixture-of-Experts (MoE) architecture combined with a lightning
72
+ attention mechanism. The model is developed based on our previous [MiniMax-Text-01 model](https://huggingface.co/MiniMaxAI/MiniMax-Text-01),
73
+ which contains a total of 456 billion parameters with 45.9 billion parameters activated
74
+ per token. Consistent with MiniMax-Text-01, the M1 model natively supports a context length of 1
75
+ million tokens, 8x the context size of DeepSeek R1. Furthermore, the lightning attention mechanism
76
+ in MiniMax-M1 enables efficient scaling of test-time compute – For example, compared to DeepSeek
77
+ R1, M1 consumes 25% of the FLOPs at a generation length of 100K tokens. These properties make M1
78
+ particularly suitable for complex tasks that require processing long inputs and thinking extensively.
79
+ MiniMax-M1 is trained using large-scale reinforcement learning (RL) on diverse problems ranging from
80
+ traditional mathematical reasoning to sandbox-based, real-world software engineering environments.
81
+ We develop an efficient RL scaling framework for M1 highlighting two perspectives: (1) We propose
82
+ CISPO, a novel algorithm that clips importance sampling weights instead of token updates, which
83
+ outperforms other competitive RL variants; (2) Our hybrid-attention design naturally enhances the
84
+ efficiency of RL, where we address unique challenges when scaling RL with the hybrid architecture. We
85
+ train two versions of MiniMax-M1 models with [40K](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k) and
86
+ [80K](https://huggingface.co/MiniMaxAI/MiniMax-M1-80k) thinking budgets respectively. Experiments
87
+ on standard benchmarks show that our models outperform other strong open-weight models such as
88
+ the original DeepSeek-R1 and Qwen3-235B, particularly on complex software engineering, tool using,
89
+ and long context tasks. With efficient scaling of test-time compute, MiniMax-M1 serves as a strong
90
+ foundation for next-generation language model agents to reason and tackle real-world challenges.
91
+
92
+ <p align="center">
93
+ <img width="100%" src="figures/TextBench.png">
94
+ <br>
95
+ <small><em>Benchmark performance comparison of leading commercial and open-weight models across competition-level mathematics, coding, software engineering, agentic tool use, and long-context understanding tasks. We use the MiniMax-M1-80k model here for MiniMax-M1.</em></small>
96
+ </p>
97
+
98
+
99
+ ## 2. Evaluation
100
+
101
+ **Performance of MiniMax-M1 on core benchmarks.**
102
+
103
+
104
+ | **Category** | **Task** | **MiniMax-M1-80K** | **MiniMax-M1-40K** | **Qwen3-235B-A22B** | **DeepSeek-R1-0528** | **DeepSeek-R1** | **Seed-Thinking-v1.5** | **Claude 4 Opus** | **Gemini 2.5 Pro (06-05)** | **OpenAI-o3** |
105
+ |:---|:---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
106
+ | | *Extended Thinking* | *80K* | *40K* | *32k* | *64k* | *32k* | *32k* | *64k* | *64k* | *100k* |
107
+ | ***Mathematics*** | AIME 2024 | 86.0 | 83.3 | 85.7 | 91.4 | 79.8 | 86.7 | 76.0 | 92.0 | 91.6 |
108
+ | | AIME 2025 | 76.9 | 74.6 | 81.5 | 87.5 | 70.0 | 74.0 | 75.5 | 88.0 | 88.9 |
109
+ | | MATH-500 | 96.8 | 96.0 | 96.2 | 98.0 | 97.3 | 96.7 | 98.2 | 98.8 | 98.1 |
110
+ | ***General Coding*** | LiveCodeBench *(24/8~25/5)* | 65.0 | 62.3 | 65.9 | 73.1 | 55.9 | 67.5 | 56.6 | 77.1 | 75.8 |
111
+ | | FullStackBench | 68.3 | 67.6 | 62.9 | 69.4 | 70.1 | 69.9 | 70.3 | -- | 69.3 |
112
+ | ***Reasoning & Knowledge***| GPQA Diamond | 70.0 | 69.2 | 71.1 | 81.0 | 71.5 | 77.3 | 79.6 | 86.4 | 83.3 |
113
+ | | HLE *(no tools)* | 8.4\* | 7.2\* | 7.6\* | 17.7\* | 8.6\* | 8.2 | 10.7 | 21.6 | 20.3 |
114
+ | | ZebraLogic | 86.8 | 80.1 | 80.3 | 95.1 | 78.7 | 84.4 | 95.1 | 91.6 | 95.8 |
115
+ | | MMLU-Pro | 81.1 | 80.6 | 83.0 | 85.0 | 84.0 | 87.0 | 85.0 | 86.0 | 85.0 |
116
+ | ***Software Engineering***| SWE-bench Verified| 56.0 | 55.6 | 34.4 | 57.6 | 49.2 | 47.0 | 72.5 | 67.2 | 69.1 |
117
+ | ***Long Context*** | OpenAI-MRCR *(128k)* | 73.4 | 76.1 | 27.7 | 51.5 | 35.8 | 54.3 | 48.9 | 76.8 | 56.5 |
118
+ | | OpenAI-MRCR *(1M)* | 56.2 | 58.6 | -- | -- | -- | -- | -- | 58.8 | -- |
119
+ | | LongBench-v2 | 61.5 | 61.0 | 50.1 | 52.1 | 58.3 | 52.5 | 55.6 | 65.0 | 58.8 |
120
+ | ***Agentic Tool Use***| TAU-bench *(airline)* | 62.0 | 60.0 | 34.7 | 53.5 | -- | 44.0 | 59.6 | 50.0 | 52.0 |
121
+ | | TAU-bench *(retail)* | 63.5 | 67.8 | 58.6 | 63.9 | -- | 55.7 | 81.4 | 67.0 | 73.9 |
122
+ | ***Factuality*** | SimpleQA | 18.5 | 17.9 | 11.0 | 27.8 | 30.1 | 12.9 | -- | 54.0 | 49.4 |
123
+ | ***General Assistant***| MultiChallenge | 44.7 | 44.7 | 40.0 | 45.0 | 40.7 | 43.0 | 45.8 | 51.8 | 56.5 |
124
+
125
+ \* conducted on the text-only HLE subset.
126
+
127
+ Our models are evaluated with `temperature=1.0`, `top_p=0.95`.
128
+
129
+ ### SWE-bench methodology
130
+ We report results derived from the Agentless scaffold. Departing from the original pipeline, our methodology employs a two-stage localization process (without any embedding-based retrieval mechanisms): initial coarse-grained file localization followed by fine-grained localization to specific files and code elements. The values for our models are calculated on the subset of n=486 verified tasks which work on our infrastructure. The excluded 14 test cases that were incompatible with our internal infrastructure are:
131
+ `"astropy__astropy-7606"`,
132
+ `"astropy__astropy-8707"`,
133
+ `"astropy__astropy-8872"`,
134
+ `"django__django-10097"`,
135
+ `"matplotlib__matplotlib-20488"`,
136
+ `"psf__requests-2317"`,
137
+ `"psf__requests-2931"`,
138
+ `"psf__requests-5414"`,
139
+ `"pylint-dev__pylint-6528"`,
140
+ `"pylint-dev__pylint-7277"`,
141
+ `"sphinx-doc__sphinx-10435"`,
142
+ `"sphinx-doc__sphinx-7985"`,
143
+ `"sphinx-doc__sphinx-8269"`,
144
+ `"sphinx-doc__sphinx-8475"`
145
+
146
+ ### TAU-bench methodology
147
+ We evaluate TAU-Bench with GPT-4.1 as user model and without any custom tools. The maximum number of interaction steps is 40.
148
+ Our general system prompt is:
149
+ ```
150
+ - In each round, you need to carefully examine the tools provided to you to determine if any can be used.
151
+ - You must adhere to all of the policies. Pay attention to the details in the terms. Solutions for most situations can be found within these policies.
152
+ ```
153
+
154
+ ## 3. Recommendations for Minimax-M1 Model Usage
155
+
156
+ To achieve the best results with the Minimax-M1 model, we suggest focusing on two key points: Inference Parameters and the System Prompt.
157
+
158
+ ### 3.1. Inference Parameters
159
+ - Temperature: **`1.0`**
160
+ - Top_p: **`0.95`**
161
+
162
+ This setting is optimal for encouraging creativity and diversity in the model's responses. It allows the model to explore a wider range of linguistic possibilities, preventing outputs that are too rigid or repetitive, while still maintaining strong logical coherence.
163
+
164
+ ### 3.2. System Prompt
165
+ Tailoring your system prompt to the specific task is crucial for guiding the model effectively. Below are suggested settings for different scenarios.
166
+
167
+ #### A. General-Purpose Scenarios
168
+ For common tasks like summarization, translation, Q&A, or creative writing:
169
+ ```
170
+ You are a helpful assistant.
171
+ ```
172
+ #### B. Web Development Scenarios
173
+ For complex tasks like generating code for web pages:
174
+ ```
175
+ You are a web development engineer, writing web pages according to the instructions below. You are a powerful code editing assistant capable of writing code and creating artifacts in conversations with users, or modifying and updating existing artifacts as requested by users.
176
+ All code is written in a single code block to form a complete code file for display, without separating HTML and JavaScript code. An artifact refers to a runnable complete code snippet, you prefer to integrate and output such complete runnable code rather than breaking it down into several code blocks. For certain types of code, they can render graphical interfaces in a UI window. After generation, please check the code execution again to ensure there are no errors in the output.
177
+ Output only the HTML, without any additional descriptive text. Make the UI looks modern and beautiful.
178
+ ```
179
+ #### C. Mathematical Scenarios
180
+ When dealing with problems that require calculation or logical deduction:
181
+ ```
182
+ Please reason step by step, and put your final answer within \boxed{}.
183
+ ```
184
+
185
+ ## 4. Deployment Guide
186
+
187
+ Download the model from HuggingFace repository:
188
+ - [MiniMax-M1-40k](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k)
189
+ - [MiniMax-M1-80k](https://huggingface.co/MiniMaxAI/MiniMax-M1-80k)
190
+
191
+ For production deployment, we recommend using [vLLM](https://docs.vllm.ai/en/latest/) to serve MiniMax-M1. vLLM provides excellent performance for serving large language models with the following features:
192
+ - 🔥 Outstanding service throughout performance
193
+ - ⚡ Efficient and intelligent memory management
194
+ - 📦 Powerful batch request processing capability
195
+ - ⚙️ Deeply optimized underlying performance
196
+
197
+ For detailed vLLM deployment instructions, please refer to our [vLLM Deployment Guide](./docs/vllm_deployment_guide.md).
198
+ Alternatively, you can also deploy using Transformers directly. For detailed Transformers deployment instructions, you can see our [MiniMax-M1 Transformers Deployment Guide](./docs/transformers_deployment_guide.md).
199
+
200
+
201
+ ## 5. Function Calling
202
+
203
+ The MiniMax-M1 model supports function calling capabilities, enabling the model to identify when external functions need to be called and output function call parameters in a structured format. [MiniMax-M1 Function Call Guide](./docs/function_call_guide.md) provides detailed instructions on how to use the function calling feature of MiniMax-M1.
204
+
205
+
206
+ ## 6. Chatbot & API
207
+ For general use and evaluation, we provide a [Chatbot](https://chat.minimax.io/) with online search capabilities and the [online API](https://www.minimax.io/platform/) for developers. For general use and evaluation, we provide the [MiniMax MCP Server](https://github.com/MiniMax-AI/MiniMax-MCP) with video generation, image generation, speech synthesis, and voice cloning for developers.
208
+
209
+
210
+ ## 7. Citation
211
+ ```
212
+ @misc{minimax2025minimaxm1scalingtesttimecompute,
213
+ title={MiniMax-M1: Scaling Test-Time Compute Efficiently with Lightning Attention},
214
+ author={MiniMax},
215
+ year={2025},
216
+ eprint={2506.13585},
217
+ archivePrefix={arXiv},
218
+ primaryClass={cs.CL},
219
+ url={https://arxiv.org/abs/2506.13585},
220
+ }
221
+ ```
222
+
223
+ ## 8. Contact Us
224
+ Contact us at [[email protected]](mailto:[email protected]).
config.json ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MiniMaxM1ForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "attn_type_list": [
7
+ 0,
8
+ 0,
9
+ 0,
10
+ 0,
11
+ 0,
12
+ 0,
13
+ 0,
14
+ 1,
15
+ 0,
16
+ 0,
17
+ 0,
18
+ 0,
19
+ 0,
20
+ 0,
21
+ 0,
22
+ 1,
23
+ 0,
24
+ 0,
25
+ 0,
26
+ 0,
27
+ 0,
28
+ 0,
29
+ 0,
30
+ 1,
31
+ 0,
32
+ 0,
33
+ 0,
34
+ 0,
35
+ 0,
36
+ 0,
37
+ 0,
38
+ 1,
39
+ 0,
40
+ 0,
41
+ 0,
42
+ 0,
43
+ 0,
44
+ 0,
45
+ 0,
46
+ 1,
47
+ 0,
48
+ 0,
49
+ 0,
50
+ 0,
51
+ 0,
52
+ 0,
53
+ 0,
54
+ 1,
55
+ 0,
56
+ 0,
57
+ 0,
58
+ 0,
59
+ 0,
60
+ 0,
61
+ 0,
62
+ 1,
63
+ 0,
64
+ 0,
65
+ 0,
66
+ 0,
67
+ 0,
68
+ 0,
69
+ 0,
70
+ 1,
71
+ 0,
72
+ 0,
73
+ 0,
74
+ 0,
75
+ 0,
76
+ 0,
77
+ 0,
78
+ 1,
79
+ 0,
80
+ 0,
81
+ 0,
82
+ 0,
83
+ 0,
84
+ 0,
85
+ 0,
86
+ 1
87
+ ],
88
+ "auto_map": {
89
+ "AutoConfig": "configuration_minimax_m1.MiniMaxM1Config",
90
+ "AutoModelForCausalLM": "modeling_minimax_m1.MiniMaxM1ForCausalLM"
91
+ },
92
+ "bos_token_id": null,
93
+ "eos_token_id": null,
94
+ "head_dim": 128,
95
+ "hidden_act": "silu",
96
+ "hidden_size": 6144,
97
+ "initializer_range": 0.02,
98
+ "intermediate_size": 9216,
99
+ "layernorm_full_attention_alpha": 3.5565588200778455,
100
+ "layernorm_full_attention_beta": 1.0,
101
+ "layernorm_linear_attention_alpha": 3.5565588200778455,
102
+ "layernorm_linear_attention_beta": 1.0,
103
+ "layernorm_mlp_alpha": 3.5565588200778455,
104
+ "layernorm_mlp_beta": 1.0,
105
+ "max_position_embeddings": 10240000,
106
+ "model_type": "minimax_m1",
107
+ "num_attention_heads": 64,
108
+ "num_experts_per_tok": 2,
109
+ "num_hidden_layers": 80,
110
+ "num_key_value_heads": 8,
111
+ "num_local_experts": 32,
112
+ "output_router_logits": false,
113
+ "postnorm": true,
114
+ "rms_norm_eps": 1e-05,
115
+ "rope_theta": 10000000,
116
+ "rotary_dim": 64,
117
+ "router_aux_loss_coef": 0.001,
118
+ "router_jitter_noise": 0.0,
119
+ "shared_intermediate_size": 0,
120
+ "shared_moe_mode": "sigmoid",
121
+ "sliding_window": null,
122
+ "tie_word_embeddings": false,
123
+ "transformers_version": "4.45.2",
124
+ "use_cache": true,
125
+ "vocab_size": 200064
126
+ }
127
+
configuration_minimax_m1.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ MiniMaxM1 model configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class MiniMaxM1Config(PretrainedConfig):
11
+ r"""
12
+ This is the configuration class to store the configuration of a [`MiniMaxM1Model`]. It is used to instantiate an
13
+ MiniMaxM1 model according to the specified arguments, defining the model architecture. Instantiating a configuration
14
+ with the defaults will yield a similar configuration to that of the MiniMaxM1.
15
+
16
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
17
+ documentation from [`PretrainedConfig`] for more information.
18
+
19
+
20
+ Args:
21
+ vocab_size (`int`, *optional*, defaults to 32000):
22
+ Vocabulary size of the MiniMaxM1 model. Defines the number of different tokens that can be represented by the
23
+ `inputs_ids` passed when calling [`MiniMaxM1Model`]
24
+ hidden_size (`int`, *optional*, defaults to 4096):
25
+ Dimension of the hidden representations.
26
+ intermediate_size (`int`, *optional*, defaults to 14336):
27
+ Dimension of the MLP representations.
28
+ num_hidden_layers (`int`, *optional*, defaults to 32):
29
+ Number of hidden layers in the Transformer encoder.
30
+ num_attention_heads (`int`, *optional*, defaults to 32):
31
+ Number of attention heads for each attention layer in the Transformer encoder.
32
+ num_key_value_heads (`int`, *optional*, defaults to 8):
33
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
34
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
35
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
36
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
37
+ by meanpooling all the original heads within that group. For more details checkout [this
38
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
39
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
40
+ The non-linear activation function (function or string) in the decoder.
41
+ max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
42
+ The maximum sequence length that this model might ever be used with. MiniMaxM1's sliding window attention
43
+ allows sequence of up to 4096*32 tokens.
44
+ initializer_range (`float`, *optional*, defaults to 0.02):
45
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
46
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
47
+ The epsilon used by the rms normalization layers.
48
+ use_cache (`bool`, *optional*, defaults to `True`):
49
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
50
+ relevant if `config.is_decoder=True`.
51
+ pad_token_id (`int`, *optional*):
52
+ The id of the padding token.
53
+ bos_token_id (`int`, *optional*, defaults to 1):
54
+ The id of the "beginning-of-sequence" token.
55
+ eos_token_id (`int`, *optional*, defaults to 2):
56
+ The id of the "end-of-sequence" token.
57
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
58
+ Whether the model's input and output word embeddings should be tied.
59
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
60
+ The base period of the RoPE embeddings.
61
+ sliding_window (`int`, *optional*):
62
+ Sliding window attention window size. If not specified, will default to `4096`.
63
+ attention_dropout (`float`, *optional*, defaults to 0.0):
64
+ The dropout ratio for the attention probabilities.
65
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
66
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
67
+ parameter
68
+ num_local_experts (`int`, *optional*, defaults to 8):
69
+ Number of experts per Sparse MLP layer.
70
+ output_router_logits (`bool`, *optional*, defaults to `False`):
71
+ Whether or not the router logits should be returned by the model. Enabeling this will also
72
+ allow the model to output the auxiliary loss. See [here]() for more details
73
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
74
+ The aux loss factor for the total loss.
75
+ router_jitter_noise (`float`, *optional*, defaults to 0.0):
76
+ Amount of noise to add to the router.
77
+
78
+ ```python
79
+ >>> from transformers import MiniMaxM1Model, MiniMaxM1Config
80
+
81
+ >>> # Initializing a MiniMaxM1 style configuration
82
+ >>> configuration = MiniMaxM1Config()
83
+
84
+ >>> # Initializing a model from the MiniMaxM1 style configuration
85
+ >>> model = MiniMaxM1Model(configuration)
86
+
87
+ >>> # Accessing the model configuration
88
+ >>> configuration = model.config
89
+ ```"""
90
+
91
+ model_type = "MiniMaxM1"
92
+ keys_to_ignore_at_inference = ["past_key_values"]
93
+
94
+ def __init__(
95
+ self,
96
+ vocab_size=32000,
97
+ hidden_size=4096,
98
+ intermediate_size=14336,
99
+ num_hidden_layers=32,
100
+ num_attention_heads=32,
101
+ num_key_value_heads=8,
102
+ hidden_act="silu",
103
+ max_position_embeddings=4096 * 32,
104
+ initializer_range=0.02,
105
+ rms_norm_eps=1e-5,
106
+ use_cache=True,
107
+ pad_token_id=None,
108
+ bos_token_id=None,
109
+ eos_token_id=None,
110
+ tie_word_embeddings=False,
111
+ rope_theta=1e6,
112
+ sliding_window=None,
113
+ attention_dropout=0.0,
114
+ num_experts_per_tok=2,
115
+ num_local_experts=8,
116
+ output_router_logits=False,
117
+ router_aux_loss_coef=0.001,
118
+ router_jitter_noise=0.0,
119
+ **kwargs,
120
+ ):
121
+ self.vocab_size = vocab_size
122
+ self.max_position_embeddings = max_position_embeddings
123
+ self.hidden_size = hidden_size
124
+ self.intermediate_size = intermediate_size
125
+ self.num_hidden_layers = num_hidden_layers
126
+ self.num_attention_heads = num_attention_heads
127
+ self.sliding_window = sliding_window
128
+
129
+ # for backward compatibility
130
+ if num_key_value_heads is None:
131
+ num_key_value_heads = num_attention_heads
132
+
133
+ self.num_key_value_heads = num_key_value_heads
134
+ self.hidden_act = hidden_act
135
+ self.initializer_range = initializer_range
136
+ self.rms_norm_eps = rms_norm_eps
137
+ self.use_cache = use_cache
138
+ self.rope_theta = rope_theta
139
+ self.attention_dropout = attention_dropout
140
+
141
+ self.num_experts_per_tok = num_experts_per_tok
142
+ self.num_local_experts = num_local_experts
143
+ self.output_router_logits = output_router_logits
144
+ self.router_aux_loss_coef = router_aux_loss_coef
145
+ self.router_jitter_noise = router_jitter_noise
146
+ super().__init__(
147
+ pad_token_id=pad_token_id,
148
+ bos_token_id=bos_token_id,
149
+ eos_token_id=eos_token_id,
150
+ tie_word_embeddings=tie_word_embeddings,
151
+ **kwargs,
152
+ )
docs/function_call_guide.md ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MiniMax-M1 Function Call Guide
2
+
3
+ [FunctionCall中文使用指南](./function_call_guide_cn.md)
4
+
5
+ ## 📖 Introduction
6
+
7
+ The MiniMax-M1 model supports function calling capabilities, enabling the model to identify when external functions need to be called and output function call parameters in a structured format. This document provides detailed instructions on how to use the function calling feature of MiniMax-M1.
8
+
9
+ ## 🚀 Quick Start
10
+
11
+ ### Using Chat Template
12
+
13
+ MiniMax-M1 uses a specific chat template format to handle function calls. The chat template is defined in `tokenizer_config.json`, and you can use it in your code through the template.
14
+
15
+ ```python
16
+ from transformers import AutoTokenizer
17
+
18
+ def get_default_tools():
19
+ return [
20
+ {
21
+ {
22
+ "name": "get_current_weather",
23
+ "description": "Get the latest weather for a location",
24
+ "parameters": {
25
+ "type": "object",
26
+ "properties": {
27
+ "location": {
28
+ "type": "string",
29
+ "description": "A certain city, such as Beijing, Shanghai"
30
+ }
31
+ },
32
+ }
33
+ "required": ["location"],
34
+ "type": "object"
35
+ }
36
+ }
37
+ ]
38
+
39
+ # Load model and tokenizer
40
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
41
+ prompt = "What's the weather like in Shanghai today?"
42
+ messages = [
43
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant created by Minimax based on MiniMax-M1 model."}]},
44
+ {"role": "user", "content": [{"type": "text", "text": prompt}]},
45
+ ]
46
+
47
+ # Enable function call tools
48
+ tools = get_default_tools()
49
+
50
+ # Apply chat template and add tool definitions
51
+ text = tokenizer.apply_chat_template(
52
+ messages,
53
+ tokenize=False,
54
+ add_generation_prompt=True,
55
+ tools=tools
56
+ )
57
+ ```
58
+
59
+ ## 🛠️ Function Call Definition
60
+
61
+ ### Function Structure
62
+
63
+ Function calls need to be defined in the `tools` field of the request body. Each function consists of the following components:
64
+
65
+ ```json
66
+ {
67
+ "tools": [
68
+ {
69
+ "name": "search_web",
70
+ "description": "Search function.",
71
+ "parameters": {
72
+ "properties": {
73
+ "query_list": {
74
+ "description": "Keywords for search, with list element count of 1.",
75
+ "items": { "type": "string" },
76
+ "type": "array"
77
+ },
78
+ "query_tag": {
79
+ "description": "Classification of the query",
80
+ "items": { "type": "string" },
81
+ "type": "array"
82
+ }
83
+ },
84
+ "required": [ "query_list", "query_tag" ],
85
+ "type": "object"
86
+ }
87
+ }
88
+ ]
89
+ }
90
+ ```
91
+
92
+ **Field Descriptions:**
93
+ - `name`: Function name
94
+ - `description`: Function description
95
+ - `parameters`: Function parameter definition
96
+ - `properties`: Parameter property definitions, where key is the parameter name and value contains detailed parameter description
97
+ - `required`: List of required parameters
98
+ - `type`: Parameter type (usually "object")
99
+
100
+ ### Internal Model Processing Format
101
+
102
+ When processed internally by the model, function definitions are converted to a special format and concatenated to the input text:
103
+
104
+ ```
105
+ ]~!b[]~b]system ai_setting=MiniMax AI
106
+ MiniMax AI is an AI assistant independently developed by MiniMax. [e~[
107
+ ]~b]system tool_setting=tools
108
+ You are provided with these tools:
109
+ <tools>
110
+ {"name": "search_web", "description": "Search function.", "parameters": {"properties": {"query_list": {"description": "Keywords for search, with list element count of 1.", "items": {"type": "string"}, "type": "array"}, "query_tag": {"description": "Classification of the query", "items": {"type": "string"}, "type": "array"}}, "required": ["query_list", "query_tag"], "type": "object"}}
111
+ </tools>
112
+
113
+ If you need to call tools, please respond with <tool_calls></tool_calls> XML tags, and provide tool-name and json-object of arguments, following the format below:
114
+ <tool_calls>
115
+ {"name": <tool-name>, "arguments": <args-json-object>}
116
+ ...
117
+ </tool_calls>[e~[
118
+ ]~b]user name=User
119
+ When were the most recent launch events for OpenAI and Gemini?[e~[
120
+ ]~b]ai name=MiniMax AI
121
+ ```
122
+
123
+ ### Model Output Format
124
+
125
+ The model outputs function calls in the following format:
126
+
127
+ ```xml
128
+ <think>
129
+ Okay, I will search for the OpenAI and Gemini latest release.
130
+ </think>
131
+ <tool_calls>
132
+ {"name": "search_web", "arguments": {"query_tag": ["technology", "events"], "query_list": ["\"OpenAI\" \"latest\" \"release\""]}}
133
+ {"name": "search_web", "arguments": {"query_tag": ["technology", "events"], "query_list": ["\"Gemini\" \"latest\" \"release\""]}}
134
+ </tool_calls>
135
+ ```
136
+
137
+ ## 📥 Function Call Result Processing
138
+
139
+ ### Parsing Function Calls
140
+
141
+ You can use the following code to parse function calls from the model output:
142
+
143
+ ```python
144
+ import re
145
+ import json
146
+
147
+ def parse_function_calls(content: str):
148
+ """
149
+ Parse function calls from model output
150
+ """
151
+ function_calls = []
152
+
153
+ # Match content within <tool_calls> tags
154
+ tool_calls_pattern = r"<tool_calls>(.*?)</tool_calls>"
155
+ tool_calls_match = re.search(tool_calls_pattern, content, re.DOTALL)
156
+
157
+ if not tool_calls_match:
158
+ return function_calls
159
+
160
+ tool_calls_content = tool_calls_match.group(1).strip()
161
+
162
+ # Parse each function call (one JSON object per line)
163
+ for line in tool_calls_content.split('\n'):
164
+ line = line.strip()
165
+ if not line:
166
+ continue
167
+
168
+ try:
169
+ # Parse JSON format function call
170
+ call_data = json.loads(line)
171
+ function_name = call_data.get("name")
172
+ arguments = call_data.get("arguments", {})
173
+
174
+ function_calls.append({
175
+ "name": function_name,
176
+ "arguments": arguments
177
+ })
178
+
179
+ print(f"Function call: {function_name}, Arguments: {arguments}")
180
+
181
+ except json.JSONDecodeError as e:
182
+ print(f"Parameter parsing failed: {line}, Error: {e}")
183
+
184
+ return function_calls
185
+
186
+ # Example: Handle weather query function
187
+ def execute_function_call(function_name: str, arguments: dict):
188
+ """
189
+ Execute function call and return result
190
+ """
191
+ if function_name == "get_current_weather":
192
+ location = arguments.get("location", "Unknown location")
193
+ # Build function execution result
194
+ return {
195
+ "role": "tool",
196
+ "name": function_name,
197
+ "content": json.dumps({
198
+ "location": location,
199
+ "temperature": "25",
200
+ "unit": "celsius",
201
+ "weather": "Sunny"
202
+ }, ensure_ascii=False)
203
+ }
204
+ elif function_name == "search_web":
205
+ query_list = arguments.get("query_list", [])
206
+ query_tag = arguments.get("query_tag", [])
207
+ # Simulate search results
208
+ return {
209
+ "role": "tool",
210
+ "name": function_name,
211
+ "content": f"Search keywords: {query_list}, Categories: {query_tag}\nSearch results: Relevant information found"
212
+ }
213
+
214
+ return None
215
+ ```
216
+
217
+ ### Returning Function Execution Results to the Model
218
+
219
+ After successfully parsing function calls, you should add the function execution results to the conversation history so that the model can access and utilize this information in subsequent interactions.
220
+
221
+ #### Single Result
222
+
223
+ If the model decides to call `search_web`, we suggest you to return the function result in the following format, with the `name` field set to the specific tool name.
224
+
225
+ ```json
226
+ {
227
+ "data": [
228
+ {
229
+ "role": "tool",
230
+ "name": "search_web",
231
+ "content": "search_result"
232
+ }
233
+ ]
234
+ }
235
+ ```
236
+
237
+ Corresponding model input format:
238
+ ```
239
+ ]~b]tool name=search_web
240
+ search_result[e~[
241
+ ```
242
+
243
+
244
+ #### Multiple Result
245
+ If the model decides to call `search_web` and `get_current_weather` at the same time, we suggest you to return the multiple function results in the following format, with the `name` field set to "tools", and use the `content` field to contain multiple results.
246
+
247
+
248
+ ```json
249
+ {
250
+ "data": [
251
+ {
252
+ "role": "tool",
253
+ "name": "tools",
254
+ "content": "Tool name: search_web\nTool result: test_result1\n\nTool name: get_current_weather\nTool result: test_result2"
255
+ }
256
+ ]
257
+ }
258
+ ```
259
+
260
+ Corresponding model input format:
261
+ ```
262
+ ]~b]tool name=tools
263
+ Tool name: search_web
264
+ Tool result: test_result1
265
+
266
+ Tool name: get_current_weather
267
+ Tool result: test_result2[e~[
268
+ ```
269
+
270
+ While we suggest following the above formats, as long as the model input is easy to understand, the specific values of `name` and `content` is entirely up to the caller.
docs/function_call_guide_cn.md ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MiniMax-M1 函数调用(Function Call)功能指南
2
+
3
+ ## 📖 简介
4
+
5
+ MiniMax-M1 模型支持函数调用功能,使模型能够识别何时需要调用外部函数,并以结构化格式输出函数调用参数。本文档详细介绍了如何使用 MiniMax-M1 的函数调用功能。
6
+
7
+ ## 🚀 快速开始
8
+
9
+ ### 聊天模板使用
10
+
11
+ MiniMax-M1 使用特定的聊天模板格式处理函数调用。聊天模板定义在 `tokenizer_config.json` 中,你可以在代码中通过 template 来进行使用。
12
+
13
+ ```python
14
+ from transformers import AutoTokenizer
15
+
16
+ def get_default_tools():
17
+ return [
18
+ {
19
+ {
20
+ "name": "get_current_weather",
21
+ "description": "Get the latest weather for a location",
22
+ "parameters": {
23
+ "type": "object",
24
+ "properties": {
25
+ "location": {
26
+ "type": "string",
27
+ "description": "A certain city, such as Beijing, Shanghai"
28
+ }
29
+ },
30
+ }
31
+ "required": ["location"],
32
+ "type": "object"
33
+ }
34
+ }
35
+ ]
36
+
37
+ # 加载模型和分词器
38
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
39
+ prompt = "What's the weather like in Shanghai today?"
40
+ messages = [
41
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant created by Minimax based on MiniMax-M1 model."}]},
42
+ {"role": "user", "content": [{"type": "text", "text": prompt}]},
43
+ ]
44
+
45
+ # 启用函数调用工具
46
+ tools = get_default_tools()
47
+
48
+ # 应用聊天模板,并加入工具定义
49
+ text = tokenizer.apply_chat_template(
50
+ messages,
51
+ tokenize=False,
52
+ add_generation_prompt=True,
53
+ tools=tools
54
+ )
55
+ ```
56
+
57
+ ## 🛠️ 函数调用的定义
58
+
59
+ ### 函数结构体
60
+
61
+ 函数调用需要在请求体中定义 `tools` 字段,每个函数由以下部分组成:
62
+
63
+ ```json
64
+ {
65
+ "tools": [
66
+ {
67
+ "name": "search_web",
68
+ "description": "搜索函数。",
69
+ "parameters": {
70
+ "properties": {
71
+ "query_list": {
72
+ "description": "进行搜索的关键词,列表元素个数为1。",
73
+ "items": { "type": "string" },
74
+ "type": "array"
75
+ },
76
+ "query_tag": {
77
+ "description": "query的分类",
78
+ "items": { "type": "string" },
79
+ "type": "array"
80
+ }
81
+ },
82
+ "required": [ "query_list", "query_tag" ],
83
+ "type": "object"
84
+ }
85
+ }
86
+ ]
87
+ }
88
+ ```
89
+
90
+ **字段说明:**
91
+ - `name`: 函数名称
92
+ - `description`: 函数功能描述
93
+ - `parameters`: 函数参数定义
94
+ - `properties`: 参数属性定义,key 是参数名,value 包含参数的详细描述
95
+ - `required`: 必填参数列表
96
+ - `type`: 参数类型(通常为 "object")
97
+
98
+ ### 模型内部处理格式
99
+
100
+ 在模型内部处理时,函数定义会被转换为特殊格式并拼接到输入文本中:
101
+
102
+ ```
103
+ ]~!b[]~b]system ai_setting=MiniMax AI
104
+ MiniMax AI是由上海稀宇科技有限公司(MiniMax)自主研发的AI助理。[e~[
105
+ ]~b]system tool_setting=tools
106
+ You are provided with these tools:
107
+ <tools>
108
+ {"name": "search_web", "description": "搜索函数。", "parameters": {"properties": {"query_list": {"description": "进行搜索的关键词,列表元素个数为1。", "items": {"type": "string"}, "type": "array"}, "query_tag": {"description": "query的分类", "items": {"type": "string"}, "type": "array"}}, "required": ["query_list", "query_tag"], "type": "object"}}
109
+ </tools>
110
+
111
+ If you need to call tools, please respond with <tool_calls></tool_calls> XML tags, and provide tool-name and json-object of arguments, following the format below:
112
+ <tool_calls>
113
+ {"name": <tool-name>, "arguments": <args-json-object>}
114
+ ...
115
+ </tool_calls>[e~[
116
+ ]~b]user name=用户
117
+ OpenAI 和 Gemini 的最近一次发布会都是什么时候?[e~[
118
+ ]~b]ai name=MiniMax AI
119
+ ```
120
+
121
+ ### 模型输出格式
122
+
123
+ 模型会以以下格式输出函数调用:
124
+
125
+ ```xml
126
+ <think>
127
+ Okay, I will search for the OpenAI and Gemini latest release.
128
+ </think>
129
+ <tool_calls>
130
+ {"name": "search_web", "arguments": {"query_tag": ["technology", "events"], "query_list": ["\"OpenAI\" \"latest\" \"release\""]}}
131
+ {"name": "search_web", "arguments": {"query_tag": ["technology", "events"], "query_list": ["\"Gemini\" \"latest\" \"release\""]}}
132
+ </tool_calls>
133
+ ```
134
+
135
+ ## 📥 函数调用结果处理
136
+
137
+ ### 解析函数调用
138
+
139
+ 您可以使用以下代码解析模型输出的函数调用:
140
+
141
+ ```python
142
+ import re
143
+ import json
144
+
145
+ def parse_function_calls(content: str):
146
+ """
147
+ 解析模型输出中的函数调用
148
+ """
149
+ function_calls = []
150
+
151
+ # 匹配 <tool_calls> 标签内的内容
152
+ tool_calls_pattern = r"<tool_calls>(.*?)</tool_calls>"
153
+ tool_calls_match = re.search(tool_calls_pattern, content, re.DOTALL)
154
+
155
+ if not tool_calls_match:
156
+ return function_calls
157
+
158
+ tool_calls_content = tool_calls_match.group(1).strip()
159
+
160
+ # 解析每个函数调用(每行一个JSON对象)
161
+ for line in tool_calls_content.split('\n'):
162
+ line = line.strip()
163
+ if not line:
164
+ continue
165
+
166
+ try:
167
+ # 解析JSON格式的函数调用
168
+ call_data = json.loads(line)
169
+ function_name = call_data.get("name")
170
+ arguments = call_data.get("arguments", {})
171
+
172
+ function_calls.append({
173
+ "name": function_name,
174
+ "arguments": arguments
175
+ })
176
+
177
+ print(f"调用函数: {function_name}, 参数: {arguments}")
178
+
179
+ except json.JSONDecodeError as e:
180
+ print(f"参数解析失败: {line}, 错误: {e}")
181
+
182
+ return function_calls
183
+
184
+ # 示例:处理天气查询函数
185
+ def execute_function_call(function_name: str, arguments: dict):
186
+ """
187
+ 执行函数调用并返回结果
188
+ """
189
+ if function_name == "get_current_weather":
190
+ location = arguments.get("location", "未知位置")
191
+ # 构建函数执行结果
192
+ return {
193
+ "role": "tool",
194
+ "name": function_name,
195
+ "content": json.dumps({
196
+ "location": location,
197
+ "temperature": "25",
198
+ "unit": "celsius",
199
+ "weather": "晴朗"
200
+ }, ensure_ascii=False)
201
+ }
202
+ elif function_name == "search_web":
203
+ query_list = arguments.get("query_list", [])
204
+ query_tag = arguments.get("query_tag", [])
205
+ # 模拟搜索结果
206
+ return {
207
+ "role": "tool",
208
+ "name": function_name,
209
+ "content": f"搜索关键词: {query_list}, 分类: {query_tag}\n搜索结果: 相关信息已找到"
210
+ }
211
+
212
+ return None
213
+ ```
214
+
215
+ ### 将函数执行结果返回给模型
216
+
217
+ 成功解析函数调用后,您应将函数执行结果添加到对话历史中,以便模型在后续交互中能够访问和利用这些信息。
218
+
219
+ #### 单个结果
220
+
221
+ 假如模型调用了 `search_web` 函数,您可以参考如下格式添加执行结果,`name` 字段为具体的函数名称。
222
+
223
+ ```json
224
+ {
225
+ "data": [
226
+ {
227
+ "role": "tool",
228
+ "name": "search_web",
229
+ "content": "search_result"
230
+ }
231
+ ]
232
+ }
233
+ ```
234
+
235
+ 对应如下的模型输入格式:
236
+ ```
237
+ ]~b]tool name=search_web
238
+ search_result[e~[
239
+ ```
240
+
241
+
242
+ #### 多个结果
243
+ 假如模型同时调用了 `search_web` 和 `get_current_weather` 函数,您可以参考如下格式添加执行结果,`name` 字段为"tools",`content`包含多个结果。
244
+
245
+ ```json
246
+ {
247
+ "data": [
248
+ {
249
+ "role": "tool",
250
+ "name": "tools",
251
+ "content": "Tool name: search_web\nTool result: test_result1\n\nTool name: get_current_weather\nTool result: test_result2"
252
+ }
253
+ ]
254
+ }
255
+ ```
256
+
257
+ 对应如下的模型输入格式:
258
+ ```
259
+ ]~b]tool name=tools
260
+ Tool name: search_web
261
+ Tool result: test_result1
262
+
263
+ Tool name: get_current_weather
264
+ Tool result: test_result2[e~[
265
+ ```
266
+
267
+ 虽然我们建议您参考以上格式,但只要返回给模型的输入易于理解,`name` 和 `content` 的具体内容完全由您自主决定。
docs/transformers_deployment_guide.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 MiniMax Model Transformers Deployment Guide
2
+
3
+ [Transformers中文版部署指南](./transformers_deployment_guide_cn.md)
4
+
5
+ ## 📖 Introduction
6
+
7
+ This guide will help you deploy the MiniMax-M1 model using the [Transformers](https://huggingface.co/docs/transformers/index) library. Transformers is a widely used deep learning library that provides a rich collection of pre-trained models and flexible model operation interfaces.
8
+
9
+ ## 🛠️ Environment Setup
10
+
11
+ ### Installing Transformers
12
+
13
+ ```bash
14
+ pip install transformers torch accelerate
15
+ ```
16
+
17
+ ## 📋 Basic Usage Example
18
+
19
+ The pre-trained model can be used as follows:
20
+
21
+ ```python
22
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
23
+
24
+ MODEL_PATH = "{MODEL_PATH}"
25
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", trust_remote_code=True)
26
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
27
+
28
+ messages = [
29
+ {"role": "user", "content": [{"type": "text", "text": "What is your favourite condiment?"}]},
30
+ {"role": "assistant", "content": [{"type": "text", "text": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}]},
31
+ {"role": "user", "content": [{"type": "text", "text": "Do you have mayonnaise recipes?"}]}
32
+ ]
33
+
34
+ text = tokenizer.apply_chat_template(
35
+ messages,
36
+ tokenize=False,
37
+ add_generation_prompt=True
38
+ )
39
+
40
+ model_inputs = tokenizer(text, return_tensors="pt").to(model.device)
41
+
42
+ generation_config = GenerationConfig(
43
+ max_new_tokens=20,
44
+ eos_token_id=tokenizer.eos_token_id,
45
+ use_cache=True,
46
+ )
47
+
48
+ generated_ids = model.generate(**model_inputs, generation_config=generation_config)
49
+
50
+ generated_ids = [
51
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
52
+ ]
53
+
54
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
55
+ print(response)
56
+ ```
57
+
58
+ ## ⚡ Performance Optimization
59
+
60
+ ### Speeding up with Flash Attention
61
+
62
+ The code snippet above showcases inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
63
+
64
+ First, make sure to install the latest version of Flash Attention 2:
65
+
66
+ ```bash
67
+ pip install -U flash-attn --no-build-isolation
68
+ ```
69
+
70
+ Also make sure that you have hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the [Flash Attention repository](https://github.com/Dao-AILab/flash-attention). Additionally, ensure you load your model in half-precision (e.g. `torch.float16`).
71
+
72
+ To load and run a model using Flash Attention-2, refer to the snippet below:
73
+
74
+ ```python
75
+ import torch
76
+ from transformers import AutoModelForCausalLM, AutoTokenizer
77
+
78
+ MODEL_PATH = "{MODEL_PATH}"
79
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, trust_remote_code=True, torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto")
80
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
81
+
82
+ prompt = "My favourite condiment is"
83
+
84
+ model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
85
+ generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
86
+ response = tokenizer.batch_decode(generated_ids)[0]
87
+ print(response)
88
+ ```
89
+
90
+ ## 📮 Getting Support
91
+
92
+ If you encounter any issues while deploying the MiniMax-M1 model:
93
+ - Please check our official documentation
94
+ - Contact our technical support team through official channels
95
+ - Submit an Issue on our GitHub repository
96
+
97
+ We continuously optimize the deployment experience on Transformers and welcome your feedback!
docs/transformers_deployment_guide_cn.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 MiniMax 模型 Transformers 部署指南
2
+
3
+ ## 📖 简介
4
+
5
+ 本指南将帮助您使用 [Transformers](https://huggingface.co/docs/transformers/index) 库部署 MiniMax-M1 模型。Transformers 是一个广泛使用的深度学习库,提供了丰富的预训练模型和灵活的模型操作接口。
6
+
7
+ ## 🛠️ 环境准备
8
+
9
+ ### 安装 Transformers
10
+
11
+ ```bash
12
+ pip install transformers torch accelerate
13
+ ```
14
+
15
+ ## 📋 基本使用示例
16
+
17
+ 预训练模型可以按照以下方式使用:
18
+
19
+ ```python
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
21
+
22
+ MODEL_PATH = "{MODEL_PATH}"
23
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", trust_remote_code=True)
24
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
25
+
26
+ messages = [
27
+ {"role": "user", "content": [{"type": "text", "text": "What is your favourite condiment?"}]},
28
+ {"role": "assistant", "content": [{"type": "text", "text": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}]},
29
+ {"role": "user", "content": [{"type": "text", "text": "Do you have mayonnaise recipes?"}]}
30
+ ]
31
+
32
+ text = tokenizer.apply_chat_template(
33
+ messages,
34
+ tokenize=False,
35
+ add_generation_prompt=True
36
+ )
37
+
38
+ model_inputs = tokenizer(text, return_tensors="pt").to(model.device)
39
+
40
+ generation_config = GenerationConfig(
41
+ max_new_tokens=20,
42
+ eos_token_id=tokenizer.eos_token_id,
43
+ use_cache=True,
44
+ )
45
+
46
+ generated_ids = model.generate(**model_inputs, generation_config=generation_config)
47
+
48
+ generated_ids = [
49
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
50
+ ]
51
+
52
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
53
+ print(response)
54
+ ```
55
+
56
+ ## ⚡ 性能优化
57
+
58
+ ### 使用 Flash Attention 加速
59
+
60
+ 上面的代码片段展示了不使用任何优化技巧的推理过程。但通过利用 [Flash Attention](../perf_train_gpu_one#flash-attention-2),可以大幅加速模型,因为它提供了模型内部使用的注意力机制的更快实现。
61
+
62
+ 首先,确保安装最新版本的 Flash Attention 2:
63
+
64
+ ```bash
65
+ pip install -U flash-attn --no-build-isolation
66
+ ```
67
+
68
+ 还要确保您拥有与 Flash-Attention 2 兼容的硬件。在[Flash Attention 官方仓库](https://github.com/Dao-AILab/flash-attention)的官方文档中了解更多信息。此外,请确保以半精度(例如 `torch.float16`)加载模型。
69
+
70
+ 要使用 Flash Attention-2 加载和运行模型,请参考以下代码片段:
71
+
72
+ ```python
73
+ import torch
74
+ from transformers import AutoModelForCausalLM, AutoTokenizer
75
+
76
+ MODEL_PATH = "{MODEL_PATH}"
77
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, trust_remote_code=True, torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto")
78
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
79
+
80
+ prompt = "My favourite condiment is"
81
+
82
+ model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
83
+ generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
84
+ response = tokenizer.batch_decode(generated_ids)[0]
85
+ print(response)
86
+ ```
87
+
88
+ ## 📮 获取支持
89
+
90
+ 如果您在部署 MiniMax-M1 模型过程中遇到任何问题:
91
+ - 请查看我们的官方文档
92
+ - 通过官方渠道联系我们的技术支持团队
93
+ - 在我们的 GitHub 仓库提交 Issue
94
+
95
+ 我们会持续优化 Transformers 上的部署体验,欢迎您的反馈!
docs/vllm_deployment_guide.md ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 MiniMax Models vLLM Deployment Guide
2
+
3
+ [vLLM中文版部署指南](./vllm_deployment_guide_cn.md)
4
+
5
+ ## 📖 Introduction
6
+
7
+ We recommend using [vLLM](https://docs.vllm.ai/en/latest/) to deploy [MiniMax-M1](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k) model. Based on our testing, vLLM performs excellently when deploying this model, with the following features:
8
+
9
+ - 🔥 Outstanding service throughput performance
10
+ - ⚡ Efficient and intelligent memory management
11
+ - 📦 Powerful batch request processing capability
12
+ - ⚙️ Deeply optimized underlying performance
13
+
14
+ The MiniMax-M1 model can run efficiently on a single server equipped with 8 H800 or 8 H20 GPUs. In terms of hardware configuration, a server with 8 H800 GPUs can process context inputs up to 2 million tokens, while a server equipped with 8 H20 GPUs can support ultra-long context processing capabilities of up to 5 million tokens.
15
+
16
+ ## 💾 Obtaining MiniMax Models
17
+
18
+ ### MiniMax-M1 Model Obtaining
19
+
20
+ You can download the model from our official HuggingFace repository: [MiniMax-M1-40k](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k), [MiniMax-M1-80k](https://huggingface.co/MiniMaxAI/MiniMax-M1-80k)
21
+
22
+ Download command:
23
+ ```
24
+ pip install -U huggingface-hub
25
+ huggingface-cli download MiniMaxAI/MiniMax-M1-40k
26
+ # huggingface-cli download MiniMaxAI/MiniMax-M1-80k
27
+
28
+ # If you encounter network issues, you can set a proxy
29
+ export HF_ENDPOINT=https://hf-mirror.com
30
+ ```
31
+
32
+ Or download using git:
33
+
34
+ ```bash
35
+ git lfs install
36
+ git clone https://huggingface.co/MiniMaxAI/MiniMax-M1-40k
37
+ git clone https://huggingface.co/MiniMaxAI/MiniMax-M1-80k
38
+ ```
39
+
40
+ ⚠️ **Important Note**: Please ensure that [Git LFS](https://git-lfs.github.com/) is installed on your system, which is necessary for completely downloading the model weight files.
41
+
42
+ ## 🛠️ Deployment Options
43
+
44
+ ### Option 1: Deploy Using Docker (Recommended)
45
+
46
+ To ensure consistency and stability of the deployment environment, we recommend using Docker for deployment.
47
+
48
+ ⚠️ **Version Requirements**:
49
+ - MiniMax-M1 model requires vLLM version 0.8.3 or later for full support
50
+ - If you are using a Docker image with vLLM version lower than the required version, you will need to:
51
+ 1. Update to the latest vLLM code
52
+ 2. Recompile vLLM from source. Follow the compilation instructions in Solution 2 of the Common Issues section
53
+ - Special Note: For vLLM versions between 0.8.3 and 0.9.2, you need to modify the model configuration:
54
+ 1. Open `config.json`
55
+ 2. Change `config['architectures'] = ["MiniMaxM1ForCausalLM"]` to `config['architectures'] = ["MiniMaxText01ForCausalLM"]`
56
+
57
+ 1. Get the container image:
58
+ ```bash
59
+ docker pull vllm/vllm-openai:v0.8.3
60
+ ```
61
+
62
+ 2. Run the container:
63
+ ```bash
64
+ # Set environment variables
65
+ IMAGE=vllm/vllm-openai:v0.8.3
66
+ MODEL_DIR=<model storage path>
67
+ CODE_DIR=<code path>
68
+ NAME=MiniMaxImage
69
+
70
+ # Docker run configuration
71
+ DOCKER_RUN_CMD="--network=host --privileged --ipc=host --ulimit memlock=-1 --shm-size=2gb --rm --gpus all --ulimit stack=67108864"
72
+
73
+ # Start the container
74
+ sudo docker run -it \
75
+ -v $MODEL_DIR:$MODEL_DIR \
76
+ -v $CODE_DIR:$CODE_DIR \
77
+ --name $NAME \
78
+ $DOCKER_RUN_CMD \
79
+ $IMAGE /bin/bash
80
+ ```
81
+
82
+
83
+ ### Option 2: Direct Installation of vLLM
84
+
85
+ If your environment meets the following requirements:
86
+
87
+ - CUDA 12.1
88
+ - PyTorch 2.1
89
+
90
+ You can directly install vLLM
91
+
92
+ Installation command:
93
+ ```bash
94
+ pip install vllm
95
+ ```
96
+
97
+ 💡 If you are using other environment configurations, please refer to the [vLLM Installation Guide](https://docs.vllm.ai/en/latest/getting_started/installation.html)
98
+
99
+ ## 🚀 Starting the Service
100
+
101
+ ### Launch MiniMax-M1 Service
102
+
103
+ ```bash
104
+ export SAFETENSORS_FAST_GPU=1
105
+ export VLLM_USE_V1=0
106
+ python3 -m vllm.entrypoints.openai.api_server \
107
+ --model <model storage path> \
108
+ --tensor-parallel-size 8 \
109
+ --trust-remote-code \
110
+ --quantization experts_int8 \
111
+ --max_model_len 4096 \
112
+ --dtype bfloat16
113
+ ```
114
+
115
+ ### API Call Example
116
+
117
+ ```bash
118
+ curl http://localhost:8000/v1/chat/completions \
119
+ -H "Content-Type: application/json" \
120
+ -d '{
121
+ "model": "MiniMaxAI/MiniMax-M1",
122
+ "messages": [
123
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
124
+ {"role": "user", "content": [{"type": "text", "text": "Who won the world series in 2020?"}]}
125
+ ]
126
+ }'
127
+ ```
128
+
129
+ ## ❗ Common Issues
130
+
131
+ ### Module Loading Problems
132
+ If you encounter the following error:
133
+ ```
134
+ import vllm._C # noqa
135
+ ModuleNotFoundError: No module named 'vllm._C'
136
+ ```
137
+
138
+ Or
139
+
140
+ ```
141
+ MiniMax-M1 model is not currently supported
142
+ ```
143
+
144
+ We provide two solutions:
145
+
146
+ #### Solution 1: Copy Dependency Files
147
+ ```bash
148
+ cd <working directory>
149
+ git clone https://github.com/vllm-project/vllm.git
150
+ cd vllm
151
+ cp /usr/local/lib/python3.12/dist-packages/vllm/*.so vllm
152
+ cp -r /usr/local/lib/python3.12/dist-packages/vllm/vllm_flash_attn/* vllm/vllm_flash_attn
153
+ ```
154
+
155
+ #### Solution 2: Install from Source
156
+ ```bash
157
+ cd <working directory>
158
+ git clone https://github.com/vllm-project/vllm.git
159
+
160
+ cd vllm/
161
+ pip install -e .
162
+ ```
163
+
164
+ ## 📮 Getting Support
165
+
166
+ If you encounter any issues while deploying MiniMax-M1 model:
167
+ - Please check our official documentation
168
+ - Contact our technical support team through official channels
169
+ - Submit an [Issue](https://github.com/MiniMax-AI/MiniMax-M1/issues) on our GitHub repository
170
+
171
+ We will continuously optimize the deployment experience of this model and welcome your feedback!
docs/vllm_deployment_guide_cn.md ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 MiniMax 模型 vLLM 部署指南
2
+
3
+ ## 📖 简介
4
+
5
+ 我们推荐使用 [vLLM](https://docs.vllm.ai/en/latest/) 来部署 [MiniMax-M1](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k) 模型。经过我们的测试,vLLM 在部署这个模型时表现出色,具有以下特点:
6
+
7
+ - 🔥 卓越的服务吞吐量性能
8
+ - ⚡ 高效智能的内存管理机制
9
+ - 📦 强大的批量请求处理能力
10
+ - ⚙️ 深度优化的底层性能
11
+
12
+ MiniMax-M1 模型可在单台配备8个H800或8个H20 GPU的服务器上高效运行。在硬件配置方面,搭载8个H800 GPU的服务器可处理长达200万token的上下文输入,而配备8个H20 GPU的服务器则能够支持高达500万token的超长上下文处理能力。
13
+
14
+ ## 💾 获取 MiniMax 模型
15
+
16
+ ### MiniMax-M1 模型获取
17
+
18
+ 您可以从我们的官方 HuggingFace 仓库下载模型:[MiniMax-M1-40k](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k)、[MiniMax-M1-80k](https://huggingface.co/MiniMaxAI/MiniMax-M1-80k)
19
+
20
+ 下载命令:
21
+ ```
22
+ pip install -U huggingface-hub
23
+ huggingface-cli download MiniMaxAI/MiniMax-M1-40k
24
+ # huggingface-cli download MiniMaxAI/MiniMax-M1-80k
25
+
26
+ # 如果遇到网络问题,可以设置代理
27
+ export HF_ENDPOINT=https://hf-mirror.com
28
+ ```
29
+
30
+ 或者使用 git 下载:
31
+
32
+ ```bash
33
+ git lfs install
34
+ git clone https://huggingface.co/MiniMaxAI/MiniMax-M1-40k
35
+ git clone https://huggingface.co/MiniMaxAI/MiniMax-M1-80k
36
+ ```
37
+
38
+ ⚠️ **重要提示**:请确保系统已安装 [Git LFS](https://git-lfs.github.com/),这对于完整下载模型权重文件是必需的。
39
+
40
+ ## 🛠️ 部署方案
41
+
42
+ ### 方案一:使用 Docker 部署(推荐)
43
+
44
+ 为确保部署环境的一致性和稳定性,我们推荐使用 Docker 进行部署。
45
+
46
+ ⚠️ **版本要求**:
47
+ - 基础要求:vLLM 版本必须 ≥ 0.8.3,以确保对 MiniMax-M1 模型的完整支持
48
+ - 特殊说明:如果使用 vLLM 0.8.3 至 0.9.2 之间的版本,需要修改模型配置文件:
49
+ - 打开 `config.json`
50
+ - 将 `config['architectures'] = ["MiniMaxM1ForCausalLM"]` 修改为 `config['architectures'] = ["MiniMaxText01ForCausalLM"]`
51
+
52
+ 1. 获取容器镜像:
53
+ ```bash
54
+ docker pull vllm/vllm-openai:v0.8.3
55
+ ```
56
+
57
+ 2. 运行容器:
58
+ ```bash
59
+ # 设置环境变量
60
+ IMAGE=vllm/vllm-openai:v0.8.3
61
+ MODEL_DIR=<模型存放路径>
62
+ CODE_DIR=<代码路径>
63
+ NAME=MiniMaxImage
64
+
65
+ # Docker运行配置
66
+ DOCKER_RUN_CMD="--network=host --privileged --ipc=host --ulimit memlock=-1 --shm-size=2gb --rm --gpus all --ulimit stack=67108864"
67
+
68
+ # 启动容器
69
+ sudo docker run -it \
70
+ -v $MODEL_DIR:$MODEL_DIR \
71
+ -v $CODE_DIR:$CODE_DIR \
72
+ --name $NAME \
73
+ $DOCKER_RUN_CMD \
74
+ $IMAGE /bin/bash
75
+ ```
76
+
77
+
78
+ ### 方案二:直接安装 vLLM
79
+
80
+ 如果您的环境满足以下要求:
81
+
82
+ - CUDA 12.1
83
+ - PyTorch 2.1
84
+
85
+ 可以直接安装 vLLM
86
+
87
+ 安装命令:
88
+ ```bash
89
+ pip install vllm
90
+ ```
91
+
92
+ 💡 如果您使用其他环境配置,请参考 [vLLM 安装指南](https://docs.vllm.ai/en/latest/getting_started/installation.html)
93
+
94
+ ## 🚀 启动服务
95
+
96
+ ### 启动 MiniMax-M1 服务
97
+
98
+ ```bash
99
+ export SAFETENSORS_FAST_GPU=1
100
+ export VLLM_USE_V1=0
101
+ python3 -m vllm.entrypoints.openai.api_server \
102
+ --model <模型存放路径> \
103
+ --tensor-parallel-size 8 \
104
+ --trust-remote-code \
105
+ --quantization experts_int8 \
106
+ --max_model_len 4096 \
107
+ --dtype bfloat16
108
+ ```
109
+
110
+ ### API 调用示例
111
+
112
+ ```bash
113
+ curl http://localhost:8000/v1/chat/completions \
114
+ -H "Content-Type: application/json" \
115
+ -d '{
116
+ "model": "MiniMaxAI/MiniMax-M1",
117
+ "messages": [
118
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
119
+ {"role": "user", "content": [{"type": "text", "text": "Who won the world series in 2020?"}]}
120
+ ]
121
+ }'
122
+ ```
123
+
124
+ ## ❗ 常见问题
125
+
126
+ ### 模块加载问题
127
+ 如果遇到以下错误:
128
+ ```
129
+ import vllm._C # noqa
130
+ ModuleNotFoundError: No module named 'vllm._C'
131
+ ```
132
+
133
+
134
+
135
+ ```
136
+ 当前并不支持 MiniMax-M1 模型
137
+ ```
138
+
139
+ 我们提供两种解决方案:
140
+
141
+ #### 解决方案一:复制依赖文件
142
+ ```bash
143
+ cd <工作目录>
144
+ git clone https://github.com/vllm-project/vllm.git
145
+ cd vllm
146
+ cp /usr/local/lib/python3.12/dist-packages/vllm/*.so vllm
147
+ cp -r /usr/local/lib/python3.12/dist-packages/vllm/vllm_flash_attn/* vllm/vllm_flash_attn
148
+ ```
149
+
150
+ #### 解决方案二:从源码安装
151
+ ```bash
152
+ cd <工作目录>
153
+ git clone https://github.com/vllm-project/vllm.git
154
+
155
+ cd vllm/
156
+ pip install -e .
157
+ ```
158
+
159
+ ## 📮 获取支持
160
+
161
+ 如果您在部署 MiniMax-M1 模型过程中遇到任何问题:
162
+ - 请查看我们的官方文档
163
+ - 通过官方渠道联系我们的技术支持团队
164
+ - 在我们的 GitHub 仓库提交 [Issue](https://github.com/MiniMax-AI/MiniMax-M1/issues)
165
+
166
+ 我们会持续优化模型的部署体验,欢迎您的反馈!
figures/TextBench.png ADDED
figures/wechat-qrcode.jpeg ADDED
main.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, QuantoConfig, GenerationConfig
2
+ import torch
3
+ import argparse
4
+
5
+ """
6
+ usage:
7
+ export SAFETENSORS_FAST_GPU=1
8
+ python main.py --quant_type int8 --world_size 8 --model_id <model_path>
9
+ """
10
+
11
+ def generate_quanto_config(hf_config: AutoConfig, quant_type: str):
12
+ QUANT_TYPE_MAP = {
13
+ "default": None,
14
+ "int8": QuantoConfig(
15
+ weights="int8",
16
+ modules_to_not_convert=[
17
+ "lm_head",
18
+ "embed_tokens",
19
+ ] + [f"model.layers.{i}.coefficient" for i in range(hf_config.num_hidden_layers)]
20
+ + [f"model.layers.{i}.block_sparse_moe.gate" for i in range(hf_config.num_hidden_layers)]
21
+ ),
22
+ }
23
+ return QUANT_TYPE_MAP[quant_type]
24
+
25
+
26
+ def parse_args():
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument("--quant_type", type=str, default="default", choices=["default", "int8"])
29
+ parser.add_argument("--model_id", type=str, required=True)
30
+ parser.add_argument("--world_size", type=int, required=True)
31
+ return parser.parse_args()
32
+
33
+
34
+ def check_params(args, hf_config: AutoConfig):
35
+ if args.quant_type == "int8":
36
+ assert args.world_size >= 8, "int8 weight-only quantization requires at least 8 GPUs"
37
+
38
+ assert hf_config.num_hidden_layers % args.world_size == 0, f"num_hidden_layers({hf_config.num_hidden_layers}) must be divisible by world_size({args.world_size})"
39
+
40
+
41
+ @torch.no_grad()
42
+ def main():
43
+ args = parse_args()
44
+ print("\n=============== Argument ===============")
45
+ for key in vars(args):
46
+ print(f"{key}: {vars(args)[key]}")
47
+ print("========================================")
48
+
49
+ model_id = args.model_id
50
+
51
+ hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
52
+ check_params(args, hf_config)
53
+ quantization_config = generate_quanto_config(hf_config, args.quant_type)
54
+
55
+ device_map = {
56
+ 'model.embed_tokens': 'cuda:0',
57
+ 'model.norm': f'cuda:{args.world_size - 1}',
58
+ 'lm_head': f'cuda:{args.world_size - 1}'
59
+ }
60
+ layers_per_device = hf_config.num_hidden_layers // args.world_size
61
+ for i in range(args.world_size):
62
+ for j in range(layers_per_device):
63
+ device_map[f'model.layers.{i * layers_per_device + j}'] = f'cuda:{i}'
64
+
65
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
66
+ message = [
67
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
68
+ {"role": "user", "content": [{"type": "text", "text": "Hello, what is the weather today?"}]}
69
+ ]
70
+ tools = [
71
+ {"name": "get_location", "description": "Get the location of the user.", "parameters": {"type": "object", "properties": {}}},
72
+ {"name": "get_weather", "description": "Get the weather of a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The name of the city"}}}},
73
+ {"name": "get_news", "description": "Get the news.", "parameters": {"type": "object", "properties": {"domain": {"type": "string", "description": "The domain of the news"}}}}
74
+ ]
75
+ text = tokenizer.apply_chat_template(
76
+ message,
77
+ tools,
78
+ tokenize=False,
79
+ add_generation_prompt=True
80
+ )
81
+ model_inputs = tokenizer(text, return_tensors="pt").to("cuda")
82
+ quantized_model = AutoModelForCausalLM.from_pretrained(
83
+ model_id,
84
+ torch_dtype="bfloat16",
85
+ device_map=device_map,
86
+ quantization_config=quantization_config,
87
+ trust_remote_code=True,
88
+ offload_buffers=True,
89
+ )
90
+ generation_config = GenerationConfig(
91
+ max_new_tokens=20,
92
+ eos_token_id=200020,
93
+ use_cache=True,
94
+ )
95
+ generated_ids = quantized_model.generate(**model_inputs, generation_config=generation_config)
96
+ print(f"generated_ids: {generated_ids}")
97
+ generated_ids = [
98
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
99
+ ]
100
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
101
+ print(response)
102
+
103
+ if __name__ == "__main__":
104
+ main()
105
+
106
+
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_minimax_m1.py ADDED
@@ -0,0 +1,1701 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch MiniMaxM1 model."""
2
+ import inspect
3
+ import math
4
+ import warnings
5
+ from typing import List, Optional, Tuple, Union
6
+ import os
7
+ import copy
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13
+ from einops import rearrange, repeat
14
+ from transformers.activations import ACT2FN
15
+ from transformers.cache_utils import Cache, DynamicCache
16
+ from transformers.modeling_attn_mask_utils import (
17
+ _prepare_4d_causal_attention_mask,
18
+ )
19
+ from transformers.modeling_outputs import (
20
+ MoeCausalLMOutputWithPast,
21
+ MoeModelOutputWithPast,
22
+ SequenceClassifierOutputWithPast,
23
+ )
24
+ from transformers.modeling_utils import PreTrainedModel
25
+ from transformers.utils import (
26
+ add_start_docstrings,
27
+ add_start_docstrings_to_model_forward,
28
+ is_flash_attn_2_available,
29
+ is_flash_attn_greater_or_equal_2_10,
30
+ logging,
31
+ replace_return_docstrings,
32
+ )
33
+ from transformers.utils.import_utils import is_torch_fx_available
34
+ from .configuration_minimax_m1 import MiniMaxM1Config
35
+
36
+ if is_flash_attn_2_available():
37
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
38
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
39
+
40
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
41
+
42
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
43
+ # It means that the function will not be traced through and simply appear as a node in the graph.
44
+ if is_torch_fx_available():
45
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
46
+
47
+ use_triton = eval(os.environ.get("use_triton", default="False"))
48
+ debug = eval(os.environ.get("debug", default="False"))
49
+ do_eval = eval(os.environ.get("do_eval", default="False"))
50
+ eval_and_not_generate = eval(os.environ.get("eval_and_not_generate", default="False"))
51
+ BLOCK = 256
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+ _CONFIG_FOR_DOC = "MiniMaxM1Config"
56
+
57
+
58
+ def get_activation_fn(activation):
59
+ if debug:
60
+ logger.info(f"activation: {activation}")
61
+ if activation == "gelu":
62
+ return F.gelu
63
+ elif activation == "relu":
64
+ return F.relu
65
+ elif activation == "elu":
66
+ return F.elu
67
+ elif activation == "sigmoid":
68
+ return F.sigmoid
69
+ elif activation == "exp":
70
+
71
+ def f(x):
72
+ with torch.no_grad():
73
+ x_max = torch.max(x, dim=-1, keepdims=True).values
74
+ y = torch.exp(x - x_max)
75
+
76
+ return y
77
+
78
+ return f
79
+ elif activation == "leak":
80
+ return F.leaky_relu
81
+ elif activation == "1+elu":
82
+
83
+ def f(x):
84
+ return 1 + F.elu(x)
85
+
86
+ return f
87
+ elif activation == "2+elu":
88
+
89
+ def f(x):
90
+ return 2 + F.elu(x)
91
+
92
+ return f
93
+ elif activation == "silu" or activation == "swish":
94
+ return F.silu
95
+ elif activation == "sine":
96
+ return torch.sin
97
+ else:
98
+ logger.info(
99
+ f"activation: does not support {activation}, use Identity!!!")
100
+ return lambda x: x
101
+
102
+
103
+ def load_balancing_loss_func(
104
+ gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2,
105
+ attention_mask: Optional[torch.Tensor] = None
106
+ ) -> float:
107
+ r"""
108
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
109
+
110
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
111
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
112
+ experts is too unbalanced.
113
+
114
+ Args:
115
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
116
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
117
+ shape [batch_size X sequence_length, num_experts].
118
+ attention_mask (`torch.Tensor`, None):
119
+ The attention_mask used in forward function
120
+ shape [batch_size X sequence_length] if not None.
121
+ num_experts (`int`, *optional*):
122
+ Number of experts
123
+
124
+ Returns:
125
+ The auxiliary loss.
126
+ """
127
+ if gate_logits is None or not isinstance(gate_logits, tuple):
128
+ return 0
129
+
130
+ if isinstance(gate_logits, tuple):
131
+ compute_device = gate_logits[0].device
132
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
133
+
134
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
135
+
136
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
137
+
138
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
139
+
140
+ if attention_mask is None:
141
+ # Compute the percentage of tokens routed to each experts
142
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
143
+
144
+ # Compute the average probability of routing to these experts
145
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
146
+ else:
147
+ batch_size, sequence_length = attention_mask.shape
148
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
149
+
150
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
151
+ expert_attention_mask = (
152
+ attention_mask[None, :, :, None, None]
153
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
154
+ .reshape(-1, top_k, num_experts)
155
+ .to(compute_device)
156
+ )
157
+
158
+ # Compute the percentage of tokens routed to each experts
159
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
160
+ expert_attention_mask, dim=0
161
+ )
162
+
163
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
164
+ router_per_expert_attention_mask = (
165
+ attention_mask[None, :, :, None]
166
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
167
+ .reshape(-1, num_experts)
168
+ .to(compute_device)
169
+ )
170
+
171
+ # Compute the average probability of routing to these experts
172
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
173
+ router_per_expert_attention_mask, dim=0
174
+ )
175
+
176
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
177
+ return overall_loss * num_experts
178
+
179
+
180
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
181
+ def _get_unpad_data(attention_mask):
182
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
183
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
184
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
185
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
186
+ return (
187
+ indices,
188
+ cu_seqlens,
189
+ max_seqlen_in_batch,
190
+ )
191
+
192
+
193
+ class GLU(nn.Module):
194
+
195
+ def __init__(self, d1, d2, bias=False):
196
+ super().__init__()
197
+
198
+ self.l1 = nn.Linear(d1, d2, bias=bias)
199
+ self.l2 = nn.Linear(d1, d2, bias=bias)
200
+ self.l3 = nn.Linear(d2, d1, bias=bias)
201
+
202
+ def forward(self, x):
203
+ o1 = self.l1(x)
204
+ o2 = self.l2(x)
205
+ output = o1 * o2
206
+ output = self.l3(output)
207
+ return output
208
+
209
+
210
+ class MiniMaxM1LightningAttention(nn.Module):
211
+ def __init__(self, config: MiniMaxM1Config, layer_idx: Optional[int] = None):
212
+ super().__init__()
213
+ bias = False
214
+ self.hidden_size = config.hidden_size
215
+ self.num_heads = config.num_attention_heads
216
+ self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
217
+
218
+ self.out_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=bias)
219
+ self.act = get_activation_fn(config.hidden_act)
220
+ self.norm = MiniMaxM1RMSNorm(self.head_dim * self.num_heads)
221
+
222
+ self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias)
223
+ self.output_gate = nn.Linear(self.hidden_size, self.head_dim * self.num_heads, bias=bias)
224
+
225
+ # for inference only
226
+ self.offset = 0
227
+ self.layer_idx = layer_idx
228
+
229
+ def forward(
230
+ self,
231
+ hidden_states,
232
+ attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
233
+ output_attentions: bool = False,
234
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
235
+ use_cache: bool = False,
236
+ slope_rate: Optional[torch.Tensor] = None,
237
+ **kwargs
238
+ ):
239
+ if (not self.training) and (not do_eval):
240
+ return self.inference(
241
+ hidden_states,
242
+ attn_mask,
243
+ output_attentions,
244
+ past_key_value,
245
+ use_cache,
246
+ slope_rate,
247
+ )
248
+
249
+ def inference(
250
+ self,
251
+ x,
252
+ attn_mask: Optional[torch.Tensor] = None, # (b, n)
253
+ output_attentions: bool = False,
254
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
255
+ use_cache: bool = False,
256
+ slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
257
+ ):
258
+ # x: b n d
259
+ b, n, d = x.shape
260
+ # linear map
261
+ qkv = self.act(self.qkv_proj(x))
262
+ new_shape = qkv.size()[:-1] + (self.num_heads, -1)
263
+ qkv = qkv.view(*new_shape)
264
+ q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3)
265
+ q = q.transpose(1, 2)
266
+ k = k.transpose(1, 2)
267
+ v = v.transpose(1, 2)
268
+
269
+ if past_key_value is None:
270
+ self.offset = q.shape[-2]
271
+ else:
272
+ self.offset += 1
273
+
274
+ # for align with metaseq
275
+ ratio = torch.exp(-slope_rate)
276
+
277
+ # only use for the first time
278
+ if past_key_value is None:
279
+ slope_rate = slope_rate.to(torch.float32)
280
+ if attn_mask is not None:
281
+ v = v.masked_fill((1 - attn_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0)
282
+ NUM_BLOCK = (n + BLOCK - 1) // BLOCK
283
+ b, h, n, d = q.shape
284
+ e = v.shape[-1]
285
+ # other
286
+ array = torch.arange(BLOCK).to(q) + 1
287
+ q_decay = torch.exp(-slope_rate * array.reshape(-1, 1))
288
+ k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1)))
289
+ index = array[:, None] - array[None, :]
290
+ s_index = slope_rate * index[
291
+ None,
292
+ None,
293
+ ]
294
+ s_index = torch.where(index >= 0, -s_index, float("-inf"))
295
+ diag_decay = torch.exp(s_index)
296
+
297
+ kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device)
298
+ output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
299
+ for i in range(NUM_BLOCK):
300
+ si = i * BLOCK
301
+ ei = min(si + BLOCK, n)
302
+ m = ei - si
303
+ qi = q[:, :, si:ei].contiguous()
304
+ ki = k[:, :, si:ei].contiguous()
305
+ vi = v[:, :, si:ei].contiguous()
306
+ qkv_none_diag = torch.matmul(qi * q_decay[:, :m], kv).to(torch.float32)
307
+
308
+ # diag
309
+ qk = torch.matmul(qi, ki.transpose(-1, -2)).to(torch.float32) * diag_decay[:, :, :m, :m]
310
+ qkv_diag = torch.matmul(qk, vi.to(torch.float32))
311
+ block_decay = torch.exp(-slope_rate * m)
312
+ output[:, :, si:ei] = qkv_none_diag + qkv_diag
313
+ kv = block_decay * kv + torch.matmul((ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi)
314
+
315
+ else:
316
+ kv = past_key_value
317
+ output = []
318
+ for i in range(n):
319
+ kv = ratio * kv + torch.einsum(
320
+ "... n d, ... n e -> ... d e",
321
+ k[:, :, i:i + 1],
322
+ v[:, :, i:i + 1],
323
+ )
324
+ qkv = torch.einsum("... n e, ... e d -> ... n d", q[:, :, i:i + 1], kv.to(q.dtype))
325
+ output.append(qkv)
326
+ output = torch.concat(output, dim=-2)
327
+ # reshape
328
+ output = rearrange(output, "b h n d -> b n (h d)")
329
+ # normalize
330
+ output = self.norm(output)
331
+ # gate
332
+ output = F.sigmoid(self.output_gate(x)) * output
333
+ # outproj
334
+ output = self.out_proj(output)
335
+
336
+ attn_weights = None
337
+
338
+ return output, attn_weights, kv
339
+
340
+
341
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxM1
342
+ class MiniMaxM1RMSNorm(nn.Module):
343
+ def __init__(self, hidden_size, eps=1e-6):
344
+ """
345
+ MiniMaxM1RMSNorm is equivalent to T5LayerNorm
346
+ """
347
+ super().__init__()
348
+ self.weight = nn.Parameter(torch.ones(hidden_size))
349
+ self.variance_epsilon = eps
350
+
351
+ def forward(self, hidden_states):
352
+ input_dtype = hidden_states.dtype
353
+ hidden_states = hidden_states.to(torch.float32)
354
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
355
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
356
+ return self.weight * hidden_states.to(input_dtype)
357
+
358
+
359
+ # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->MiniMaxM1
360
+ class MiniMaxM1RotaryEmbedding(nn.Module):
361
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
362
+ super().__init__()
363
+
364
+ self.dim = dim
365
+ self.max_position_embeddings = max_position_embeddings
366
+ self.base = base
367
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
368
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
369
+
370
+ # Build here to make `torch.jit.trace` work.
371
+ self._set_cos_sin_cache(
372
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
373
+ )
374
+
375
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
376
+ self.max_seq_len_cached = seq_len
377
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
378
+
379
+ freqs = torch.outer(t, self.inv_freq)
380
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
381
+ emb = torch.cat((freqs, freqs), dim=-1)
382
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
383
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
384
+
385
+ def forward(self, x, seq_len=None):
386
+ # x: [bs, num_attention_heads, seq_len, head_size]
387
+ if seq_len > self.max_seq_len_cached:
388
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
389
+
390
+ return (
391
+ self.cos_cached[:seq_len].to(dtype=torch.float32),
392
+ self.sin_cached[:seq_len].to(dtype=torch.float32),
393
+ )
394
+
395
+
396
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
397
+ def rotate_half(x):
398
+ """Rotates half the hidden dims of the input."""
399
+ x1 = x[..., : x.shape[-1] // 2]
400
+ x2 = x[..., x.shape[-1] // 2:]
401
+ return torch.cat((-x2, x1), dim=-1)
402
+
403
+
404
+ # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
405
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
406
+ """Applies Rotary Position Embedding to the query and key tensors.
407
+
408
+ Args:
409
+ q (`torch.Tensor`): The query tensor.
410
+ k (`torch.Tensor`): The key tensor.
411
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
412
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
413
+ position_ids (`torch.Tensor`):
414
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
415
+ used to pass offsetted position ids when working with a KV-cache.
416
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
417
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
418
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
419
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
420
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
421
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
422
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
423
+ Returns:
424
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
425
+ """
426
+ dtype = q.dtype
427
+ rot_dim = cos.shape[-1]
428
+ q_, q_pass = q[..., :rot_dim], q[..., rot_dim:]
429
+ k_, k_pass = k[..., :rot_dim], k[..., rot_dim:]
430
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
431
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
432
+ q_embed = (q_ * cos) + (rotate_half(q_) * sin)
433
+ k_embed = (k_ * cos) + (rotate_half(k_) * sin)
434
+ return torch.cat((q_embed, q_pass), dim=-1).to(dtype), torch.cat((k_embed, k_pass), dim=-1).to(dtype)
435
+
436
+
437
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
438
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
439
+ """
440
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
441
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
442
+ """
443
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
444
+ if n_rep == 1:
445
+ return hidden_states
446
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
447
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
448
+
449
+
450
+ # Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->MiniMaxM1
451
+ class MiniMaxM1Attention(nn.Module):
452
+ """
453
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
454
+ and "Generating Long Sequences with Sparse Transformers".
455
+ """
456
+
457
+ def __init__(self, config: MiniMaxM1Config, layer_idx: Optional[int] = None):
458
+ super().__init__()
459
+ self.config = config
460
+ self.layer_idx = layer_idx
461
+ if layer_idx is None:
462
+ logger.warning_once(
463
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
464
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
465
+ "when creating this class."
466
+ )
467
+
468
+ self.hidden_size = config.hidden_size
469
+ self.num_heads = config.num_attention_heads
470
+ self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
471
+ self.num_key_value_heads = config.num_key_value_heads
472
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
473
+ self.max_position_embeddings = config.max_position_embeddings
474
+ self.rope_theta = config.rope_theta
475
+ self.is_causal = True
476
+ self.attention_dropout = config.attention_dropout
477
+
478
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
479
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
480
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
481
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
482
+ self.rotary_dim = getattr(config, 'rotary_dim', self.head_dim)
483
+
484
+ self.rotary_emb = MiniMaxM1RotaryEmbedding(
485
+ self.rotary_dim,
486
+ max_position_embeddings=self.max_position_embeddings,
487
+ base=self.rope_theta,
488
+ )
489
+
490
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
491
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
492
+
493
+ def forward(
494
+ self,
495
+ hidden_states: torch.Tensor,
496
+ attention_mask: Optional[torch.Tensor] = None,
497
+ position_ids: Optional[torch.LongTensor] = None,
498
+ past_key_value: Optional[Cache] = None,
499
+ output_attentions: bool = False,
500
+ use_cache: bool = False,
501
+ **kwargs,
502
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
503
+ if "padding_mask" in kwargs:
504
+ warnings.warn(
505
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
506
+ )
507
+ bsz, q_len, _ = hidden_states.size()
508
+
509
+ query_states = self.q_proj(hidden_states)
510
+ key_states = self.k_proj(hidden_states)
511
+ value_states = self.v_proj(hidden_states)
512
+
513
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
514
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
515
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
516
+
517
+ kv_seq_len = key_states.shape[-2]
518
+ if past_key_value is not None:
519
+ if self.layer_idx is None:
520
+ raise ValueError(
521
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
522
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
523
+ "with a layer index."
524
+ )
525
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
526
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
527
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
528
+
529
+ if past_key_value is not None:
530
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
531
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
532
+
533
+ # repeat k/v heads if n_kv_heads < n_heads
534
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
535
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
536
+
537
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
538
+
539
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
540
+ raise ValueError(
541
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
542
+ f" {attn_weights.size()}"
543
+ )
544
+
545
+ if attention_mask is not None:
546
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
547
+ raise ValueError(
548
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
549
+ )
550
+
551
+ attn_weights = attn_weights + attention_mask
552
+
553
+ # upcast attention to fp32
554
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
555
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
556
+ attn_output = torch.matmul(attn_weights, value_states)
557
+
558
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
559
+ raise ValueError(
560
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
561
+ f" {attn_output.size()}"
562
+ )
563
+
564
+ attn_output = attn_output.transpose(1, 2).contiguous()
565
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
566
+
567
+ attn_output = self.o_proj(attn_output)
568
+
569
+ if not output_attentions:
570
+ attn_weights = None
571
+
572
+ return attn_output, attn_weights, past_key_value
573
+
574
+
575
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->MiniMaxM1
576
+ class MiniMaxM1FlashAttention2(MiniMaxM1Attention):
577
+ """
578
+ MiniMaxM1 flash attention module. This module inherits from `MiniMaxM1Attention` as the weights of the module stays
579
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
580
+ flash attention and deal with padding tokens in case the input contains any of them.
581
+ """
582
+
583
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
584
+ def __init__(self, *args, **kwargs):
585
+ super().__init__(*args, **kwargs)
586
+
587
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
588
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
589
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
590
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
591
+
592
+ def forward(
593
+ self,
594
+ hidden_states: torch.Tensor,
595
+ attention_mask: Optional[torch.Tensor] = None,
596
+ position_ids: Optional[torch.LongTensor] = None,
597
+ past_key_value: Optional[Union[Cache, Tuple[torch.Tensor]]] = None,
598
+ output_attentions: bool = False,
599
+ use_cache: bool = False,
600
+ **kwargs,
601
+ ):
602
+ if "padding_mask" in kwargs:
603
+ warnings.warn(
604
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
605
+ )
606
+
607
+ # overwrite attention_mask with padding_mask
608
+ attention_mask = kwargs.pop("padding_mask")
609
+ bsz, q_len, _ = hidden_states.size()
610
+
611
+ query_states = self.q_proj(hidden_states)
612
+ key_states = self.k_proj(hidden_states)
613
+ value_states = self.v_proj(hidden_states)
614
+
615
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
616
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
617
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
618
+
619
+ kv_seq_len = key_states.shape[-2]
620
+ if past_key_value is not None:
621
+ kv_seq_len += past_key_value[0].shape[-3]
622
+
623
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
624
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
625
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
626
+
627
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
628
+
629
+ use_sliding_windows = (
630
+ _flash_supports_window_size
631
+ and getattr(self.config, "sliding_window", None) is not None
632
+ and kv_seq_len > self.config.sliding_window
633
+ )
634
+
635
+ if not _flash_supports_window_size:
636
+ logger.warning_once(
637
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
638
+ " make sure to upgrade flash-attn library."
639
+ )
640
+
641
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
642
+
643
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
644
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
645
+ # cast them back in float16 just to be sure everything works as expected.
646
+ input_dtype = query_states.dtype
647
+ if input_dtype == torch.float32:
648
+ if torch.is_autocast_enabled():
649
+ target_dtype = torch.get_autocast_gpu_dtype()
650
+ # Handle the case where the model is quantized
651
+ elif hasattr(self.config, "_pre_quantization_dtype"):
652
+ target_dtype = self.config._pre_quantization_dtype
653
+ else:
654
+ target_dtype = self.q_proj.weight.dtype
655
+
656
+ logger.warning_once(
657
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
658
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
659
+ f" {target_dtype}."
660
+ )
661
+
662
+ query_states = query_states.to(target_dtype)
663
+ key_states = key_states.to(target_dtype)
664
+ value_states = value_states.to(target_dtype)
665
+
666
+ # Reshape to the expected shape for Flash Attention
667
+ query_states = query_states.transpose(1, 2)
668
+ key_states = key_states.transpose(1, 2)
669
+ value_states = value_states.transpose(1, 2)
670
+
671
+ if past_key_value is not None:
672
+ # reuse k, v, for evaluation only
673
+ key_states = torch.cat([past_key_value[0], key_states], dim=-3)
674
+ value_states = torch.cat([past_key_value[1], value_states], dim=-3)
675
+
676
+ past_key_value = (key_states, value_states) if use_cache else None
677
+
678
+ attn_output = self._flash_attention_forward(
679
+ query_states,
680
+ key_states,
681
+ value_states,
682
+ attention_mask,
683
+ q_len,
684
+ dropout=dropout_rate,
685
+ use_sliding_windows=use_sliding_windows,
686
+ )
687
+
688
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
689
+ attn_output = self.o_proj(attn_output)
690
+
691
+ if not output_attentions:
692
+ attn_weights = None
693
+
694
+ return attn_output, attn_weights, past_key_value
695
+
696
+ def _flash_attention_forward(
697
+ self,
698
+ query_states,
699
+ key_states,
700
+ value_states,
701
+ attention_mask,
702
+ query_length,
703
+ dropout=0.0,
704
+ softmax_scale=None,
705
+ use_sliding_windows=False,
706
+ ):
707
+ """
708
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
709
+ first unpad the input, then computes the attention scores and pad the final attention scores.
710
+
711
+ Args:
712
+ query_states (`torch.Tensor`):
713
+ Input query states to be passed to Flash Attention API
714
+ key_states (`torch.Tensor`):
715
+ Input key states to be passed to Flash Attention API
716
+ value_states (`torch.Tensor`):
717
+ Input value states to be passed to Flash Attention API
718
+ attention_mask (`torch.Tensor`):
719
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
720
+ position of padding tokens and 1 for the position of non-padding tokens.
721
+ dropout (`float`):
722
+ Attention dropout
723
+ softmax_scale (`float`, *optional*):
724
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
725
+ use_sliding_windows (`bool`, *optional*):
726
+ Whether to activate sliding window attention.
727
+ """
728
+ if not self._flash_attn_uses_top_left_mask:
729
+ causal = self.is_causal
730
+ else:
731
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
732
+ causal = self.is_causal and query_length != 1
733
+
734
+ # Contains at least one padding token in the sequence
735
+ if attention_mask is not None:
736
+ batch_size = query_states.shape[0]
737
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
738
+ query_states, key_states, value_states, attention_mask, query_length
739
+ )
740
+
741
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
742
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
743
+
744
+ if not use_sliding_windows:
745
+ attn_output_unpad = flash_attn_varlen_func(
746
+ query_states,
747
+ key_states,
748
+ value_states,
749
+ cu_seqlens_q=cu_seqlens_q,
750
+ cu_seqlens_k=cu_seqlens_k,
751
+ max_seqlen_q=max_seqlen_in_batch_q,
752
+ max_seqlen_k=max_seqlen_in_batch_k,
753
+ dropout_p=dropout,
754
+ softmax_scale=softmax_scale,
755
+ causal=causal,
756
+ )
757
+ else:
758
+ attn_output_unpad = flash_attn_varlen_func(
759
+ query_states,
760
+ key_states,
761
+ value_states,
762
+ cu_seqlens_q=cu_seqlens_q,
763
+ cu_seqlens_k=cu_seqlens_k,
764
+ max_seqlen_q=max_seqlen_in_batch_q,
765
+ max_seqlen_k=max_seqlen_in_batch_k,
766
+ dropout_p=dropout,
767
+ softmax_scale=softmax_scale,
768
+ causal=causal,
769
+ window_size=(self.config.sliding_window, self.config.sliding_window),
770
+ )
771
+
772
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
773
+ else:
774
+ if not use_sliding_windows:
775
+ attn_output = flash_attn_func(
776
+ query_states,
777
+ key_states,
778
+ value_states,
779
+ dropout,
780
+ softmax_scale=softmax_scale,
781
+ causal=causal,
782
+ )
783
+ else:
784
+ attn_output = flash_attn_func(
785
+ query_states,
786
+ key_states,
787
+ value_states,
788
+ dropout,
789
+ softmax_scale=softmax_scale,
790
+ causal=causal,
791
+ window_size=(self.config.sliding_window, self.config.sliding_window),
792
+ )
793
+
794
+ return attn_output
795
+
796
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
797
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
798
+
799
+ # On the first iteration we need to properly re-create the padding mask
800
+ # by slicing it on the proper place
801
+ if kv_seq_len != attention_mask.shape[-1]:
802
+ attention_mask_num_tokens = attention_mask.shape[-1]
803
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len:]
804
+
805
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
806
+
807
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
808
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
809
+
810
+ if query_length == kv_seq_len:
811
+ query_layer = index_first_axis(
812
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
813
+ )
814
+ cu_seqlens_q = cu_seqlens_k
815
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
816
+ indices_q = indices_k
817
+ elif query_length == 1:
818
+ max_seqlen_in_batch_q = 1
819
+ cu_seqlens_q = torch.arange(
820
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
821
+ ) # There is a memcpy here, that is very bad.
822
+ indices_q = cu_seqlens_q[:-1]
823
+ query_layer = query_layer.squeeze(1)
824
+ else:
825
+ # The -q_len: slice assumes left padding.
826
+ attention_mask = attention_mask[:, -query_length:]
827
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
828
+
829
+ return (
830
+ query_layer,
831
+ key_layer,
832
+ value_layer,
833
+ indices_q,
834
+ (cu_seqlens_q, cu_seqlens_k),
835
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
836
+ )
837
+
838
+
839
+ class MiniMaxM1MLP(nn.Module):
840
+ def __init__(self, config):
841
+ super().__init__()
842
+ self.config = config
843
+ self.hidden_size = config.hidden_size
844
+ self.intermediate_size = config.intermediate_size
845
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
846
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
847
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
848
+ self.act_fn = ACT2FN[config.hidden_act]
849
+
850
+ def forward(self, x):
851
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
852
+ return down_proj
853
+
854
+
855
+ class MiniMaxM1BlockSparseTop2MLP(nn.Module):
856
+ def __init__(self, config: MiniMaxM1Config):
857
+ super().__init__()
858
+ self.ffn_dim = config.intermediate_size
859
+ self.hidden_dim = config.hidden_size
860
+
861
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
862
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
863
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
864
+
865
+ self.act_fn = ACT2FN[config.hidden_act]
866
+
867
+ def forward(self, hidden_states):
868
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
869
+ current_hidden_states = self.w2(current_hidden_states)
870
+ return current_hidden_states
871
+
872
+
873
+ class MiniMaxM1BLockSparseTop2MLP(MiniMaxM1BlockSparseTop2MLP):
874
+ def __init__(self, *args, **kwargs):
875
+ logger.warning_once(
876
+ "MiniMaxM1BLockSparseTop2MLP is deprecated by MiniMaxM1BlockSparseTop2MLP and will be removed in v4.40."
877
+ )
878
+ super().__init__(*args, **kwargs)
879
+
880
+
881
+ class MiniMaxM1SparseMoeBlock(nn.Module):
882
+ """
883
+ This implementation is
884
+ strictly equivalent to standard MoE with full capacity (no
885
+ dropped tokens). It's faster since it formulates MoE operations
886
+ in terms of block-sparse operations to accomodate imbalanced
887
+ assignments of tokens to experts, whereas standard MoE either
888
+ (1) drop tokens at the cost of reduced performance or (2) set
889
+ capacity factor to number of experts and thus waste computation
890
+ and memory on padding.
891
+ """
892
+
893
+ def __init__(self, config):
894
+ super().__init__()
895
+ self.hidden_dim = config.hidden_size
896
+ self.ffn_dim = config.intermediate_size
897
+ self.num_experts = config.num_local_experts
898
+ self.top_k = config.num_experts_per_tok
899
+
900
+ # gating
901
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
902
+
903
+ self.experts = nn.ModuleList([MiniMaxM1BlockSparseTop2MLP(config) for _ in range(self.num_experts)])
904
+
905
+ # Jitter parameters
906
+ self.jitter_noise = config.router_jitter_noise
907
+
908
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
909
+ """ """
910
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
911
+ if self.training and self.jitter_noise > 0:
912
+ hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
913
+ hidden_states = hidden_states.view(-1, hidden_dim)
914
+ # router_logits: (batch * sequence_length, n_experts)
915
+ router_logits = self.gate(hidden_states)
916
+
917
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
918
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
919
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
920
+ # we cast back to the input dtype
921
+ routing_weights = routing_weights.to(hidden_states.dtype)
922
+
923
+ final_hidden_states = torch.zeros(
924
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
925
+ )
926
+
927
+ # One hot encode the selected experts to create an expert mask
928
+ # this will be used to easily index which expert is going to be sollicitated
929
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
930
+
931
+ # Loop over all available experts in the model and perform the computation on each expert
932
+ for expert_idx in range(self.num_experts):
933
+ expert_layer = self.experts[expert_idx]
934
+ idx, top_x = torch.where(expert_mask[expert_idx])
935
+
936
+ # Index the correct hidden states and compute the expert hidden state for
937
+ # the current expert. We need to make sure to multiply the output hidden
938
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
939
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
940
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
941
+
942
+ # However `index_add_` only support torch tensors for indexing so we'll use
943
+ # the `top_x` tensor here.
944
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
945
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
946
+ return final_hidden_states, router_logits
947
+
948
+
949
+ class MiniMaxM1DecoderLayer(nn.Module):
950
+ def __init__(self, config: MiniMaxM1Config, layer_idx: int):
951
+ super().__init__()
952
+ self.config = config
953
+ self.hidden_size = config.hidden_size
954
+
955
+ self.self_attn = self.build_attn(config, layer_idx)
956
+
957
+ self.layer_idx = layer_idx
958
+
959
+ self.block_sparse_moe = MiniMaxM1SparseMoeBlock(config)
960
+ self.input_layernorm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
961
+ self.post_attention_layernorm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
962
+
963
+ self.postnorm = getattr(config, 'postnorm', False)
964
+ self.layernorm_attention_alpha = getattr(config, 'layernorm_linear_attention_alpha', 1) \
965
+ if config.attention_type == 0 else getattr(config, 'layernorm_full_attention_alpha', 1)
966
+ self.layernorm_attention_beta = getattr(config, 'layernorm_linear_attention_beta', 1) \
967
+ if config.attention_type == 0 else getattr(config, 'layernorm_full_attention_beta', 1)
968
+ self.layernorm_mlp_alpha = getattr(config, 'layernorm_mlp_alpha', 1)
969
+ self.layernorm_mlp_beta = getattr(config, 'layernorm_mlp_beta', 1)
970
+
971
+ shared_intermediate = getattr(config, 'shared_intermediate_size', 0)
972
+ self.shared_moe = False
973
+ if shared_intermediate > 0:
974
+ self.shared_moe = True
975
+ self.shared_mlp = MiniMaxM1MLP(config)
976
+ self.coefficient = torch.nn.Linear(self.hidden_size, 1, bias=False)
977
+
978
+ def build_attn(self, config, layer_idx):
979
+ if config.attention_type == 0:
980
+ Attention_module = MiniMaxM1LightningAttention
981
+ else:
982
+ Attention_module = MiniMaxM1FlashAttention2
983
+
984
+ return Attention_module(
985
+ config,
986
+ layer_idx
987
+ )
988
+
989
+ def forward(
990
+ self,
991
+ hidden_states: torch.Tensor,
992
+ attention_mask: Optional[torch.Tensor] = None,
993
+ position_ids: Optional[torch.LongTensor] = None,
994
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
995
+ output_attentions: Optional[bool] = False,
996
+ output_router_logits: Optional[bool] = False,
997
+ use_cache: Optional[bool] = False,
998
+ slope_rate: Optional[float] = None,
999
+ **kwargs,
1000
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1001
+ if "padding_mask" in kwargs:
1002
+ warnings.warn(
1003
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1004
+ )
1005
+ """
1006
+ Args:
1007
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1008
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1009
+ `(batch, sequence_length)` where padding elements are indicated by 0.
1010
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1011
+ output_attentions (`bool`, *optional*):
1012
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1013
+ returned tensors for more detail.
1014
+ output_router_logits (`bool`, *optional*):
1015
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
1016
+ should not be returned during inference.
1017
+ use_cache (`bool`, *optional*):
1018
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1019
+ (see `past_key_values`).
1020
+ """
1021
+
1022
+ residual = hidden_states
1023
+
1024
+ hidden_states = self.input_layernorm(hidden_states)
1025
+ if self.postnorm:
1026
+ residual = hidden_states
1027
+
1028
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1029
+ hidden_states=hidden_states,
1030
+ position_ids=position_ids,
1031
+ attn_mask=attention_mask,
1032
+ past_key_value=past_key_value,
1033
+ output_attentions=output_attentions,
1034
+ use_cache=use_cache,
1035
+ slope_rate=slope_rate,
1036
+ )
1037
+
1038
+ hidden_states = residual * self.layernorm_attention_alpha \
1039
+ + hidden_states * self.layernorm_attention_beta
1040
+
1041
+ # Fully Connected
1042
+ residual = hidden_states
1043
+ hidden_states = self.post_attention_layernorm(hidden_states)
1044
+ if self.postnorm:
1045
+ residual = hidden_states
1046
+
1047
+ moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
1048
+ if self.shared_moe:
1049
+ output_mlp = self.shared_mlp(hidden_states)
1050
+ weight_fp32 = self.coefficient.weight.float()
1051
+ coef = hidden_states.to(torch.float32) @ weight_fp32.T
1052
+ coef = torch.nn.functional.sigmoid(coef).to(hidden_states.dtype)
1053
+ hidden_states = moe_hidden_states * (1 - coef) + output_mlp * coef
1054
+ else:
1055
+ hidden_states = moe_hidden_states
1056
+
1057
+ hidden_states = residual * self.layernorm_mlp_alpha \
1058
+ + hidden_states * self.layernorm_mlp_beta
1059
+
1060
+ outputs = (hidden_states,)
1061
+
1062
+ if output_attentions:
1063
+ outputs += (self_attn_weights,)
1064
+
1065
+ if use_cache:
1066
+ outputs += (present_key_value,)
1067
+
1068
+ if output_router_logits:
1069
+ outputs += (router_logits,)
1070
+
1071
+ return outputs
1072
+
1073
+
1074
+ MIXTRAL_START_DOCSTRING = r"""
1075
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1076
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1077
+ etc.)
1078
+
1079
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1080
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1081
+ and behavior.
1082
+
1083
+ Parameters:
1084
+ config ([`MiniMaxM1Config`]):
1085
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1086
+ load the weights associated with the model, only the configuration. Check out the
1087
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1088
+ """
1089
+
1090
+
1091
+ @add_start_docstrings(
1092
+ "The bare MiniMaxM1 Model outputting raw hidden-states without any specific head on top.",
1093
+ MIXTRAL_START_DOCSTRING,
1094
+ )
1095
+ # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->MiniMaxM1
1096
+ class MiniMaxM1PreTrainedModel(PreTrainedModel):
1097
+ config_class = MiniMaxM1Config
1098
+ base_model_prefix = "model"
1099
+ supports_gradient_checkpointing = True
1100
+ _no_split_modules = ["MiniMaxM1DecoderLayer"]
1101
+ _skip_keys_device_placement = "past_key_values"
1102
+ _supports_flash_attn_2 = True
1103
+ _supports_sdpa = True
1104
+
1105
+ def _init_weights(self, module):
1106
+ std = self.config.initializer_range
1107
+ if isinstance(module, nn.Linear):
1108
+ module.weight.data.normal_(mean=0.0, std=std)
1109
+ if module.bias is not None:
1110
+ module.bias.data.zero_()
1111
+ elif isinstance(module, nn.Embedding):
1112
+ module.weight.data.normal_(mean=0.0, std=std)
1113
+ if module.padding_idx is not None:
1114
+ module.weight.data[module.padding_idx].zero_()
1115
+
1116
+
1117
+ MIXTRAL_INPUTS_DOCSTRING = r"""
1118
+ Args:
1119
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1120
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1121
+ it.
1122
+
1123
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1124
+ [`PreTrainedTokenizer.__call__`] for details.
1125
+
1126
+ [What are input IDs?](../glossary#input-ids)
1127
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1128
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1129
+
1130
+ - 1 for tokens that are **not masked**,
1131
+ - 0 for tokens that are **masked**.
1132
+
1133
+ [What are attention masks?](../glossary#attention-mask)
1134
+
1135
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1136
+ [`PreTrainedTokenizer.__call__`] for details.
1137
+
1138
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1139
+ `past_key_values`).
1140
+
1141
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1142
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1143
+ information on the default strategy.
1144
+
1145
+ - 1 indicates the head is **not masked**,
1146
+ - 0 indicates the head is **masked**.
1147
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1148
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1149
+ config.n_positions - 1]`.
1150
+
1151
+ [What are position IDs?](../glossary#position-ids)
1152
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1153
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1154
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
1155
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1156
+
1157
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1158
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1159
+
1160
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1161
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1162
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1163
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1164
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1165
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1166
+ model's internal embedding lookup matrix.
1167
+ use_cache (`bool`, *optional*):
1168
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1169
+ `past_key_values`).
1170
+ output_attentions (`bool`, *optional*):
1171
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1172
+ tensors for more detail.
1173
+ output_hidden_states (`bool`, *optional*):
1174
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1175
+ more detail.
1176
+ output_router_logits (`bool`, *optional*):
1177
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
1178
+ should not be returned during inference.
1179
+ return_dict (`bool`, *optional*):
1180
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1181
+ """
1182
+
1183
+
1184
+ @add_start_docstrings(
1185
+ "The bare MiniMaxM1 Model outputting raw hidden-states without any specific head on top.",
1186
+ MIXTRAL_START_DOCSTRING,
1187
+ )
1188
+ # Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->MiniMaxM1
1189
+ class MiniMaxM1Model(MiniMaxM1PreTrainedModel):
1190
+ """
1191
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniMaxM1DecoderLayer`]
1192
+
1193
+ Args:
1194
+ config: MiniMaxM1Config
1195
+ """
1196
+
1197
+ def __init__(self, config: MiniMaxM1Config):
1198
+ super().__init__(config)
1199
+ self.padding_idx = config.pad_token_id
1200
+ self.vocab_size = config.vocab_size
1201
+
1202
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1203
+ self.attn_type_list = config.attn_type_list
1204
+ config_copy = copy.deepcopy(config)
1205
+
1206
+ self.layers = nn.ModuleList([])
1207
+ for i in range(config.num_hidden_layers):
1208
+ _config = copy.deepcopy(config)
1209
+ if self.attn_type_list[i] == 0:
1210
+ _config._attn_implementation = 'linear_attention'
1211
+ _config.attention_type = 0
1212
+ else:
1213
+ _config._attn_implementation = config_copy._attn_implementation
1214
+ _config.attention_type = 1
1215
+ self.layers.append(MiniMaxM1DecoderLayer(_config, i))
1216
+
1217
+ self._attn_implementation = config_copy._attn_implementation
1218
+ self.norm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1219
+
1220
+ self.gradient_checkpointing = False
1221
+ self.slopes = self._build_slope_tensor(config.num_attention_heads)
1222
+ # mask
1223
+ self._linear_attn_mask = torch.empty(0)
1224
+
1225
+ # Initialize weights and apply final processing
1226
+ self.post_init()
1227
+
1228
+ def get_input_embeddings(self):
1229
+ return self.embed_tokens
1230
+
1231
+ def set_input_embeddings(self, value):
1232
+ self.embed_tokens = value
1233
+
1234
+ @staticmethod
1235
+ def _build_slope_tensor(n_attention_heads: int):
1236
+
1237
+ def get_slopes(n):
1238
+
1239
+ def get_slopes_power_of_2(n):
1240
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
1241
+ ratio = start
1242
+ return [start * ratio ** i for i in range(n)]
1243
+
1244
+ if math.log2(n).is_integer():
1245
+ return get_slopes_power_of_2(
1246
+ n) # In the paper, we only train models that have 2^a heads for some a. This function has
1247
+ else: # some good properties that only occur when the input is a power of 2. To maintain that even
1248
+ closest_power_of_2 = 2 ** math.floor(
1249
+ math.log2(n)) # when the number of heads is not a power of 2, we use this workaround.
1250
+ return (get_slopes_power_of_2(closest_power_of_2)
1251
+ + get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
1252
+
1253
+ # h, 1, 1
1254
+ slopes = torch.tensor(get_slopes(n_attention_heads), dtype=torch.float32).reshape(n_attention_heads, 1, 1)
1255
+
1256
+ return slopes
1257
+
1258
+ # Ignore copy
1259
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1260
+ def forward(
1261
+ self,
1262
+ input_ids: torch.LongTensor = None,
1263
+ attention_mask: Optional[torch.Tensor] = None,
1264
+ position_ids: Optional[torch.LongTensor] = None,
1265
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1266
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1267
+ use_cache: Optional[bool] = None,
1268
+ output_attentions: Optional[bool] = None,
1269
+ output_hidden_states: Optional[bool] = None,
1270
+ output_router_logits: Optional[bool] = None,
1271
+ return_dict: Optional[bool] = None,
1272
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
1273
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1274
+ output_router_logits = (
1275
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1276
+ )
1277
+ output_hidden_states = (
1278
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1279
+ )
1280
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1281
+
1282
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1283
+
1284
+ # retrieve input_ids and inputs_embeds
1285
+ if input_ids is not None and inputs_embeds is not None:
1286
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1287
+ elif input_ids is not None:
1288
+ batch_size, seq_length = input_ids.shape
1289
+ default_device = input_ids.device
1290
+ elif inputs_embeds is not None:
1291
+ batch_size, seq_length, _ = inputs_embeds.shape
1292
+ default_device = inputs_embeds.device
1293
+ else:
1294
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1295
+
1296
+ past_key_values_length = 0
1297
+
1298
+ if self.gradient_checkpointing and self.training:
1299
+ if use_cache:
1300
+ logger.warning_once(
1301
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1302
+ )
1303
+ use_cache = False
1304
+
1305
+ seq_length_with_past = seq_length
1306
+ if past_key_values is not None:
1307
+ for idx in range(len(past_key_values)):
1308
+ if self.attn_type_list[idx] == 1:
1309
+ past_key_values_length = past_key_values[idx][0].shape[-3]
1310
+ seq_length_with_past = seq_length_with_past + past_key_values_length
1311
+ break
1312
+
1313
+ if position_ids is None:
1314
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1315
+ position_ids = torch.arange(
1316
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1317
+ )
1318
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1319
+ else:
1320
+ position_ids = position_ids.view(-1, seq_length).long()
1321
+
1322
+ if inputs_embeds is None:
1323
+ inputs_embeds = self.embed_tokens(input_ids)
1324
+
1325
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1326
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1327
+ if is_padding_right:
1328
+ raise ValueError(
1329
+ "You are attempting to perform batched generation with padding_side='right'"
1330
+ " this may lead to unexpected behaviour for Flash Attention version of MiniMaxM1. Make sure to "
1331
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1332
+ )
1333
+ slope_rates = [self.slopes.to(default_device) for _ in range(len(self.layers))]
1334
+ hidden_states = inputs_embeds
1335
+ # decoder layers
1336
+ all_hidden_states = () if output_hidden_states else None
1337
+ all_self_attns = () if output_attentions else None
1338
+ all_router_logits = () if output_router_logits else None
1339
+ next_decoder_cache = () if use_cache else None
1340
+
1341
+ for idx, decoder_layer in enumerate(self.layers):
1342
+ if output_hidden_states:
1343
+ all_hidden_states += (hidden_states,)
1344
+
1345
+ past_key_value = (past_key_values[idx] if past_key_values is not None else None)
1346
+ attn_mask = attention_mask
1347
+ slope_rate = slope_rates[idx]
1348
+ slope_rate = slope_rate * (1 - idx / (len(self.layers) - 1) + 1e-5)
1349
+ if self.gradient_checkpointing and self.training:
1350
+ layer_outputs = self._gradient_checkpointing_func(
1351
+ decoder_layer.__call__,
1352
+ hidden_states,
1353
+ attention_mask,
1354
+ position_ids,
1355
+ past_key_values,
1356
+ output_attentions,
1357
+ output_router_logits,
1358
+ use_cache,
1359
+ )
1360
+ else:
1361
+ layer_outputs = decoder_layer(
1362
+ hidden_states,
1363
+ attention_mask=attn_mask,
1364
+ position_ids=position_ids,
1365
+ past_key_value=past_key_value,
1366
+ output_attentions=output_attentions,
1367
+ output_router_logits=output_router_logits,
1368
+ use_cache=use_cache,
1369
+ slope_rate=slope_rate
1370
+ )
1371
+
1372
+ hidden_states = layer_outputs[0]
1373
+
1374
+ if use_cache:
1375
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1376
+
1377
+ if output_attentions:
1378
+ all_self_attns += (layer_outputs[1],)
1379
+
1380
+ if output_router_logits:
1381
+ all_router_logits += (layer_outputs[-1],)
1382
+
1383
+ hidden_states = self.norm(hidden_states)
1384
+
1385
+ # add hidden states from the last decoder layer
1386
+ if output_hidden_states:
1387
+ all_hidden_states += (hidden_states,)
1388
+ next_cache = next_decoder_cache if use_cache else None
1389
+ if not return_dict:
1390
+ return tuple(
1391
+ v
1392
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
1393
+ if v is not None
1394
+ )
1395
+ return MoeModelOutputWithPast(
1396
+ last_hidden_state=hidden_states,
1397
+ past_key_values=next_cache,
1398
+ hidden_states=all_hidden_states,
1399
+ attentions=all_self_attns,
1400
+ router_logits=all_router_logits,
1401
+ )
1402
+
1403
+
1404
+ class MiniMaxM1ForCausalLM(MiniMaxM1PreTrainedModel):
1405
+ _tied_weights_keys = ["lm_head.weight"]
1406
+
1407
+ def __init__(self, config):
1408
+ super().__init__(config)
1409
+ self.model = MiniMaxM1Model(config)
1410
+ self.vocab_size = config.vocab_size
1411
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1412
+ self.router_aux_loss_coef = config.router_aux_loss_coef
1413
+ self.num_experts = config.num_local_experts
1414
+ self.num_experts_per_tok = config.num_experts_per_tok
1415
+ # Initialize weights and apply final processing
1416
+ self.post_init()
1417
+
1418
+ def get_input_embeddings(self):
1419
+ return self.model.embed_tokens
1420
+
1421
+ def set_input_embeddings(self, value):
1422
+ self.model.embed_tokens = value
1423
+
1424
+ def get_output_embeddings(self):
1425
+ return self.lm_head
1426
+
1427
+ def set_output_embeddings(self, new_embeddings):
1428
+ self.lm_head = new_embeddings
1429
+
1430
+ def set_decoder(self, decoder):
1431
+ self.model = decoder
1432
+
1433
+ def get_decoder(self):
1434
+ return self.model
1435
+
1436
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1437
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1438
+ # Ignore copy
1439
+ def forward(
1440
+ self,
1441
+ input_ids: torch.LongTensor = None,
1442
+ attention_mask: Optional[torch.Tensor] = None,
1443
+ position_ids: Optional[torch.LongTensor] = None,
1444
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1445
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1446
+ labels: Optional[torch.LongTensor] = None,
1447
+ use_cache: Optional[bool] = None,
1448
+ output_attentions: Optional[bool] = None,
1449
+ output_hidden_states: Optional[bool] = None,
1450
+ output_router_logits: Optional[bool] = None,
1451
+ return_dict: Optional[bool] = None,
1452
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1453
+ r"""
1454
+ Args:
1455
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1456
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1457
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1458
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1459
+
1460
+ Returns:
1461
+
1462
+ Example:
1463
+
1464
+ ```python
1465
+ >>> from transformers import AutoTokenizer, MiniMaxM1ForCausalLM
1466
+
1467
+ >>> model = MiniMaxM1ForCausalLM.from_pretrained(PATH_TO_WEIGHTS)
1468
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_WEIGHTS)
1469
+
1470
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1471
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1472
+
1473
+ >>> # Generate
1474
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1475
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1476
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1477
+ ```"""
1478
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1479
+ output_router_logits = (
1480
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1481
+ )
1482
+
1483
+ output_hidden_states = (
1484
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1485
+ )
1486
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1487
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1488
+ outputs = self.model(
1489
+ input_ids=input_ids,
1490
+ attention_mask=attention_mask,
1491
+ position_ids=position_ids,
1492
+ past_key_values=past_key_values,
1493
+ inputs_embeds=inputs_embeds,
1494
+ use_cache=use_cache,
1495
+ output_attentions=output_attentions,
1496
+ output_hidden_states=output_hidden_states,
1497
+ output_router_logits=output_router_logits,
1498
+ return_dict=return_dict,
1499
+ )
1500
+
1501
+ hidden_states = outputs[0]
1502
+ logits = self.lm_head(hidden_states)
1503
+ logits = logits.float()
1504
+
1505
+ loss = None
1506
+ if labels is not None:
1507
+ # Shift so that tokens < n predict n
1508
+ shift_logits = logits[..., :-1, :].contiguous()
1509
+ shift_labels = labels[..., 1:].contiguous()
1510
+ # Flatten the tokens
1511
+ loss_fct = CrossEntropyLoss()
1512
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1513
+ shift_labels = shift_labels.view(-1)
1514
+ # Enable model parallelism
1515
+ shift_labels = shift_labels.to(shift_logits.device)
1516
+ loss = loss_fct(shift_logits, shift_labels)
1517
+
1518
+ aux_loss = None
1519
+ if output_router_logits:
1520
+ aux_loss = load_balancing_loss_func(
1521
+ outputs.router_logits if return_dict else outputs[-1],
1522
+ self.num_experts,
1523
+ self.num_experts_per_tok,
1524
+ attention_mask,
1525
+ )
1526
+ if labels is not None:
1527
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
1528
+
1529
+ if not return_dict:
1530
+ output = (logits,) + outputs[1:]
1531
+ if output_router_logits:
1532
+ output = (aux_loss,) + output
1533
+ return (loss,) + output if loss is not None else output
1534
+
1535
+ torch.cuda.empty_cache()
1536
+ return MoeCausalLMOutputWithPast(
1537
+ loss=loss,
1538
+ aux_loss=aux_loss,
1539
+ logits=logits,
1540
+ past_key_values=outputs.past_key_values,
1541
+ hidden_states=outputs.hidden_states,
1542
+ attentions=outputs.attentions,
1543
+ router_logits=outputs.router_logits,
1544
+ )
1545
+
1546
+ def prepare_inputs_for_generation(
1547
+ self,
1548
+ input_ids,
1549
+ past_key_values=None,
1550
+ attention_mask=None,
1551
+ inputs_embeds=None,
1552
+ **kwargs,
1553
+ ):
1554
+ if past_key_values:
1555
+ input_ids = input_ids[:, -1:]
1556
+
1557
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1558
+ if inputs_embeds is not None and past_key_values is None:
1559
+ model_inputs = {"inputs_embeds": inputs_embeds}
1560
+ else:
1561
+ model_inputs = {"input_ids": input_ids}
1562
+
1563
+ model_inputs.update({
1564
+ "past_key_values": past_key_values,
1565
+ "use_cache": kwargs.get("use_cache"),
1566
+ "attention_mask": attention_mask,
1567
+ })
1568
+ return model_inputs
1569
+
1570
+ @staticmethod
1571
+ def _reorder_cache(past_key_values, beam_idx):
1572
+ reordered_past = ()
1573
+ for layer_past in past_key_values:
1574
+ reordered_past += (
1575
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1576
+ )
1577
+ return reordered_past
1578
+
1579
+
1580
+ @add_start_docstrings(
1581
+ """
1582
+ The MiniMaxM1 Model transformer with a sequence classification head on top (linear layer).
1583
+
1584
+ [`MiniMaxM1ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1585
+ (e.g. GPT-2) do.
1586
+
1587
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1588
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1589
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1590
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1591
+ each row of the batch).
1592
+ """,
1593
+ MIXTRAL_START_DOCSTRING,
1594
+ )
1595
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->MiniMaxM1, LLAMA->MIXTRAL
1596
+ class MiniMaxM1ForSequenceClassification(MiniMaxM1PreTrainedModel):
1597
+ def __init__(self, config):
1598
+ super().__init__(config)
1599
+ self.num_labels = config.num_labels
1600
+ self.model = MiniMaxM1Model(config)
1601
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1602
+
1603
+ # Initialize weights and apply final processing
1604
+ self.post_init()
1605
+
1606
+ def get_input_embeddings(self):
1607
+ return self.model.embed_tokens
1608
+
1609
+ def set_input_embeddings(self, value):
1610
+ self.model.embed_tokens = value
1611
+
1612
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1613
+ def forward(
1614
+ self,
1615
+ input_ids: torch.LongTensor = None,
1616
+ attention_mask: Optional[torch.Tensor] = None,
1617
+ position_ids: Optional[torch.LongTensor] = None,
1618
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1619
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1620
+ labels: Optional[torch.LongTensor] = None,
1621
+ use_cache: Optional[bool] = None,
1622
+ output_attentions: Optional[bool] = None,
1623
+ output_hidden_states: Optional[bool] = None,
1624
+ return_dict: Optional[bool] = None,
1625
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1626
+ r"""
1627
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1628
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1629
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1630
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1631
+ """
1632
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1633
+
1634
+ transformer_outputs = self.model(
1635
+ input_ids,
1636
+ attention_mask=attention_mask,
1637
+ position_ids=position_ids,
1638
+ past_key_values=past_key_values,
1639
+ inputs_embeds=inputs_embeds,
1640
+ use_cache=use_cache,
1641
+ output_attentions=output_attentions,
1642
+ output_hidden_states=output_hidden_states,
1643
+ return_dict=return_dict,
1644
+ )
1645
+ hidden_states = transformer_outputs[0]
1646
+ logits = self.score(hidden_states)
1647
+
1648
+ if input_ids is not None:
1649
+ batch_size = input_ids.shape[0]
1650
+ else:
1651
+ batch_size = inputs_embeds.shape[0]
1652
+
1653
+ if self.config.pad_token_id is None and batch_size != 1:
1654
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1655
+ if self.config.pad_token_id is None:
1656
+ sequence_lengths = -1
1657
+ else:
1658
+ if input_ids is not None:
1659
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1660
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1661
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1662
+ sequence_lengths = sequence_lengths.to(logits.device)
1663
+ else:
1664
+ sequence_lengths = -1
1665
+
1666
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1667
+
1668
+ loss = None
1669
+ if labels is not None:
1670
+ labels = labels.to(logits.device)
1671
+ if self.config.problem_type is None:
1672
+ if self.num_labels == 1:
1673
+ self.config.problem_type = "regression"
1674
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1675
+ self.config.problem_type = "single_label_classification"
1676
+ else:
1677
+ self.config.problem_type = "multi_label_classification"
1678
+
1679
+ if self.config.problem_type == "regression":
1680
+ loss_fct = MSELoss()
1681
+ if self.num_labels == 1:
1682
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1683
+ else:
1684
+ loss = loss_fct(pooled_logits, labels)
1685
+ elif self.config.problem_type == "single_label_classification":
1686
+ loss_fct = CrossEntropyLoss()
1687
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1688
+ elif self.config.problem_type == "multi_label_classification":
1689
+ loss_fct = BCEWithLogitsLoss()
1690
+ loss = loss_fct(pooled_logits, labels)
1691
+ if not return_dict:
1692
+ output = (pooled_logits,) + transformer_outputs[1:]
1693
+ return ((loss,) + output) if loss is not None else output
1694
+
1695
+ return SequenceClassifierOutputWithPast(
1696
+ loss=loss,
1697
+ logits=pooled_logits,
1698
+ past_key_values=transformer_outputs.past_key_values,
1699
+ hidden_states=transformer_outputs.hidden_states,
1700
+ attentions=transformer_outputs.attentions,
1701
+ )
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<beginning_of_sentence>",
4
+ "clean_up_tokenization_spaces": false,
5
+ "eos_token": "<end_of_sentence>",
6
+ "model_max_length": 40960000,
7
+ "tokenizer_class": "GPT2Tokenizer",
8
+ "unk_token": "<end_of_document>",
9
+ "chat_template": "{{ '<begin_of_document>' -}}{% set ns = namespace(system_prompt='') -%}{% for message in messages -%}{% if message['role'] == 'system' -%}{% set ns.system_prompt = ns.system_prompt + message['content'][0]['text'] -%}{% endif -%}{%- endfor -%}{% if ns.system_prompt != '' -%}{{ '<beginning_of_sentence>system ai_setting=assistant\n' + ns.system_prompt + '<end_of_sentence>\n' -}}{%- endif -%}{% if tools -%}{{ '<beginning_of_sentence>system tool_setting=tools\nYou are provided with these tools:\n<tools>\n' -}}{% for tool in tools -%}{{ tool | tojson ~ '\n' -}}{%- endfor -%}{{ '</tools>\n\nIf you need to call tools, please respond with <tool_calls></tool_calls> XML tags, and provide tool-name and json-object of arguments, following the format below:\n<tool_calls>\n{''name'': <tool-name-1>, ''arguments'': <args-json-object-1>}\n...\n</tool_calls><end_of_sentence>\n' -}}{%- endif -%}{% for message in messages -%}{% if message['role'] == 'user' -%}{{ '<beginning_of_sentence>user name=user\n' + message['content'][0]['text'] + '<end_of_sentence>\n' -}}{% elif message['role'] == 'assistant' -%}{{ '<beginning_of_sentence>ai name=assistant\n' -}}{% for content in message['content'] | selectattr('type', 'equalto', 'text') -%}{{ content['text'] -}}{%- endfor -%}{{ '<end_of_sentence>\n' -}}{% elif message['role'] == 'tool' -%}{{ '<beginning_of_sentence>tool name=tools\n' }} {%- for content in message['content'] -%}{{- 'tool name: ' + content['name'] + '\n' + 'tool result: ' + content['text'] + '\n\n' -}} {%- endfor -%}{{- '<end_of_sentence>\n' -}}{% endif -%}{%- endfor -%}{% if add_generation_prompt -%}{{ '<beginning_of_sentence>ai name=assistant\n' -}}{%- endif -%}"
10
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff