bhuvanmdev commited on
Commit
7b0e091
1 Parent(s): 39fd1db

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +44 -1
README.md CHANGED
@@ -1,6 +1,12 @@
1
  ---
2
  library_name: peft
3
  base_model: google-t5/t5-base
 
 
 
 
 
 
4
  ---
5
 
6
  # Model Card for Model ID
@@ -39,7 +45,44 @@ base_model: google-t5/t5-base
39
 
40
  ### Direct Use
41
 
42
- <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  [More Information Needed]
45
 
 
1
  ---
2
  library_name: peft
3
  base_model: google-t5/t5-base
4
+ license: apache-2.0
5
+ language:
6
+ - en
7
+ - ja
8
+ - ar
9
+ pipeline_tag: text2text-generation
10
  ---
11
 
12
  # Model Card for Model ID
 
45
 
46
  ### Direct Use
47
 
48
+ `from peft import PeftModel
49
+
50
+ model_id = 'google-t5/t5-base'
51
+
52
+ bnb_config = BitsAndBytesConfig(
53
+ load_in_4bit=True,
54
+ load_4bit_use_double_quant=True,
55
+ bnb_4bit_quant_type="nf4",
56
+ bnb_4bit_compute_dtype=torch.bfloat16,
57
+ )
58
+
59
+ original_model = AutoModelForSeq2SeqLM.from_pretrained(model_id,quantization_config=bnb_config,device_map='auto')
60
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
61
+
62
+ tokenizer.pad_token = tokenizer.eos_token
63
+
64
+ peft_model = PeftModel.from_pretrained(original_model, "bhuvanmdev/t5-base-news-describer")
65
+
66
+ generation_config = peft_model.generation_config
67
+ generation_config.do_sample = True
68
+ generation_config.max_new_tokens = 100 # maxium no of token in output will get
69
+ generation_config.temperature = 0.1
70
+ generation_config.top_p = 0.8
71
+ generation_config.num_return_sequences = 1
72
+ generation_config.pad_token_id = tokenizer.eos_token_id
73
+ generation_config.eos_token_id = tokenizer.eos_token_id
74
+ generation_config.use_cache = True
75
+
76
+ prompt = f"""Title: A big accidient occurs in luxemberg.""".strip()
77
+
78
+ encoding = tokenizer(prompt, return_tensors="pt").to(device)
79
+ with torch.inference_mode():
80
+ outputs = peft_model.generate(
81
+ input_ids=encoding.input_ids,
82
+ attention_mask=encoding.attention_mask,
83
+ generation_config=generation_config,
84
+ )
85
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))`
86
 
87
  [More Information Needed]
88