MWilinski commited on
Commit
8559b37
·
1 Parent(s): 58526f3
Files changed (1) hide show
  1. qa_engine/qa_engine.py +13 -8
qa_engine/qa_engine.py CHANGED
@@ -15,6 +15,7 @@ from langchain.llms.base import LLM
15
  from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings, HuggingFaceInstructEmbeddings
16
  from langchain.vectorstores import FAISS
17
  from sentence_transformers import CrossEncoder
 
18
 
19
  from qa_engine import logger
20
  from qa_engine.response import Response
@@ -63,14 +64,18 @@ class TransformersPipelineModel(LLM):
63
  self.model_id = model_id
64
 
65
  tokenizer = AutoTokenizer.from_pretrained(model_id)
66
- model = AutoModelForCausalLM.from_pretrained(
67
- model_id,
68
- torch_dtype=torch.bfloat16,
69
- trust_remote_code=True,
70
- load_in_8bit=False,
71
- device_map='auto',
72
- resume_download=True,
73
- )
 
 
 
 
74
  self.pipeline = transformers.pipeline(
75
  'text-generation',
76
  model=model,
 
15
  from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings, HuggingFaceInstructEmbeddings
16
  from langchain.vectorstores import FAISS
17
  from sentence_transformers import CrossEncoder
18
+ from awq import AutoAWQForCausalLM
19
 
20
  from qa_engine import logger
21
  from qa_engine.response import Response
 
64
  self.model_id = model_id
65
 
66
  tokenizer = AutoTokenizer.from_pretrained(model_id)
67
+ if "AWQ" in model_id:
68
+ model = AutoAWQForCausalLM.from_quantized(model_id, fuse_layers=True,
69
+ trust_remote_code=False, safetensors=True)
70
+ else:
71
+ model = AutoModelForCausalLM.from_pretrained(
72
+ model_id,
73
+ torch_dtype=torch.bfloat16,
74
+ trust_remote_code=True,
75
+ load_in_8bit=False,
76
+ device_map='auto',
77
+ resume_download=True,
78
+ )
79
  self.pipeline = transformers.pipeline(
80
  'text-generation',
81
  model=model,