Harsh2001 commited on
Commit
b722c79
·
verified ·
1 Parent(s): 185e9a0

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +16 -17
utils.py CHANGED
@@ -10,6 +10,7 @@ from langchain.llms import HuggingFacePipeline
10
  from sentence_transformers import SentenceTransformer, util
11
  from langchain.chains.question_answering import load_qa_chain
12
  from transformers import StoppingCriteria, StoppingCriteriaList
 
13
 
14
 
15
  hf_auth = os.getenv('hf_auth')
@@ -18,25 +19,23 @@ model_id = 'google-t5/t5-base'
18
 
19
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
 
21
- model_config = transformers.AutoConfig.from_pretrained(
22
- model_id,
23
- use_auth_token=hf_auth
24
- )
25
-
26
- llm_model = transformers.AutoModelForCausalLM.from_pretrained(
27
- model_id,
28
- trust_remote_code=True,
29
- config=model_config,
30
- device_map='auto',
31
- use_auth_token=hf_auth
32
- )
33
 
34
- # enable evaluation mode to allow model inference
35
- model.eval()
 
 
36
 
37
- tokenizer = transformers.AutoTokenizer.from_pretrained(
38
- model_id,
39
- use_auth_token=hf_auth
 
 
 
 
 
 
 
40
  )
41
 
42
  stop_list = ['\nHuman:', '\n```\n']
 
10
  from sentence_transformers import SentenceTransformer, util
11
  from langchain.chains.question_answering import load_qa_chain
12
  from transformers import StoppingCriteria, StoppingCriteriaList
13
+ from transformers import T5Tokenizer, T5ForConditionalGeneration, pipeline
14
 
15
 
16
  hf_auth = os.getenv('hf_auth')
 
19
 
20
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
 
22
+ tokenizer = T5Tokenizer.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # Load model
25
+ llm_model = T5ForConditionalGeneration.from_pretrained(model_id)
26
+ llm_model.to(device)
27
+ llm_model.eval() # Set model to evaluation mode
28
 
29
+ # Define the text generation pipeline
30
+ generate_text = pipeline(
31
+ 'text2text-generation',
32
+ model=llm_model,
33
+ tokenizer=tokenizer,
34
+ device=0 if torch.cuda.is_available() else -1,
35
+ return_full_text=True, # Ensure the full text is returned
36
+ temperature=0.1, # Control the randomness of the output
37
+ max_length=512, # Maximum length of the generated sequence
38
+ repetition_penalty=1.1 # Penalty to prevent repetition
39
  )
40
 
41
  stop_list = ['\nHuman:', '\n```\n']