ranamhamoud commited on
Commit
1e69bad
·
verified ·
1 Parent(s): ce9212e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -31
app.py CHANGED
@@ -61,9 +61,7 @@ def process_text(text):
61
  text = re.sub(r'\[.*?\]', '', text, flags=re.DOTALL)
62
 
63
  return text
64
-
65
- import re
66
-
67
  @spaces.GPU
68
  def generate(
69
  message: str,
@@ -85,42 +83,31 @@ def generate(
85
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
86
 
87
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
88
- generate_kwargs = {
89
- "input_ids": input_ids,
90
- "streamer": streamer,
91
- "max_new_tokens": max_new_tokens,
92
- "do_sample": True,
93
- "top_p": top_p,
94
- "top_k": top_k,
95
- "temperature": temperature,
96
- "num_beams": 1,
97
- "repetition_penalty": repetition_penalty,
98
- }
99
  t = Thread(target=model.generate, kwargs=generate_kwargs)
100
  t.start()
101
 
102
  outputs = []
103
- last_sentence_buffer = ""
104
  for text in streamer:
105
  processed_text = process_text(text)
106
- sentences = re.split(r'(?<=\.)\s', processed_text)
107
-
108
- if len(sentences) > 1:
109
- # Join all but the last sentence and buffer the last one
110
- ready_to_stream = "".join(sentences[:-1])
111
- if last_sentence_buffer:
112
- yield last_sentence_buffer + ready_to_stream
113
- last_sentence_buffer = sentences[-1]
114
- else:
115
- # No full sentences yet, buffer everything
116
- last_sentence_buffer += processed_text
117
 
118
  final_story = "".join(outputs)
119
- if last_sentence_buffer:
120
- final_story += last_sentence_buffer # Add the last buffer if it's a complete sentence
121
-
122
- # Optional: Save the final story without the last sentence
123
  final_story_trimmed = remove_last_sentence(final_story)
 
124
  try:
125
  saved_story = Story(message=message, content=final_story_trimmed).save()
126
  yield f"{final_story_trimmed}\n\n Story saved with ID: {saved_story.story_id}"
@@ -128,7 +115,7 @@ def generate(
128
  yield f"Failed to save story: {str(e)}"
129
 
130
  def remove_last_sentence(text):
131
- # Split sentences and remove the last one
132
  sentences = re.split(r'(?<=\.)\s', text)
133
  return ' '.join(sentences[:-1]) if sentences else text
134
 
 
61
  text = re.sub(r'\[.*?\]', '', text, flags=re.DOTALL)
62
 
63
  return text
64
+
 
 
65
  @spaces.GPU
66
  def generate(
67
  message: str,
 
83
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
84
 
85
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
86
+ generate_kwargs = dict(
87
+ {"input_ids": input_ids},
88
+ streamer=streamer,
89
+ max_new_tokens=max_new_tokens,
90
+ do_sample=True,
91
+ top_p=top_p,
92
+ top_k=top_k,
93
+ temperature=temperature,
94
+ num_beams=1,
95
+ repetition_penalty=repetition_penalty,
96
+ )
97
  t = Thread(target=model.generate, kwargs=generate_kwargs)
98
  t.start()
99
 
100
  outputs = []
 
101
  for text in streamer:
102
  processed_text = process_text(text)
103
+ outputs.append(processed_text)
104
+ output = "".join(outputs)
105
+ yield output
 
 
 
 
 
 
 
 
106
 
107
  final_story = "".join(outputs)
108
+ # Remove the last sentence
 
 
 
109
  final_story_trimmed = remove_last_sentence(final_story)
110
+
111
  try:
112
  saved_story = Story(message=message, content=final_story_trimmed).save()
113
  yield f"{final_story_trimmed}\n\n Story saved with ID: {saved_story.story_id}"
 
115
  yield f"Failed to save story: {str(e)}"
116
 
117
  def remove_last_sentence(text):
118
+ # Assuming sentences end with a period followed by space or end of string
119
  sentences = re.split(r'(?<=\.)\s', text)
120
  return ' '.join(sentences[:-1]) if sentences else text
121