ranamhamoud commited on
Commit
ce9212e
·
verified ·
1 Parent(s): 2db8be7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -17
app.py CHANGED
@@ -62,7 +62,8 @@ def process_text(text):
62
 
63
  return text
64
 
65
- # Gradio Function
 
66
  @spaces.GPU
67
  def generate(
68
  message: str,
@@ -84,34 +85,54 @@ def generate(
84
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
85
 
86
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
87
- generate_kwargs = dict(
88
- {"input_ids": input_ids},
89
- streamer=streamer,
90
- max_new_tokens=max_new_tokens,
91
- do_sample=True,
92
- top_p=top_p,
93
- top_k=top_k,
94
- temperature=temperature,
95
- num_beams=1,
96
- repetition_penalty=repetition_penalty,
97
- )
98
  t = Thread(target=model.generate, kwargs=generate_kwargs)
99
  t.start()
100
 
101
  outputs = []
 
102
  for text in streamer:
103
  processed_text = process_text(text)
104
- outputs.append(processed_text)
105
- output = "".join(outputs)
106
- yield output
 
 
 
 
 
 
 
 
107
 
108
  final_story = "".join(outputs)
 
 
 
 
 
109
  try:
110
- saved_story = Story(message=message, content=final_story).save()
111
- yield f"{final_story}\n\n Story saved with ID: {saved_story.story_id}"
112
  except Exception as e:
113
  yield f"Failed to save story: {str(e)}"
114
 
 
 
 
 
 
 
115
  # Gradio Interface Setup
116
  chat_interface = gr.ChatInterface(
117
  fn=generate,
 
62
 
63
  return text
64
 
65
+ import re
66
+
67
  @spaces.GPU
68
  def generate(
69
  message: str,
 
85
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
86
 
87
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
88
+ generate_kwargs = {
89
+ "input_ids": input_ids,
90
+ "streamer": streamer,
91
+ "max_new_tokens": max_new_tokens,
92
+ "do_sample": True,
93
+ "top_p": top_p,
94
+ "top_k": top_k,
95
+ "temperature": temperature,
96
+ "num_beams": 1,
97
+ "repetition_penalty": repetition_penalty,
98
+ }
99
  t = Thread(target=model.generate, kwargs=generate_kwargs)
100
  t.start()
101
 
102
  outputs = []
103
+ last_sentence_buffer = ""
104
  for text in streamer:
105
  processed_text = process_text(text)
106
+ sentences = re.split(r'(?<=\.)\s', processed_text)
107
+
108
+ if len(sentences) > 1:
109
+ # Join all but the last sentence and buffer the last one
110
+ ready_to_stream = "".join(sentences[:-1])
111
+ if last_sentence_buffer:
112
+ yield last_sentence_buffer + ready_to_stream
113
+ last_sentence_buffer = sentences[-1]
114
+ else:
115
+ # No full sentences yet, buffer everything
116
+ last_sentence_buffer += processed_text
117
 
118
  final_story = "".join(outputs)
119
+ if last_sentence_buffer:
120
+ final_story += last_sentence_buffer # Add the last buffer if it's a complete sentence
121
+
122
+ # Optional: Save the final story without the last sentence
123
+ final_story_trimmed = remove_last_sentence(final_story)
124
  try:
125
+ saved_story = Story(message=message, content=final_story_trimmed).save()
126
+ yield f"{final_story_trimmed}\n\n Story saved with ID: {saved_story.story_id}"
127
  except Exception as e:
128
  yield f"Failed to save story: {str(e)}"
129
 
130
+ def remove_last_sentence(text):
131
+ # Split sentences and remove the last one
132
+ sentences = re.split(r'(?<=\.)\s', text)
133
+ return ' '.join(sentences[:-1]) if sentences else text
134
+
135
+
136
  # Gradio Interface Setup
137
  chat_interface = gr.ChatInterface(
138
  fn=generate,