Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -62,7 +62,8 @@ def process_text(text):
|
|
62 |
|
63 |
return text
|
64 |
|
65 |
-
|
|
|
66 |
@spaces.GPU
|
67 |
def generate(
|
68 |
message: str,
|
@@ -84,34 +85,54 @@ def generate(
|
|
84 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
85 |
|
86 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
|
87 |
-
generate_kwargs =
|
88 |
-
|
89 |
-
streamer
|
90 |
-
max_new_tokens
|
91 |
-
do_sample
|
92 |
-
top_p
|
93 |
-
top_k
|
94 |
-
temperature
|
95 |
-
num_beams
|
96 |
-
repetition_penalty
|
97 |
-
|
98 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
99 |
t.start()
|
100 |
|
101 |
outputs = []
|
|
|
102 |
for text in streamer:
|
103 |
processed_text = process_text(text)
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
final_story = "".join(outputs)
|
|
|
|
|
|
|
|
|
|
|
109 |
try:
|
110 |
-
saved_story = Story(message=message, content=
|
111 |
-
yield f"{
|
112 |
except Exception as e:
|
113 |
yield f"Failed to save story: {str(e)}"
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
# Gradio Interface Setup
|
116 |
chat_interface = gr.ChatInterface(
|
117 |
fn=generate,
|
|
|
62 |
|
63 |
return text
|
64 |
|
65 |
+
import re
|
66 |
+
|
67 |
@spaces.GPU
|
68 |
def generate(
|
69 |
message: str,
|
|
|
85 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
86 |
|
87 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
|
88 |
+
generate_kwargs = {
|
89 |
+
"input_ids": input_ids,
|
90 |
+
"streamer": streamer,
|
91 |
+
"max_new_tokens": max_new_tokens,
|
92 |
+
"do_sample": True,
|
93 |
+
"top_p": top_p,
|
94 |
+
"top_k": top_k,
|
95 |
+
"temperature": temperature,
|
96 |
+
"num_beams": 1,
|
97 |
+
"repetition_penalty": repetition_penalty,
|
98 |
+
}
|
99 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
100 |
t.start()
|
101 |
|
102 |
outputs = []
|
103 |
+
last_sentence_buffer = ""
|
104 |
for text in streamer:
|
105 |
processed_text = process_text(text)
|
106 |
+
sentences = re.split(r'(?<=\.)\s', processed_text)
|
107 |
+
|
108 |
+
if len(sentences) > 1:
|
109 |
+
# Join all but the last sentence and buffer the last one
|
110 |
+
ready_to_stream = "".join(sentences[:-1])
|
111 |
+
if last_sentence_buffer:
|
112 |
+
yield last_sentence_buffer + ready_to_stream
|
113 |
+
last_sentence_buffer = sentences[-1]
|
114 |
+
else:
|
115 |
+
# No full sentences yet, buffer everything
|
116 |
+
last_sentence_buffer += processed_text
|
117 |
|
118 |
final_story = "".join(outputs)
|
119 |
+
if last_sentence_buffer:
|
120 |
+
final_story += last_sentence_buffer # Add the last buffer if it's a complete sentence
|
121 |
+
|
122 |
+
# Optional: Save the final story without the last sentence
|
123 |
+
final_story_trimmed = remove_last_sentence(final_story)
|
124 |
try:
|
125 |
+
saved_story = Story(message=message, content=final_story_trimmed).save()
|
126 |
+
yield f"{final_story_trimmed}\n\n Story saved with ID: {saved_story.story_id}"
|
127 |
except Exception as e:
|
128 |
yield f"Failed to save story: {str(e)}"
|
129 |
|
130 |
+
def remove_last_sentence(text):
|
131 |
+
# Split sentences and remove the last one
|
132 |
+
sentences = re.split(r'(?<=\.)\s', text)
|
133 |
+
return ' '.join(sentences[:-1]) if sentences else text
|
134 |
+
|
135 |
+
|
136 |
# Gradio Interface Setup
|
137 |
chat_interface = gr.ChatInterface(
|
138 |
fn=generate,
|