Spaces:
Paused
Paused
chore: enable flash attention 2
Browse files
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Ghost 8B Beta
|
3 |
-
emoji: ๐ป
|
4 |
colorFrom: indigo
|
5 |
colorTo: pink
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: Ghost 8B Beta
|
3 |
+
emoji: ๐ป / ๐ฅธ
|
4 |
colorFrom: indigo
|
5 |
colorTo: pink
|
6 |
sdk: gradio
|
app.py
CHANGED
@@ -1,3 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
from threading import Thread
|
3 |
from typing import Iterator
|
@@ -7,6 +17,7 @@ import spaces
|
|
7 |
import torch
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
9 |
|
|
|
10 |
MAX_MAX_NEW_TOKENS = 4096
|
11 |
DEFAULT_MAX_NEW_TOKENS = 1536
|
12 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
|
@@ -18,7 +29,7 @@ DESCRIPTION = """\
|
|
18 |
|
19 |
The languages supported are ๐บ๐ธ English, ๐ซ๐ท French, ๐ฎ๐น Italian, ๐ช๐ธ Spanish, ๐ต๐น Portuguese, ๐ฉ๐ช German, ๐ป๐ณ Vietnamese, ๐ฐ๐ท Korean and ๐จ๐ณ Chinese.
|
20 |
|
21 |
-
๐ Note: current model version is "disl-0x5
|
22 |
"""
|
23 |
|
24 |
|
@@ -241,6 +252,8 @@ if torch.cuda.is_available():
|
|
241 |
model = AutoModelForCausalLM.from_pretrained(
|
242 |
model_id,
|
243 |
device_map="auto",
|
|
|
|
|
244 |
trust_remote_code=True,
|
245 |
token=model_tk,
|
246 |
)
|
@@ -363,4 +376,3 @@ with gr.Blocks(fill_height=True, css="style.css") as demo:
|
|
363 |
|
364 |
if __name__ == "__main__":
|
365 |
demo.queue(max_size=20).launch(share=True)
|
366 |
-
# demo.launch(share=True)
|
|
|
1 |
+
# pylint: skip-file
|
2 |
+
|
3 |
+
import subprocess
|
4 |
+
|
5 |
+
subprocess.run(
|
6 |
+
f"pip install flash-attn --no-build-isolation",
|
7 |
+
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
|
8 |
+
shell=True,
|
9 |
+
)
|
10 |
+
|
11 |
import os
|
12 |
from threading import Thread
|
13 |
from typing import Iterator
|
|
|
17 |
import torch
|
18 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
19 |
|
20 |
+
|
21 |
MAX_MAX_NEW_TOKENS = 4096
|
22 |
DEFAULT_MAX_NEW_TOKENS = 1536
|
23 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
|
|
|
29 |
|
30 |
The languages supported are ๐บ๐ธ English, ๐ซ๐ท French, ๐ฎ๐น Italian, ๐ช๐ธ Spanish, ๐ต๐น Portuguese, ๐ฉ๐ช German, ๐ป๐ณ Vietnamese, ๐ฐ๐ท Korean and ๐จ๐ณ Chinese.
|
31 |
|
32 |
+
๐ Note: current model version is "disl-0x5" (10 Jul 2024), context length 8k (8192 tokens) and current status is "moderating / previewing". For detailed information about the model, see [here](https://ghost-x.org/docs/models/ghost-8b-beta/). Try to experience it the way you want!
|
33 |
"""
|
34 |
|
35 |
|
|
|
252 |
model = AutoModelForCausalLM.from_pretrained(
|
253 |
model_id,
|
254 |
device_map="auto",
|
255 |
+
torch_dtype=torch.bfloat16,
|
256 |
+
attn_implementation="flash_attention_2",
|
257 |
trust_remote_code=True,
|
258 |
token=model_tk,
|
259 |
)
|
|
|
376 |
|
377 |
if __name__ == "__main__":
|
378 |
demo.queue(max_size=20).launch(share=True)
|
|