habdine commited on
Commit
d68ef8f
1 Parent(s): 4051681

Upload 6 files

Browse files
Files changed (6) hide show
  1. README.md +5 -6
  2. app.py +119 -0
  3. gitattributes +35 -0
  4. pre-commit-config.yaml +60 -0
  5. requirements.txt +241 -0
  6. style.css +11 -0
README.md CHANGED
@@ -1,14 +1,13 @@
1
  ---
2
- title: Esm2Text
3
- emoji: 👀
4
- colorFrom: yellow
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.1.0
8
  app_file: app.py
9
  pinned: false
10
- license: cc-by-nc-4.0
11
- short_description: Generate a protein's function using its amino acid sequence
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: ESM2Text
3
+ emoji: 😻
4
+ colorFrom: indigo
5
+ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.1.0
8
  app_file: app.py
9
  pinned: false
10
+ short_description: Chatbot
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+ DESCRIPTION = """\
11
+ # ESM2Text Demo
12
+ """
13
+
14
+ MAX_MAX_NEW_TOKENS = 256
15
+ DEFAULT_MAX_NEW_TOKENS = 100
16
+
17
+
18
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained('habdine/Esm2Text-Base-v1-1',
22
+ trust_remote_code=True)
23
+ model = AutoModelForCausalLM.from_pretrained('habdine/Esm2Text-Base-v1-1',
24
+ device_map="auto",
25
+ trust_remote_code=True)
26
+ model.eval()
27
+
28
+
29
+ @spaces.GPU(duration=90)
30
+ def generate(
31
+ message: str,
32
+ max_new_tokens: int = 1024,
33
+ do_sample: bool = False,
34
+ temperature: float = 0.6,
35
+ top_p: float = 0.9,
36
+ top_k: int = 50,
37
+ repetition_penalty: float = 1.2,
38
+ ) -> Iterator[str]:
39
+
40
+
41
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
42
+ generate_kwargs = dict(
43
+ protein_sequence=message,
44
+ tokenizer=tokenizer,
45
+ device=device,
46
+ streamer=streamer,
47
+ max_new_tokens=max_new_tokens,
48
+ do_sample=do_sample,
49
+ top_p=top_p,
50
+ top_k=top_k,
51
+ temperature=temperature,
52
+ num_beams=1,
53
+ repetition_penalty=repetition_penalty,
54
+ )
55
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
56
+ t.start()
57
+
58
+ outputs = []
59
+ for text in streamer:
60
+ outputs.append(text)
61
+ yield "".join(outputs)
62
+
63
+
64
+ chat_interface = gr.ChatInterface(
65
+ fn=generate,
66
+ additional_inputs=[
67
+ gr.Slider(
68
+ label="Max new tokens",
69
+ minimum=1,
70
+ maximum=MAX_MAX_NEW_TOKENS,
71
+ step=1,
72
+ value=DEFAULT_MAX_NEW_TOKENS,
73
+ ),
74
+ gr.Checkbox(label="Do Sample"),
75
+ gr.Slider(
76
+ label="Temperature",
77
+ minimum=0.1,
78
+ maximum=4.0,
79
+ step=0.1,
80
+ value=0.6,
81
+ ),
82
+ gr.Slider(
83
+ label="Top-p (nucleus sampling)",
84
+ minimum=0.05,
85
+ maximum=1.0,
86
+ step=0.05,
87
+ value=0.9,
88
+ ),
89
+ gr.Slider(
90
+ label="Top-k",
91
+ minimum=1,
92
+ maximum=1000,
93
+ step=1,
94
+ value=50,
95
+ ),
96
+ gr.Slider(
97
+ label="Repetition penalty",
98
+ minimum=1.0,
99
+ maximum=2.0,
100
+ step=0.05,
101
+ value=1.0,
102
+ ),
103
+ ],
104
+ stop_btn=None,
105
+ examples=[
106
+ ['AEQAERYEEMVEFMEKL'],
107
+ ["MAVVLPAVVEELLSEMAAAVQESARIPDEYLLSLKFLFGSSATQALDLVDRQSITLISSPSGRRVYQVLGSSSKTYTCLASCHYCSCPAFAFSVLRKSDSILCKHLLAVYLSQVMRTCQQLSVSDKQLTDILLMEKKQEA"],
108
+ ],
109
+ cache_examples=False,
110
+ type="messages",
111
+ )
112
+
113
+ with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
114
+ gr.Markdown(DESCRIPTION)
115
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
116
+ chat_interface.render()
117
+
118
+ if __name__ == "__main__":
119
+ demo.queue(max_size=20).launch()
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
pre-commit-config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.6.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.13.2
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.10.1
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ [
33
+ "types-python-slugify",
34
+ "types-requests",
35
+ "types-PyYAML",
36
+ "types-pytz",
37
+ ]
38
+ - repo: https://github.com/psf/black
39
+ rev: 24.4.2
40
+ hooks:
41
+ - id: black
42
+ language_version: python3.10
43
+ args: ["--line-length", "119"]
44
+ - repo: https://github.com/kynan/nbstripout
45
+ rev: 0.7.1
46
+ hooks:
47
+ - id: nbstripout
48
+ args:
49
+ [
50
+ "--extra-keys",
51
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
52
+ ]
53
+ - repo: https://github.com/nbQA-dev/nbQA
54
+ rev: 1.8.5
55
+ hooks:
56
+ - id: nbqa-black
57
+ - id: nbqa-pyupgrade
58
+ args: ["--py37-plus"]
59
+ - id: nbqa-isort
60
+ args: ["--float-to-top"]
requirements.txt ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml -o requirements.txt
3
+ accelerate==1.0.0
4
+ # via gemma-2-9b-it (pyproject.toml)
5
+ aiofiles==23.2.1
6
+ # via gradio
7
+ annotated-types==0.7.0
8
+ # via pydantic
9
+ anyio==4.6.0
10
+ # via
11
+ # gradio
12
+ # httpx
13
+ # starlette
14
+ certifi==2024.8.30
15
+ # via
16
+ # httpcore
17
+ # httpx
18
+ # requests
19
+ charset-normalizer==3.3.2
20
+ # via requests
21
+ click==8.1.7
22
+ # via
23
+ # typer
24
+ # uvicorn
25
+ exceptiongroup==1.2.2
26
+ # via anyio
27
+ fastapi==0.115.0
28
+ # via gradio
29
+ ffmpy==0.4.0
30
+ # via gradio
31
+ filelock==3.16.1
32
+ # via
33
+ # huggingface-hub
34
+ # torch
35
+ # transformers
36
+ # triton
37
+ fsspec==2024.9.0
38
+ # via
39
+ # gradio-client
40
+ # huggingface-hub
41
+ # torch
42
+ gradio==5.0.1
43
+ # via
44
+ # gemma-2-9b-it (pyproject.toml)
45
+ # spaces
46
+ gradio-client==1.4.0
47
+ # via gradio
48
+ h11==0.14.0
49
+ # via
50
+ # httpcore
51
+ # uvicorn
52
+ hf-transfer==0.1.8
53
+ # via gemma-2-9b-it (pyproject.toml)
54
+ httpcore==1.0.5
55
+ # via httpx
56
+ httpx==0.27.2
57
+ # via
58
+ # gradio
59
+ # gradio-client
60
+ # spaces
61
+ huggingface-hub==0.25.1
62
+ # via
63
+ # accelerate
64
+ # gradio
65
+ # gradio-client
66
+ # tokenizers
67
+ # transformers
68
+ idna==3.10
69
+ # via
70
+ # anyio
71
+ # httpx
72
+ # requests
73
+ jinja2==3.1.4
74
+ # via
75
+ # gradio
76
+ # torch
77
+ markdown-it-py==3.0.0
78
+ # via rich
79
+ markupsafe==2.1.5
80
+ # via
81
+ # gradio
82
+ # jinja2
83
+ mdurl==0.1.2
84
+ # via markdown-it-py
85
+ mpmath==1.3.0
86
+ # via sympy
87
+ networkx==3.3
88
+ # via torch
89
+ numpy==2.1.1
90
+ # via
91
+ # accelerate
92
+ # gradio
93
+ # pandas
94
+ # transformers
95
+ nvidia-cublas-cu12==12.1.3.1
96
+ # via
97
+ # nvidia-cudnn-cu12
98
+ # nvidia-cusolver-cu12
99
+ # torch
100
+ nvidia-cuda-cupti-cu12==12.1.105
101
+ # via torch
102
+ nvidia-cuda-nvrtc-cu12==12.1.105
103
+ # via torch
104
+ nvidia-cuda-runtime-cu12==12.1.105
105
+ # via torch
106
+ nvidia-cudnn-cu12==9.1.0.70
107
+ # via torch
108
+ nvidia-cufft-cu12==11.0.2.54
109
+ # via torch
110
+ nvidia-curand-cu12==10.3.2.106
111
+ # via torch
112
+ nvidia-cusolver-cu12==11.4.5.107
113
+ # via torch
114
+ nvidia-cusparse-cu12==12.1.0.106
115
+ # via
116
+ # nvidia-cusolver-cu12
117
+ # torch
118
+ nvidia-nccl-cu12==2.20.5
119
+ # via torch
120
+ nvidia-nvjitlink-cu12==12.6.68
121
+ # via
122
+ # nvidia-cusolver-cu12
123
+ # nvidia-cusparse-cu12
124
+ nvidia-nvtx-cu12==12.1.105
125
+ # via torch
126
+ orjson==3.10.7
127
+ # via gradio
128
+ packaging==24.1
129
+ # via
130
+ # accelerate
131
+ # gradio
132
+ # gradio-client
133
+ # huggingface-hub
134
+ # spaces
135
+ # transformers
136
+ pandas==2.2.3
137
+ # via gradio
138
+ pillow==10.4.0
139
+ # via gradio
140
+ psutil==5.9.8
141
+ # via
142
+ # accelerate
143
+ # spaces
144
+ pydantic==2.9.2
145
+ # via
146
+ # fastapi
147
+ # gradio
148
+ # spaces
149
+ pydantic-core==2.23.4
150
+ # via pydantic
151
+ pydub==0.25.1
152
+ # via gradio
153
+ pygments==2.18.0
154
+ # via rich
155
+ python-dateutil==2.9.0.post0
156
+ # via pandas
157
+ python-multipart==0.0.12
158
+ # via gradio
159
+ pytz==2024.2
160
+ # via pandas
161
+ pyyaml==6.0.2
162
+ # via
163
+ # accelerate
164
+ # gradio
165
+ # huggingface-hub
166
+ # transformers
167
+ regex==2024.9.11
168
+ # via transformers
169
+ requests==2.32.3
170
+ # via
171
+ # huggingface-hub
172
+ # spaces
173
+ # transformers
174
+ rich==13.8.1
175
+ # via typer
176
+ ruff==0.6.8
177
+ # via gradio
178
+ safetensors==0.4.5
179
+ # via
180
+ # accelerate
181
+ # transformers
182
+ semantic-version==2.10.0
183
+ # via gradio
184
+ shellingham==1.5.4
185
+ # via typer
186
+ six==1.16.0
187
+ # via python-dateutil
188
+ sniffio==1.3.1
189
+ # via
190
+ # anyio
191
+ # httpx
192
+ spaces==0.30.3
193
+ # via gemma-2-9b-it (pyproject.toml)
194
+ starlette==0.38.6
195
+ # via fastapi
196
+ sympy==1.13.3
197
+ # via torch
198
+ tokenizers==0.20.0
199
+ # via transformers
200
+ tomlkit==0.12.0
201
+ # via gradio
202
+ torch==2.4.0
203
+ # via
204
+ # gemma-2-9b-it (pyproject.toml)
205
+ # accelerate
206
+ tqdm==4.66.5
207
+ # via
208
+ # huggingface-hub
209
+ # transformers
210
+ transformers==4.45.2
211
+ # via gemma-2-9b-it (pyproject.toml)
212
+ triton==3.0.0
213
+ # via torch
214
+ typer==0.12.5
215
+ # via gradio
216
+ typing-extensions==4.12.2
217
+ # via
218
+ # anyio
219
+ # fastapi
220
+ # gradio
221
+ # gradio-client
222
+ # huggingface-hub
223
+ # pydantic
224
+ # pydantic-core
225
+ # spaces
226
+ # torch
227
+ # typer
228
+ # uvicorn
229
+ tzdata==2024.2
230
+ # via pandas
231
+ urllib3==2.2.3
232
+ # via requests
233
+ uvicorn==0.31.0
234
+ # via gradio
235
+ websockets==12.0
236
+ # via gradio-client
237
+ torch_geometric
238
+ torch_scatter
239
+ torch_sparse
240
+ torch_cluster
241
+ torch_spline_conv
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
+ margin: auto;
8
+ color: #fff;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
+ }