Transformers
English
shing12345 commited on
Commit
c7d0202
·
verified ·
1 Parent(s): e75c76c

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +125 -0
model.py CHANGED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-2.0
3
+ datasets:
4
+ - wikimedia/wikipedia
5
+ - open-llm-leaderboard-old/details_microsoft__DialoGPT-large
6
+ - li2017dailydialog/daily_dialog
7
+ - google/Synthetic-Persona-Chat
8
+ - kanhatakeyama/CommonCrawl-RAG-QA-Calm3-22b-chat
9
+ - yuyijiong/Multi-doc-QA-CommonCrawl
10
+ - Skylion007/openwebtext
11
+ - Bingsu/openwebtext_20p
12
+ - segyges/OpenWebText2
13
+ language:
14
+ - en
15
+ library_name: transformers
16
+ ---
17
+ import torch
18
+ from transformers import BartForConditionalGeneration, BartTokenizer, GPT2LMHeadModel, GPT2Tokenizer
19
+ import argparse
20
+ import sys
21
+
22
+ class AdvancedSummarizer:
23
+ def __init__(self, model_name="facebook/bart-large-cnn"):
24
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ self.model = BartForConditionalGeneration.from_pretrained(model_name).to(self.device)
26
+ self.tokenizer = BartTokenizer.from_pretrained(model_name)
27
+
28
+ def summarize(self, text, max_length=150, min_length=50, length_penalty=2.0, num_beams=4):
29
+ inputs = self.tokenizer([text], max_length=1024, return_tensors="pt", truncation=True)
30
+ inputs = inputs.to(self.device)
31
+
32
+ summary_ids = self.model.generate(
33
+ inputs["input_ids"],
34
+ num_beams=num_beams,
35
+ max_length=max_length,
36
+ min_length=min_length,
37
+ length_penalty=length_penalty
38
+ )
39
+
40
+ summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
41
+ return summary
42
+
43
+ def main_summarizer():
44
+ # Example usage
45
+ summarizer = AdvancedSummarizer()
46
+ text = """...""" # Your text here
47
+ summary = summarizer.summarize(text)
48
+ print("Summary:")
49
+ print(summary)
50
+
51
+ class AdvancedTextGenerator:
52
+ def __init__(self, model_name="gpt2-medium"):
53
+ try:
54
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+ print(f"Using device: {self.device}")
56
+ self.model = GPT2LMHeadModel.from_pretrained(model_name).to(self.device)
57
+ self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
58
+ except Exception as e:
59
+ print(f"Error initializing the model: {e}")
60
+ sys.exit(1)
61
+
62
+ def generate_text(self, prompt, max_length=100, num_return_sequences=1,
63
+ temperature=1.0, top_k=50, top_p=0.95, repetition_penalty=1.0):
64
+ try:
65
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
66
+
67
+ output_sequences = self.model.generate(
68
+ input_ids=input_ids,
69
+ max_length=max_length + len(input_ids[0]),
70
+ temperature=temperature,
71
+ top_k=top_k,
72
+ top_p=top_p,
73
+ repetition_penalty=repetition_penalty,
74
+ do_sample=True,
75
+ num_return_sequences=num_return_sequences,
76
+ )
77
+
78
+ generated_sequences = []
79
+ for generated_sequence in output_sequences:
80
+ text = self.tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
81
+ total_sequence = text[len(self.tokenizer.decode(input_ids[0], clean_up_tokenization_spaces=True)):]
82
+ generated_sequences.append(total_sequence)
83
+
84
+ return generated_sequences
85
+ except Exception as e:
86
+ return [f"Error during text generation: {e}"]
87
+
88
+ def main_generator():
89
+ parser = argparse.ArgumentParser(description="Advanced Text Generator")
90
+ parser.add_argument("--prompt", type=str, help="Starting prompt for text generation")
91
+ parser.add_argument("--max_length", type=int, default=100, help="Maximum length of generated text")
92
+ parser.add_argument("--num_sequences", type=int, default=1, help="Number of sequences to generate")
93
+ parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling")
94
+ parser.add_argument("--top_k", type=int, default=50, help="Top-k sampling parameter")
95
+ parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling parameter")
96
+ parser.add_argument("--repetition_penalty", type=float, default=1.0, help="Repetition penalty")
97
+
98
+ args = parser.parse_args()
99
+
100
+ generator = AdvancedTextGenerator()
101
+
102
+ if args.prompt:
103
+ prompt = args.prompt
104
+ else:
105
+ print("Please enter the prompt for text generation:")
106
+ prompt = input().strip()
107
+
108
+ generated_texts = generator.generate_text(
109
+ prompt,
110
+ max_length=args.max_length,
111
+ num_return_sequences=args.num_sequences,
112
+ temperature=args.temperature,
113
+ top_k=args.top_k,
114
+ top_p=args.top_p,
115
+ repetition_penalty=args.repetition_penalty
116
+ )
117
+
118
+ print("\nGenerated Text(s):")
119
+ for i, text in enumerate(generated_texts, 1):
120
+ print(f"\n--- Sequence {i} ---")
121
+ print(text)
122
+
123
+ if __name__ == "__main__":
124
+ main_summarizer() # Call the summarizer main function
125
+ main_generator() # Call the text generator main function