MWilinski commited on
Commit
a0fb532
·
1 Parent(s): a516e7a
Files changed (1) hide show
  1. qa_engine/qa_engine.py +2 -2
qa_engine/qa_engine.py CHANGED
@@ -66,7 +66,7 @@ class TransformersPipelineModel(LLM):
66
  tokenizer = AutoTokenizer.from_pretrained(model_id)
67
  if "AWQ" in model_id:
68
  model = AutoAWQForCausalLM.from_quantized(model_id, fuse_layers=True,
69
- trust_remote_code=False, safetensors=True).model
70
  else:
71
  model = AutoModelForCausalLM.from_pretrained(
72
  model_id,
@@ -80,7 +80,7 @@ class TransformersPipelineModel(LLM):
80
  'text-generation',
81
  model=model,
82
  tokenizer=tokenizer,
83
- torch_dtype=torch.float16,
84
  device_map='auto',
85
  eos_token_id=tokenizer.eos_token_id,
86
  pad_token_id=tokenizer.eos_token_id,
 
66
  tokenizer = AutoTokenizer.from_pretrained(model_id)
67
  if "AWQ" in model_id:
68
  model = AutoAWQForCausalLM.from_quantized(model_id, fuse_layers=True,
69
+ trust_remote_code=False, safetensors=True, torch_dtype=torch.bfloat16).model
70
  else:
71
  model = AutoModelForCausalLM.from_pretrained(
72
  model_id,
 
80
  'text-generation',
81
  model=model,
82
  tokenizer=tokenizer,
83
+ torch_dtype=torch.bfloat16,
84
  device_map='auto',
85
  eos_token_id=tokenizer.eos_token_id,
86
  pad_token_id=tokenizer.eos_token_id,