tangzhy commited on
Commit
5420f2a
·
verified ·
1 Parent(s): 5ab915e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -25,23 +25,21 @@ MAX_MAX_NEW_TOKENS = 4096
25
  DEFAULT_MAX_NEW_TOKENS = 4096
26
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
27
 
28
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29
-
30
  # quantization_config = BitsAndBytesConfig(
31
  # load_in_4bit=True,
32
  # bnb_4bit_compute_dtype=torch.bfloat16,
33
  # bnb_4bit_use_double_quant=True,
34
  # bnb_4bit_quant_type= "nf4")
35
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
36
 
37
  model_id = "CardinalOperations/ORLM-LLaMA-3-8B"
38
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
39
  model = AutoModelForCausalLM.from_pretrained(
40
  model_id,
41
  device_map="auto",
42
- # torch_dtype=torch.bfloat16,
43
- # attn_implementation="flash_attention_2",
44
- quantization_config=quantization_config,
45
  )
46
  model.eval()
47
 
@@ -63,7 +61,7 @@ def generate(
63
  input_ids = tokenized_example.input_ids
64
  input_ids = input_ids.to(model.device)
65
 
66
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
67
  generate_kwargs = dict(
68
  {"input_ids": input_ids},
69
  streamer=streamer,
 
25
  DEFAULT_MAX_NEW_TOKENS = 4096
26
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
27
 
 
 
28
  # quantization_config = BitsAndBytesConfig(
29
  # load_in_4bit=True,
30
  # bnb_4bit_compute_dtype=torch.bfloat16,
31
  # bnb_4bit_use_double_quant=True,
32
  # bnb_4bit_quant_type= "nf4")
33
+ # quantization_config = BitsAndBytesConfig(load_in_8bit=True)
34
 
35
  model_id = "CardinalOperations/ORLM-LLaMA-3-8B"
36
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
37
  model = AutoModelForCausalLM.from_pretrained(
38
  model_id,
39
  device_map="auto",
40
+ torch_dtype=torch.bfloat16,
41
+ attn_implementation="flash_attention_2",
42
+ # quantization_config=quantization_config,
43
  )
44
  model.eval()
45
 
 
61
  input_ids = tokenized_example.input_ids
62
  input_ids = input_ids.to(model.device)
63
 
64
+ streamer = TextIteratorStreamer(tokenizer, timeout=50.0, skip_prompt=True, skip_special_tokens=True)
65
  generate_kwargs = dict(
66
  {"input_ids": input_ids},
67
  streamer=streamer,