[bug]
Browse files- .gitignore +2 -0
- main.py +12 -5
.gitignore
CHANGED
|
@@ -4,3 +4,5 @@
|
|
| 4 |
|
| 5 |
**/flagged/
|
| 6 |
**/__pycache__/
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
**/flagged/
|
| 6 |
**/__pycache__/
|
| 7 |
+
|
| 8 |
+
trained_models/
|
main.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
import argparse
|
| 4 |
from collections import defaultdict
|
| 5 |
import os
|
|
|
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
from threading import Thread
|
|
@@ -71,7 +72,6 @@ def main():
|
|
| 71 |
input_ids = torch.tensor([input_ids], dtype=torch.long)
|
| 72 |
input_ids = input_ids.to(device)
|
| 73 |
|
| 74 |
-
output: str = ""
|
| 75 |
streamer = TextIteratorStreamer(tokenizer=tokenizer)
|
| 76 |
|
| 77 |
generation_kwargs = dict(
|
|
@@ -88,17 +88,24 @@ def main():
|
|
| 88 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 89 |
thread.start()
|
| 90 |
|
|
|
|
|
|
|
| 91 |
for output_ in streamer:
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
| 94 |
output_ = output_.replace("[SEP]", "\n")
|
| 95 |
output_ = output_.replace("[UNK]", "")
|
| 96 |
-
output_ = output_.replace(
|
| 97 |
|
| 98 |
output += output_.strip()
|
| 99 |
output_text_box.value += output
|
| 100 |
yield output
|
| 101 |
|
|
|
|
|
|
|
| 102 |
demo = gr.Interface(
|
| 103 |
fn=fn_stream,
|
| 104 |
inputs=[
|
|
@@ -107,7 +114,7 @@ def main():
|
|
| 107 |
gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"),
|
| 108 |
gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"),
|
| 109 |
gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"),
|
| 110 |
-
gr.Dropdown(choices=
|
| 111 |
gr.Checkbox(value=True, label="is_chat")
|
| 112 |
],
|
| 113 |
outputs=[output_text_box],
|
|
|
|
| 3 |
import argparse
|
| 4 |
from collections import defaultdict
|
| 5 |
import os
|
| 6 |
+
import platform
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
from threading import Thread
|
|
|
|
| 72 |
input_ids = torch.tensor([input_ids], dtype=torch.long)
|
| 73 |
input_ids = input_ids.to(device)
|
| 74 |
|
|
|
|
| 75 |
streamer = TextIteratorStreamer(tokenizer=tokenizer)
|
| 76 |
|
| 77 |
generation_kwargs = dict(
|
|
|
|
| 88 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 89 |
thread.start()
|
| 90 |
|
| 91 |
+
output: str = ""
|
| 92 |
+
first_answer = True
|
| 93 |
for output_ in streamer:
|
| 94 |
+
if first_answer:
|
| 95 |
+
first_answer = False
|
| 96 |
+
continue
|
| 97 |
+
# output_ = output_.replace(text, "")
|
| 98 |
+
# output_ = output_.replace("[CLS]", "")
|
| 99 |
output_ = output_.replace("[SEP]", "\n")
|
| 100 |
output_ = output_.replace("[UNK]", "")
|
| 101 |
+
output_ = output_.replace(" ", "")
|
| 102 |
|
| 103 |
output += output_.strip()
|
| 104 |
output_text_box.value += output
|
| 105 |
yield output
|
| 106 |
|
| 107 |
+
model_name_choices = ["trained_models/lib_service_4chan"] \
|
| 108 |
+
if platform.system() == "Windows" else ["qgyd2021/lib_service_4chan"]
|
| 109 |
demo = gr.Interface(
|
| 110 |
fn=fn_stream,
|
| 111 |
inputs=[
|
|
|
|
| 114 |
gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"),
|
| 115 |
gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"),
|
| 116 |
gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"),
|
| 117 |
+
gr.Dropdown(choices=model_name_choices, value=model_name_choices[0], label="model_name"),
|
| 118 |
gr.Checkbox(value=True, label="is_chat")
|
| 119 |
],
|
| 120 |
outputs=[output_text_box],
|