Update utils.py
Browse files
utils.py
CHANGED
@@ -10,6 +10,7 @@ from langchain.llms import HuggingFacePipeline
|
|
10 |
from sentence_transformers import SentenceTransformer, util
|
11 |
from langchain.chains.question_answering import load_qa_chain
|
12 |
from transformers import StoppingCriteria, StoppingCriteriaList
|
|
|
13 |
|
14 |
|
15 |
hf_auth = os.getenv('hf_auth')
|
@@ -18,25 +19,23 @@ model_id = 'google-t5/t5-base'
|
|
18 |
|
19 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
20 |
|
21 |
-
|
22 |
-
model_id,
|
23 |
-
use_auth_token=hf_auth
|
24 |
-
)
|
25 |
-
|
26 |
-
llm_model = transformers.AutoModelForCausalLM.from_pretrained(
|
27 |
-
model_id,
|
28 |
-
trust_remote_code=True,
|
29 |
-
config=model_config,
|
30 |
-
device_map='auto',
|
31 |
-
use_auth_token=hf_auth
|
32 |
-
)
|
33 |
|
34 |
-
#
|
35 |
-
|
|
|
|
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
)
|
41 |
|
42 |
stop_list = ['\nHuman:', '\n```\n']
|
|
|
10 |
from sentence_transformers import SentenceTransformer, util
|
11 |
from langchain.chains.question_answering import load_qa_chain
|
12 |
from transformers import StoppingCriteria, StoppingCriteriaList
|
13 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration, pipeline
|
14 |
|
15 |
|
16 |
hf_auth = os.getenv('hf_auth')
|
|
|
19 |
|
20 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
21 |
|
22 |
+
tokenizer = T5Tokenizer.from_pretrained(model_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
# Load model
|
25 |
+
llm_model = T5ForConditionalGeneration.from_pretrained(model_id)
|
26 |
+
llm_model.to(device)
|
27 |
+
llm_model.eval() # Set model to evaluation mode
|
28 |
|
29 |
+
# Define the text generation pipeline
|
30 |
+
generate_text = pipeline(
|
31 |
+
'text2text-generation',
|
32 |
+
model=llm_model,
|
33 |
+
tokenizer=tokenizer,
|
34 |
+
device=0 if torch.cuda.is_available() else -1,
|
35 |
+
return_full_text=True, # Ensure the full text is returned
|
36 |
+
temperature=0.1, # Control the randomness of the output
|
37 |
+
max_length=512, # Maximum length of the generated sequence
|
38 |
+
repetition_penalty=1.1 # Penalty to prevent repetition
|
39 |
)
|
40 |
|
41 |
stop_list = ['\nHuman:', '\n```\n']
|