tangzhy commited on
Commit
5c20eaf
·
verified ·
1 Parent(s): 7ff1593

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -14
app.py CHANGED
@@ -12,9 +12,6 @@ from transformers import (
12
  TextIteratorStreamer,
13
  )
14
 
15
- import subprocess
16
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
17
-
18
  DESCRIPTION = """\
19
  # ORLM LLaMA-3-8B
20
  Hello! I'm ORLM-LLaMA-3-8B, here to automate your optimization modeling tasks! Check our [repo](https://github.com/Cardinal-Operations/ORLM) and [paper](https://arxiv.org/abs/2405.17743)!
@@ -33,22 +30,14 @@ MAX_MAX_NEW_TOKENS = 4096
33
  DEFAULT_MAX_NEW_TOKENS = 4096
34
  MAX_INPUT_TOKEN_LENGTH = 2048
35
 
36
- # quantization_config = BitsAndBytesConfig(
37
- # load_in_4bit=True,
38
- # bnb_4bit_compute_dtype=torch.bfloat16,
39
- # bnb_4bit_use_double_quant=True,
40
- # bnb_4bit_quant_type= "nf4")
41
- # quantization_config = BitsAndBytesConfig(load_in_8bit=True)
42
-
43
  model_id = "CardinalOperations/ORLM-LLaMA-3-8B"
44
- tokenizer = AutoTokenizer.from_pretrained(model_id)
45
  model = AutoModelForCausalLM.from_pretrained(
46
  model_id,
47
  device_map="auto",
48
- torch_dtype=torch.bfloat16,
49
- attn_implementation="flash_attention_2",
50
- # quantization_config=quantization_config,
51
  )
 
52
  model.eval()
53
 
54
 
 
12
  TextIteratorStreamer,
13
  )
14
 
 
 
 
15
  DESCRIPTION = """\
16
  # ORLM LLaMA-3-8B
17
  Hello! I'm ORLM-LLaMA-3-8B, here to automate your optimization modeling tasks! Check our [repo](https://github.com/Cardinal-Operations/ORLM) and [paper](https://arxiv.org/abs/2405.17743)!
 
30
  DEFAULT_MAX_NEW_TOKENS = 4096
31
  MAX_INPUT_TOKEN_LENGTH = 2048
32
 
 
 
 
 
 
 
 
33
  model_id = "CardinalOperations/ORLM-LLaMA-3-8B"
34
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
35
  model = AutoModelForCausalLM.from_pretrained(
36
  model_id,
37
  device_map="auto",
38
+ quantization_config=BitsAndBytesConfig(load_in_8bit=True),
 
 
39
  )
40
+ model.config.sliding_window = 4096
41
  model.eval()
42
 
43