Spaces:
Runtime error
Runtime error
[update]edit main
Browse files
main.py
CHANGED
@@ -53,8 +53,6 @@ examples = [
|
|
53 |
def main():
|
54 |
args = get_args()
|
55 |
|
56 |
-
use_cpu = os.environ.get("USE_CPU", "all")
|
57 |
-
|
58 |
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, trust_remote_code=True)
|
59 |
# QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|>
|
60 |
if tokenizer.__class__.__name__ == "QWenTokenizer":
|
@@ -62,27 +60,21 @@ def main():
|
|
62 |
tokenizer.bos_token_id = tokenizer.eod_id
|
63 |
tokenizer.eos_token_id = tokenizer.eod_id
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
torch_dtype=torch.bfloat16,
|
76 |
-
device_map="auto",
|
77 |
-
offload_folder="./offload",
|
78 |
-
offload_state_dict=True,
|
79 |
-
# load_in_4bit=True,
|
80 |
-
)
|
81 |
model = model.eval()
|
82 |
|
83 |
-
def fn(
|
84 |
input_ids = tokenizer(
|
85 |
-
|
86 |
return_tensors="pt",
|
87 |
add_special_tokens=False,
|
88 |
).input_ids.to(args.device)
|
@@ -97,20 +89,28 @@ def main():
|
|
97 |
response = response.strip().replace(tokenizer.eos_token, "").strip()
|
98 |
return response
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
with gr.Blocks() as blocks:
|
101 |
gr.Markdown(value=description)
|
102 |
|
103 |
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=400)
|
104 |
with gr.Row():
|
105 |
with gr.Column(scale=4):
|
106 |
-
|
107 |
with gr.Column(scale=1):
|
108 |
button = gr.Button("Generate")
|
109 |
|
110 |
-
gr.Examples(examples,
|
111 |
|
112 |
-
|
113 |
-
button.click(fn, [
|
114 |
|
115 |
blocks.queue().launch()
|
116 |
|
|
|
53 |
def main():
|
54 |
args = get_args()
|
55 |
|
|
|
|
|
56 |
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, trust_remote_code=True)
|
57 |
# QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|>
|
58 |
if tokenizer.__class__.__name__ == "QWenTokenizer":
|
|
|
60 |
tokenizer.bos_token_id = tokenizer.eod_id
|
61 |
tokenizer.eos_token_id = tokenizer.eod_id
|
62 |
|
63 |
+
model = AutoModelForCausalLM.from_pretrained(
|
64 |
+
args.pretrained_model_name_or_path,
|
65 |
+
trust_remote_code=True,
|
66 |
+
low_cpu_mem_usage=True,
|
67 |
+
torch_dtype=torch.float16,
|
68 |
+
device_map="auto",
|
69 |
+
offload_folder="./offload",
|
70 |
+
offload_state_dict=True,
|
71 |
+
# load_in_4bit=True,
|
72 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
model = model.eval()
|
74 |
|
75 |
+
def fn(text: str):
|
76 |
input_ids = tokenizer(
|
77 |
+
text,
|
78 |
return_tensors="pt",
|
79 |
add_special_tokens=False,
|
80 |
).input_ids.to(args.device)
|
|
|
89 |
response = response.strip().replace(tokenizer.eos_token, "").strip()
|
90 |
return response
|
91 |
|
92 |
+
def fn_stream(text: str,
|
93 |
+
max_new_tokens: int = 200,
|
94 |
+
top_p: float = 0.85,
|
95 |
+
temperature: float = 0.35,
|
96 |
+
repetition_penalty: float = 1.2
|
97 |
+
):
|
98 |
+
return
|
99 |
+
|
100 |
with gr.Blocks() as blocks:
|
101 |
gr.Markdown(value=description)
|
102 |
|
103 |
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=400)
|
104 |
with gr.Row():
|
105 |
with gr.Column(scale=4):
|
106 |
+
input_text_box = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
|
107 |
with gr.Column(scale=1):
|
108 |
button = gr.Button("Generate")
|
109 |
|
110 |
+
gr.Examples(examples, input_text_box)
|
111 |
|
112 |
+
input_text_box.submit(fn, [input_text_box], [chatbot])
|
113 |
+
button.click(fn, [input_text_box], [chatbot])
|
114 |
|
115 |
blocks.queue().launch()
|
116 |
|