lamhieu commited on
Commit
0a30342
โ€ข
1 Parent(s): 8fee735

chore: enable flash attention 2

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +14 -2
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Ghost 8B Beta
3
- emoji: ๐Ÿ‘ป
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
 
1
  ---
2
  title: Ghost 8B Beta
3
+ emoji: ๐Ÿ‘ป / ๐Ÿฅธ
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
app.py CHANGED
@@ -1,3 +1,13 @@
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
@@ -7,6 +17,7 @@ import spaces
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
 
10
  MAX_MAX_NEW_TOKENS = 4096
11
  DEFAULT_MAX_NEW_TOKENS = 1536
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
@@ -18,7 +29,7 @@ DESCRIPTION = """\
18
 
19
  The languages supported are ๐Ÿ‡บ๐Ÿ‡ธ English, ๐Ÿ‡ซ๐Ÿ‡ท French, ๐Ÿ‡ฎ๐Ÿ‡น Italian, ๐Ÿ‡ช๐Ÿ‡ธ Spanish, ๐Ÿ‡ต๐Ÿ‡น Portuguese, ๐Ÿ‡ฉ๐Ÿ‡ช German, ๐Ÿ‡ป๐Ÿ‡ณ Vietnamese, ๐Ÿ‡ฐ๐Ÿ‡ท Korean and ๐Ÿ‡จ๐Ÿ‡ณ Chinese.
20
 
21
- ๐Ÿ“‹ Note: current model version is "disl-0x5-8k" (10 Jul 2024), context length 8k and current status is "moderating / previewing". For detailed information about the model, see [here](https://ghost-x.org/docs/models/ghost-8b-beta/). Try to experience it the way you want!
22
  """
23
 
24
 
@@ -241,6 +252,8 @@ if torch.cuda.is_available():
241
  model = AutoModelForCausalLM.from_pretrained(
242
  model_id,
243
  device_map="auto",
 
 
244
  trust_remote_code=True,
245
  token=model_tk,
246
  )
@@ -363,4 +376,3 @@ with gr.Blocks(fill_height=True, css="style.css") as demo:
363
 
364
  if __name__ == "__main__":
365
  demo.queue(max_size=20).launch(share=True)
366
- # demo.launch(share=True)
 
1
+ # pylint: skip-file
2
+
3
+ import subprocess
4
+
5
+ subprocess.run(
6
+ f"pip install flash-attn --no-build-isolation",
7
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
8
+ shell=True,
9
+ )
10
+
11
  import os
12
  from threading import Thread
13
  from typing import Iterator
 
17
  import torch
18
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
19
 
20
+
21
  MAX_MAX_NEW_TOKENS = 4096
22
  DEFAULT_MAX_NEW_TOKENS = 1536
23
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
 
29
 
30
  The languages supported are ๐Ÿ‡บ๐Ÿ‡ธ English, ๐Ÿ‡ซ๐Ÿ‡ท French, ๐Ÿ‡ฎ๐Ÿ‡น Italian, ๐Ÿ‡ช๐Ÿ‡ธ Spanish, ๐Ÿ‡ต๐Ÿ‡น Portuguese, ๐Ÿ‡ฉ๐Ÿ‡ช German, ๐Ÿ‡ป๐Ÿ‡ณ Vietnamese, ๐Ÿ‡ฐ๐Ÿ‡ท Korean and ๐Ÿ‡จ๐Ÿ‡ณ Chinese.
31
 
32
+ ๐Ÿ“‹ Note: current model version is "disl-0x5" (10 Jul 2024), context length 8k (8192 tokens) and current status is "moderating / previewing". For detailed information about the model, see [here](https://ghost-x.org/docs/models/ghost-8b-beta/). Try to experience it the way you want!
33
  """
34
 
35
 
 
252
  model = AutoModelForCausalLM.from_pretrained(
253
  model_id,
254
  device_map="auto",
255
+ torch_dtype=torch.bfloat16,
256
+ attn_implementation="flash_attention_2",
257
  trust_remote_code=True,
258
  token=model_tk,
259
  )
 
376
 
377
  if __name__ == "__main__":
378
  demo.queue(max_size=20).launch(share=True)