kz209 commited on
Commit
c89910e
1 Parent(s): 05479ce
pages/summarization_playground.py CHANGED
@@ -58,7 +58,7 @@ def generate_answer(sources, model_name, prompt):
58
 
59
  content = [prompt + '\n{' + sources + '}\n\nsummary:']
60
 
61
- answer = model[model_name].gen(content, streaming=False)
62
 
63
  return answer
64
 
 
58
 
59
  content = [prompt + '\n{' + sources + '}\n\nsummary:']
60
 
61
+ answer = model[model_name].gen(content)
62
 
63
  return answer
64
 
utils/model.py CHANGED
@@ -53,56 +53,59 @@ class Model(torch.nn.Module):
53
  return self.tokenizer
54
 
55
  def return_model(self):
56
- return self.pipeline
57
-
58
- def gen(self, content_list, temp=0.001, max_length=500, streaming=False):
59
  # Convert list of texts to input IDs
60
  input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
61
 
62
- if streaming:
63
- # Set up the initial generation parameters
64
- gen_kwargs = {
65
- "input_ids": input_ids,
66
- "do_sample": True,
67
- "temperature": temp,
68
- "eos_token_id": self.tokenizer.eos_token_id,
69
- "max_new_tokens": 1, # Generate one token at a time
70
- "return_dict_in_generate": True,
71
- "output_scores": True
72
- }
73
-
74
- # Generate and yield tokens one by one
75
- generated_tokens = 0
76
- batch_size = input_ids.shape[0]
77
- active_sequences = torch.arange(batch_size)
78
-
79
- while generated_tokens < max_length and len(active_sequences) > 0:
80
- with torch.no_grad():
81
- output = self.model.generate(**gen_kwargs)
82
-
83
- next_tokens = output.sequences[:, -1].unsqueeze(-1)
84
-
85
- # Yield the newly generated tokens for each sequence in the batch
86
- for i, token in zip(active_sequences, next_tokens):
87
- yield i, self.tokenizer.decode(token[0], skip_special_tokens=True)
88
-
89
- # Update input_ids for the next iteration
90
- gen_kwargs["input_ids"] = torch.cat([gen_kwargs["input_ids"], next_tokens], dim=-1)
91
- generated_tokens += 1
92
-
93
- # Check for completed sequences
94
- completed = (next_tokens.squeeze(-1) == self.tokenizer.eos_token_id).nonzero().squeeze(-1)
95
- active_sequences = torch.tensor([i for i in active_sequences if i not in completed])
96
- if len(active_sequences) > 0:
97
- gen_kwargs["input_ids"] = gen_kwargs["input_ids"][active_sequences]
 
 
 
 
98
 
99
- else:
100
- # Non-streaming generation (unchanged)
101
- outputs = self.model.generate(
102
- input_ids,
103
- max_new_tokens=max_length,
104
- do_sample=True,
105
- temperature=temp,
106
- eos_token_id=self.tokenizer.eos_token_id,
107
- )
108
- return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
 
53
  return self.tokenizer
54
 
55
  def return_model(self):
56
+ return self.model
57
+
58
+ def streaming(self, content_list, temp=0.001, max_length=500):
59
  # Convert list of texts to input IDs
60
  input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
61
 
62
+ # Set up the initial generation parameters
63
+ gen_kwargs = {
64
+ "input_ids": input_ids,
65
+ "do_sample": True,
66
+ "temperature": temp,
67
+ "eos_token_id": self.tokenizer.eos_token_id,
68
+ "max_new_tokens": 1, # Generate one token at a time
69
+ "return_dict_in_generate": True,
70
+ "output_scores": True
71
+ }
72
+
73
+ # Generate and yield tokens one by one
74
+ generated_tokens = 0
75
+ batch_size = input_ids.shape[0]
76
+ active_sequences = torch.arange(batch_size)
77
+
78
+ while generated_tokens < max_length and len(active_sequences) > 0:
79
+ with torch.no_grad():
80
+ output = self.model.generate(**gen_kwargs)
81
+
82
+ next_tokens = output.sequences[:, -1].unsqueeze(-1)
83
+
84
+ # Yield the newly generated tokens for each sequence in the batch
85
+ for i, token in zip(active_sequences, next_tokens):
86
+ yield i, self.tokenizer.decode(token[0], skip_special_tokens=True)
87
+
88
+ # Update input_ids for the next iteration
89
+ gen_kwargs["input_ids"] = torch.cat([gen_kwargs["input_ids"], next_tokens], dim=-1)
90
+ generated_tokens += 1
91
+
92
+ # Check for completed sequences
93
+ completed = (next_tokens.squeeze(-1) == self.tokenizer.eos_token_id).nonzero().squeeze(-1)
94
+ active_sequences = torch.tensor([i for i in active_sequences if i not in completed])
95
+ if len(active_sequences) > 0:
96
+ gen_kwargs["input_ids"] = gen_kwargs["input_ids"][active_sequences]
97
+
98
+
99
+ def gen(self, content_list, temp=0.001, max_length=500):
100
+ # Convert list of texts to input IDs
101
+ input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
102
 
103
+ # Non-streaming generation (unchanged)
104
+ outputs = self.model.generate(
105
+ input_ids,
106
+ max_new_tokens=max_length,
107
+ do_sample=True,
108
+ temperature=temp,
109
+ eos_token_id=self.tokenizer.eos_token_id,
110
+ )
111
+ return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
 
utils/multiple_stream.py CHANGED
@@ -20,7 +20,7 @@ def stream_data(content_list, model):
20
  outputs = ["" for _ in content_list]
21
 
22
  # Use the gen method to handle batch generation
23
- generator = model.gen(content_list, streaming=True)
24
 
25
  while True:
26
  updated = False
 
20
  outputs = ["" for _ in content_list]
21
 
22
  # Use the gen method to handle batch generation
23
+ generator = model.streaming(content_list)
24
 
25
  while True:
26
  updated = False