genaforvena commited on
Commit
1445a16
·
1 Parent(s): 6581160
Files changed (2) hide show
  1. app.py +56 -36
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import GPT2LMHeadModel, AutoTokenizer, pipeline
2
  import torch
3
  import gradio as gr
4
 
@@ -20,66 +20,86 @@ def generate_text_stream(model, tokenizer, prompt, max_new_tokens, temperature):
20
  inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
21
  generated_tokens = 0
22
  output_text = ""
23
-
24
  while generated_tokens < max_new_tokens:
25
- # Generate one token at a time
26
  outputs = model.generate(
27
  inputs,
28
- max_new_tokens=1, # Generate one token at a time
29
  do_sample=True,
30
  top_p=0.95,
31
  top_k=50,
32
  temperature=temperature,
33
  pad_token_id=tokenizer.eos_token_id,
34
  )
35
- # Decode the new token
36
  new_token = tokenizer.decode(outputs[0, -1], skip_special_tokens=True)
37
  output_text += new_token
38
  generated_tokens += 1
39
-
40
- # Yield the updated output text
41
  yield output_text
42
-
43
- # Update inputs for the next iteration
44
  inputs = outputs
45
 
46
- # Function to summarize text
47
- def summarize(text, summarizer, max_length, min_length):
48
- summary = summarizer(text, max_length=max_length, min_length=min_length, do_sample=False)
49
- return summary[0]['summary_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def reply(prompt):
52
  if len(prompt) == 0:
53
  prompt = "class holocaust"
54
- # Stream output from the first model (deleuze)
55
- output1 = ""
56
- for text in generate_text_stream(deleuze, tokenizer1, prompt, max_new_tokens=500, temperature=0.9):
57
- output1 = text
58
- yield output1
59
 
60
- # Stream output from the second model (scum)
61
- output2 = ""
62
- for text in generate_text_stream(scum, tokenizer2, output1, max_new_tokens=200, temperature=1.7):
63
- output2 = text
64
- yield output2
65
 
66
- # Stream output from the third model (gospel)
67
- output3 = ""
68
- for text in generate_text_stream(gospel, tokenizer3, prompt, max_new_tokens=200, temperature=1.0):
69
- output3 = text
70
- yield output3
71
 
72
- # Combine outputs for summarization
73
- final_output = output2 + " " + output3
74
 
75
- # Initialize summarizer
76
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
 
 
 
77
 
78
- # Generate the final summary
79
- summary = summarize(final_output, summarizer, max_length=500, min_length=150)
 
 
 
80
 
81
- # Display the final summary
82
- yield summary
 
 
83
 
84
  # Gradio interface
85
  iface = gr.Interface(fn=reply, inputs="text", outputs="text")
 
1
+ from transformers import GPT2LMHeadModel, AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
3
  import gradio as gr
4
 
 
20
  inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
21
  generated_tokens = 0
22
  output_text = ""
 
23
  while generated_tokens < max_new_tokens:
 
24
  outputs = model.generate(
25
  inputs,
26
+ max_new_tokens=1,
27
  do_sample=True,
28
  top_p=0.95,
29
  top_k=50,
30
  temperature=temperature,
31
  pad_token_id=tokenizer.eos_token_id,
32
  )
 
33
  new_token = tokenizer.decode(outputs[0, -1], skip_special_tokens=True)
34
  output_text += new_token
35
  generated_tokens += 1
 
 
36
  yield output_text
 
 
37
  inputs = outputs
38
 
39
+ # Load BART model and tokenizer for summarization
40
+ summarization_model_name = "facebook/bart-large-cnn"
41
+ summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_name).to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # Use GPU if available
42
+ summarizer_tokenizer = AutoTokenizer.from_pretrained(summarization_model_name)
43
+
44
+ # Function to generate summary with manual streaming
45
+ def generate_summary_stream(model, tokenizer, text, max_length, min_length):
46
+ inputs = tokenizer(text, return_tensors="pt", max_length=1024, truncation=True).to(model.device)
47
+ generated_tokens = 0
48
+ summary_text = ""
49
+ decoder_start_token_id = model.config.decoder_start_token_id
50
+ decoder_input_ids = torch.tensor([[decoder_start_token_id]], device=model.device)
51
+ past_key_values = None
52
+ while True :
53
+ outputs = model.generate(
54
+ inputs,
55
+ decoder_input_ids=decoder_input_ids,
56
+ max_new_tokens=1,
57
+ min_length=min_length,
58
+ do_sample=False,
59
+ past_key_values=past_key_values,
60
+ output_hidden_states=True,
61
+ return_dict_in_generate=True
62
+ )
63
+ next_token_id = outputs.sequences[0][-1]
64
+ if next_token_id == tokenizer.eos_token_id or len(decoder_input_ids[0]) >= max_length:
65
+ break
66
+ next_token = tokenizer.decode(next_token_id, skip_special_tokens=True)
67
+ summary_text += next_token
68
+ yield summary_text
69
+ decoder_input_ids = torch.cat([decoder_input_ids, torch.tensor([[next_token_id]], device=model.device)], dim=-1)
70
+ past_key_values = outputs.past_key_values
71
 
72
  def reply(prompt):
73
  if len(prompt) == 0:
74
  prompt = "class holocaust"
 
 
 
 
 
75
 
76
+ # --- Phase 1: Generate and Stream Combined Output ---
77
+ combined_output = ""
78
+ scum_output = ""
79
+ gospel_output = ""
 
80
 
81
+ # Stream deleuze output
82
+ for text in generate_text_stream(deleuze, tokenizer1, prompt, max_new_tokens=500, temperature=0.9):
83
+ combined_output = text
84
+ yield combined_output
 
85
 
 
 
86
 
87
+ # Stream scum output (appending to the existing combined output)
88
+ for text in generate_text_stream(scum, tokenizer3, combined_output, max_new_tokens=200, temperature=1.7):
89
+ scum_output = text
90
+ combined_output = text
91
+ yield combined_output
92
 
93
+ # Stream gospel output (appending to the existing combined output)
94
+ for text in generate_text_stream(gospel, tokenizer2, prompt, max_new_tokens=200, temperature=1.0):
95
+ gospel_output = text
96
+ combined_output = text
97
+ yield combined_output
98
 
99
+ # --- Phase 2: Generate and Stream Summary (Replacing Combined Output) ---
100
+ final_output_for_summary = scum_output + " " + gospel_output # Use scum and gospel only for summarization
101
+ for text in generate_summary_stream(summarizer_model, summarizer_tokenizer, final_output_for_summary, max_length=500, min_length=150):
102
+ yield text
103
 
104
  # Gradio interface
105
  iface = gr.Interface(fn=reply, inputs="text", outputs="text")
requirements.txt CHANGED
@@ -2,4 +2,4 @@ huggingface_hub==0.25.2
2
  gradio
3
  transformers
4
  bs4
5
- torch
 
2
  gradio
3
  transformers
4
  bs4
5
+ torch