hysts HF Staff commited on
Commit
731a1ef
·
1 Parent(s): f4580f0
Files changed (11) hide show
  1. .pre-commit-config.yaml +33 -0
  2. .python-version +1 -0
  3. .vscode/extensions.json +8 -0
  4. .vscode/settings.json +17 -0
  5. LICENSE +21 -0
  6. README.md +3 -3
  7. app.py +65 -0
  8. pyproject.toml +63 -0
  9. requirements.txt +258 -0
  10. style.css +11 -0
  11. uv.lock +0 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v5.0.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/astral-sh/ruff-pre-commit
17
+ rev: v0.9.6
18
+ hooks:
19
+ - id: ruff
20
+ args: ["--fix"]
21
+ - id: ruff-format
22
+ - repo: https://github.com/pre-commit/mirrors-mypy
23
+ rev: v1.15.0
24
+ hooks:
25
+ - id: mypy
26
+ args: ["--ignore-missing-imports"]
27
+ additional_dependencies:
28
+ [
29
+ "types-python-slugify",
30
+ "types-pytz",
31
+ "types-PyYAML",
32
+ "types-requests",
33
+ ]
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
.vscode/extensions.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "recommendations": [
3
+ "ms-python.python",
4
+ "charliermarsh.ruff",
5
+ "streetsidesoftware.code-spell-checker",
6
+ "tamasfe.even-better-toml"
7
+ ]
8
+ }
.vscode/settings.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.formatOnSave": true,
3
+ "files.insertFinalNewline": false,
4
+ "[python]": {
5
+ "editor.defaultFormatter": "charliermarsh.ruff",
6
+ "editor.formatOnType": true,
7
+ "editor.codeActionsOnSave": {
8
+ "source.fixAll.ruff": "explicit",
9
+ "source.organizeImports": "explicit"
10
+ }
11
+ },
12
+ "[jupyter]": {
13
+ "files.insertFinalNewline": false
14
+ },
15
+ "notebook.output.scrolling": true,
16
+ "notebook.formatOnSave.enabled": true
17
+ }
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Zamba2 7B Instruct
3
- emoji: 🏢
4
- colorFrom: blue
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.16.0
8
  app_file: app.py
 
1
  ---
2
  title: Zamba2 7B Instruct
3
+ emoji:
4
+ colorFrom: red
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.16.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from collections.abc import Iterator
4
+ from threading import Thread
5
+
6
+ import gradio as gr
7
+ import spaces
8
+ import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
+
11
+ MAX_INPUT_TOKEN_LENGTH = 4096
12
+
13
+ model_id = "Zyphra/Zamba2-7B-instruct"
14
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16)
15
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
16
+
17
+
18
+ @spaces.GPU(duration=90)
19
+ def generate(
20
+ message: str,
21
+ chat_history: list[dict],
22
+ ) -> Iterator[str]:
23
+ conversation = [*chat_history, {"role": "user", "content": message}]
24
+
25
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
26
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
27
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
28
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
29
+ input_ids = input_ids.to(model.device)
30
+
31
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
32
+ generate_kwargs = dict(
33
+ {"input_ids": input_ids},
34
+ streamer=streamer,
35
+ max_new_tokens=MAX_INPUT_TOKEN_LENGTH,
36
+ do_sample=False,
37
+ num_beams=1,
38
+ )
39
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
40
+ t.start()
41
+
42
+ outputs = []
43
+ for text in streamer:
44
+ outputs.append(text)
45
+ yield "".join(outputs)
46
+
47
+
48
+ demo = gr.ChatInterface(
49
+ fn=generate,
50
+ stop_btn=None,
51
+ examples=[
52
+ ["Hello there! How are you doing?"],
53
+ ["Can you explain briefly to me what is the Python programming language?"],
54
+ ["Explain the plot of Cinderella in a sentence."],
55
+ ["How many hours does it take a man to eat a Helicopter?"],
56
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
57
+ ],
58
+ cache_examples=False,
59
+ type="messages",
60
+ description="# Zamba2-7B-instruct",
61
+ css_paths="style.css",
62
+ )
63
+
64
+ if __name__ == "__main__":
65
+ demo.queue(max_size=20).launch()
pyproject.toml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "zamba2-7b-instruct"
3
+ version = "0.1.0"
4
+ description = ""
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "accelerate>=1.3.0",
9
+ "causal-conv1d",
10
+ "gradio>=5.16.0",
11
+ "hf-transfer>=0.1.9",
12
+ "mamba-ssm",
13
+ "spaces>=0.32.0",
14
+ "torch==2.4.0",
15
+ "transformers",
16
+ ]
17
+
18
+ [tool.ruff]
19
+ line-length = 119
20
+
21
+ [tool.ruff.lint]
22
+ select = ["ALL"]
23
+ ignore = [
24
+ "COM812", # missing-trailing-comma
25
+ "D203", # one-blank-line-before-class
26
+ "D213", # multi-line-summary-second-line
27
+ "E501", # line-too-long
28
+ "SIM117", # multiple-with-statements
29
+ ]
30
+ extend-ignore = [
31
+ "D100", # undocumented-public-module
32
+ "D101", # undocumented-public-class
33
+ "D102", # undocumented-public-method
34
+ "D103", # undocumented-public-function
35
+ "D104", # undocumented-public-package
36
+ "D105", # undocumented-magic-method
37
+ "D107", # undocumented-public-init
38
+ "EM101", # raw-string-in-exception
39
+ "FBT001", # boolean-type-hint-positional-argument
40
+ "FBT002", # boolean-default-value-positional-argument
41
+ "PD901", # pandas-df-variable-name
42
+ "PGH003", # blanket-type-ignore
43
+ "PLR0913", # too-many-arguments
44
+ "PLR0915", # too-many-statements
45
+ "TRY003", # raise-vanilla-args
46
+ ]
47
+ unfixable = [
48
+ "F401", # unused-import
49
+ ]
50
+
51
+ [tool.ruff.lint.pydocstyle]
52
+ convention = "google"
53
+
54
+ [tool.ruff.lint.per-file-ignores]
55
+ "*.ipynb" = ["T201"]
56
+
57
+ [tool.ruff.format]
58
+ docstring-code-format = true
59
+
60
+ [tool.uv.sources]
61
+ transformers = { git = "https://github.com/huggingface/transformers" }
62
+ causal-conv1d = { url = "https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.0.post8/causal_conv1d-1.5.0.post8+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl" }
63
+ mamba-ssm = { url = "https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl" }
requirements.txt ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml -o requirements.txt
3
+ accelerate==1.3.0
4
+ # via zamba2-7b-instruct (pyproject.toml)
5
+ aiofiles==23.2.1
6
+ # via gradio
7
+ annotated-types==0.7.0
8
+ # via pydantic
9
+ anyio==4.8.0
10
+ # via
11
+ # gradio
12
+ # httpx
13
+ # starlette
14
+ causal-conv1d @ https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.0.post8/causal_conv1d-1.5.0.post8+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
15
+ # via zamba2-7b-instruct (pyproject.toml)
16
+ certifi==2025.1.31
17
+ # via
18
+ # httpcore
19
+ # httpx
20
+ # requests
21
+ charset-normalizer==3.4.1
22
+ # via requests
23
+ click==8.1.8
24
+ # via
25
+ # typer
26
+ # uvicorn
27
+ einops==0.8.1
28
+ # via mamba-ssm
29
+ exceptiongroup==1.2.2
30
+ # via anyio
31
+ fastapi==0.115.8
32
+ # via gradio
33
+ ffmpy==0.5.0
34
+ # via gradio
35
+ filelock==3.17.0
36
+ # via
37
+ # huggingface-hub
38
+ # torch
39
+ # transformers
40
+ # triton
41
+ fsspec==2025.2.0
42
+ # via
43
+ # gradio-client
44
+ # huggingface-hub
45
+ # torch
46
+ gradio==5.16.0
47
+ # via
48
+ # zamba2-7b-instruct (pyproject.toml)
49
+ # spaces
50
+ gradio-client==1.7.0
51
+ # via gradio
52
+ h11==0.14.0
53
+ # via
54
+ # httpcore
55
+ # uvicorn
56
+ hf-transfer==0.1.9
57
+ # via zamba2-7b-instruct (pyproject.toml)
58
+ httpcore==1.0.7
59
+ # via httpx
60
+ httpx==0.28.1
61
+ # via
62
+ # gradio
63
+ # gradio-client
64
+ # safehttpx
65
+ # spaces
66
+ huggingface-hub==0.28.1
67
+ # via
68
+ # accelerate
69
+ # gradio
70
+ # gradio-client
71
+ # tokenizers
72
+ # transformers
73
+ idna==3.10
74
+ # via
75
+ # anyio
76
+ # httpx
77
+ # requests
78
+ jinja2==3.1.5
79
+ # via
80
+ # gradio
81
+ # torch
82
+ mamba-ssm @ https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
83
+ # via zamba2-7b-instruct (pyproject.toml)
84
+ markdown-it-py==3.0.0
85
+ # via rich
86
+ markupsafe==2.1.5
87
+ # via
88
+ # gradio
89
+ # jinja2
90
+ mdurl==0.1.2
91
+ # via markdown-it-py
92
+ mpmath==1.3.0
93
+ # via sympy
94
+ networkx==3.4.2
95
+ # via torch
96
+ ninja==1.11.1.3
97
+ # via
98
+ # causal-conv1d
99
+ # mamba-ssm
100
+ numpy==2.2.3
101
+ # via
102
+ # accelerate
103
+ # gradio
104
+ # pandas
105
+ # transformers
106
+ nvidia-cublas-cu12==12.1.3.1
107
+ # via
108
+ # nvidia-cudnn-cu12
109
+ # nvidia-cusolver-cu12
110
+ # torch
111
+ nvidia-cuda-cupti-cu12==12.1.105
112
+ # via torch
113
+ nvidia-cuda-nvrtc-cu12==12.1.105
114
+ # via torch
115
+ nvidia-cuda-runtime-cu12==12.1.105
116
+ # via torch
117
+ nvidia-cudnn-cu12==9.1.0.70
118
+ # via torch
119
+ nvidia-cufft-cu12==11.0.2.54
120
+ # via torch
121
+ nvidia-curand-cu12==10.3.2.106
122
+ # via torch
123
+ nvidia-cusolver-cu12==11.4.5.107
124
+ # via torch
125
+ nvidia-cusparse-cu12==12.1.0.106
126
+ # via
127
+ # nvidia-cusolver-cu12
128
+ # torch
129
+ nvidia-nccl-cu12==2.20.5
130
+ # via torch
131
+ nvidia-nvjitlink-cu12==12.8.61
132
+ # via
133
+ # nvidia-cusolver-cu12
134
+ # nvidia-cusparse-cu12
135
+ nvidia-nvtx-cu12==12.1.105
136
+ # via torch
137
+ orjson==3.10.15
138
+ # via gradio
139
+ packaging==24.2
140
+ # via
141
+ # accelerate
142
+ # causal-conv1d
143
+ # gradio
144
+ # gradio-client
145
+ # huggingface-hub
146
+ # mamba-ssm
147
+ # spaces
148
+ # transformers
149
+ pandas==2.2.3
150
+ # via gradio
151
+ pillow==11.1.0
152
+ # via gradio
153
+ psutil==5.9.8
154
+ # via
155
+ # accelerate
156
+ # spaces
157
+ pydantic==2.10.6
158
+ # via
159
+ # fastapi
160
+ # gradio
161
+ # spaces
162
+ pydantic-core==2.27.2
163
+ # via pydantic
164
+ pydub==0.25.1
165
+ # via gradio
166
+ pygments==2.19.1
167
+ # via rich
168
+ python-dateutil==2.9.0.post0
169
+ # via pandas
170
+ python-multipart==0.0.20
171
+ # via gradio
172
+ pytz==2025.1
173
+ # via pandas
174
+ pyyaml==6.0.2
175
+ # via
176
+ # accelerate
177
+ # gradio
178
+ # huggingface-hub
179
+ # transformers
180
+ regex==2024.11.6
181
+ # via transformers
182
+ requests==2.32.3
183
+ # via
184
+ # huggingface-hub
185
+ # spaces
186
+ # transformers
187
+ rich==13.9.4
188
+ # via typer
189
+ ruff==0.9.6
190
+ # via gradio
191
+ safehttpx==0.1.6
192
+ # via gradio
193
+ safetensors==0.5.2
194
+ # via
195
+ # accelerate
196
+ # transformers
197
+ semantic-version==2.10.0
198
+ # via gradio
199
+ setuptools==75.8.0
200
+ # via mamba-ssm
201
+ shellingham==1.5.4
202
+ # via typer
203
+ six==1.17.0
204
+ # via python-dateutil
205
+ sniffio==1.3.1
206
+ # via anyio
207
+ spaces==0.32.0
208
+ # via zamba2-7b-instruct (pyproject.toml)
209
+ starlette==0.45.3
210
+ # via
211
+ # fastapi
212
+ # gradio
213
+ sympy==1.13.3
214
+ # via torch
215
+ tokenizers==0.21.0
216
+ # via transformers
217
+ tomlkit==0.13.2
218
+ # via gradio
219
+ torch==2.4.0
220
+ # via
221
+ # zamba2-7b-instruct (pyproject.toml)
222
+ # accelerate
223
+ # causal-conv1d
224
+ # mamba-ssm
225
+ tqdm==4.67.1
226
+ # via
227
+ # huggingface-hub
228
+ # transformers
229
+ transformers @ git+https://github.com/huggingface/transformers@336dc69d63d56f232a183a3e7f52790429b871ef
230
+ # via
231
+ # zamba2-7b-instruct (pyproject.toml)
232
+ # mamba-ssm
233
+ triton==3.0.0
234
+ # via torch
235
+ typer==0.15.1
236
+ # via gradio
237
+ typing-extensions==4.12.2
238
+ # via
239
+ # anyio
240
+ # fastapi
241
+ # gradio
242
+ # gradio-client
243
+ # huggingface-hub
244
+ # pydantic
245
+ # pydantic-core
246
+ # rich
247
+ # spaces
248
+ # torch
249
+ # typer
250
+ # uvicorn
251
+ tzdata==2025.1
252
+ # via pandas
253
+ urllib3==2.3.0
254
+ # via requests
255
+ uvicorn==0.34.0
256
+ # via gradio
257
+ websockets==14.2
258
+ # via gradio-client
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
+ margin: auto;
8
+ color: #fff;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
+ }
uv.lock ADDED
The diff for this file is too large to render. See raw diff