Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -61,9 +61,7 @@ def process_text(text):
|
|
61 |
text = re.sub(r'\[.*?\]', '', text, flags=re.DOTALL)
|
62 |
|
63 |
return text
|
64 |
-
|
65 |
-
import re
|
66 |
-
|
67 |
@spaces.GPU
|
68 |
def generate(
|
69 |
message: str,
|
@@ -85,42 +83,31 @@ def generate(
|
|
85 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
86 |
|
87 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
|
88 |
-
generate_kwargs =
|
89 |
-
"input_ids": input_ids,
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
100 |
t.start()
|
101 |
|
102 |
outputs = []
|
103 |
-
last_sentence_buffer = ""
|
104 |
for text in streamer:
|
105 |
processed_text = process_text(text)
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
# Join all but the last sentence and buffer the last one
|
110 |
-
ready_to_stream = "".join(sentences[:-1])
|
111 |
-
if last_sentence_buffer:
|
112 |
-
yield last_sentence_buffer + ready_to_stream
|
113 |
-
last_sentence_buffer = sentences[-1]
|
114 |
-
else:
|
115 |
-
# No full sentences yet, buffer everything
|
116 |
-
last_sentence_buffer += processed_text
|
117 |
|
118 |
final_story = "".join(outputs)
|
119 |
-
|
120 |
-
final_story += last_sentence_buffer # Add the last buffer if it's a complete sentence
|
121 |
-
|
122 |
-
# Optional: Save the final story without the last sentence
|
123 |
final_story_trimmed = remove_last_sentence(final_story)
|
|
|
124 |
try:
|
125 |
saved_story = Story(message=message, content=final_story_trimmed).save()
|
126 |
yield f"{final_story_trimmed}\n\n Story saved with ID: {saved_story.story_id}"
|
@@ -128,7 +115,7 @@ def generate(
|
|
128 |
yield f"Failed to save story: {str(e)}"
|
129 |
|
130 |
def remove_last_sentence(text):
|
131 |
-
#
|
132 |
sentences = re.split(r'(?<=\.)\s', text)
|
133 |
return ' '.join(sentences[:-1]) if sentences else text
|
134 |
|
|
|
61 |
text = re.sub(r'\[.*?\]', '', text, flags=re.DOTALL)
|
62 |
|
63 |
return text
|
64 |
+
|
|
|
|
|
65 |
@spaces.GPU
|
66 |
def generate(
|
67 |
message: str,
|
|
|
83 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
84 |
|
85 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
|
86 |
+
generate_kwargs = dict(
|
87 |
+
{"input_ids": input_ids},
|
88 |
+
streamer=streamer,
|
89 |
+
max_new_tokens=max_new_tokens,
|
90 |
+
do_sample=True,
|
91 |
+
top_p=top_p,
|
92 |
+
top_k=top_k,
|
93 |
+
temperature=temperature,
|
94 |
+
num_beams=1,
|
95 |
+
repetition_penalty=repetition_penalty,
|
96 |
+
)
|
97 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
98 |
t.start()
|
99 |
|
100 |
outputs = []
|
|
|
101 |
for text in streamer:
|
102 |
processed_text = process_text(text)
|
103 |
+
outputs.append(processed_text)
|
104 |
+
output = "".join(outputs)
|
105 |
+
yield output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
final_story = "".join(outputs)
|
108 |
+
# Remove the last sentence
|
|
|
|
|
|
|
109 |
final_story_trimmed = remove_last_sentence(final_story)
|
110 |
+
|
111 |
try:
|
112 |
saved_story = Story(message=message, content=final_story_trimmed).save()
|
113 |
yield f"{final_story_trimmed}\n\n Story saved with ID: {saved_story.story_id}"
|
|
|
115 |
yield f"Failed to save story: {str(e)}"
|
116 |
|
117 |
def remove_last_sentence(text):
|
118 |
+
# Assuming sentences end with a period followed by space or end of string
|
119 |
sentences = re.split(r'(?<=\.)\s', text)
|
120 |
return ' '.join(sentences[:-1]) if sentences else text
|
121 |
|