qgyd2021 commited on
Commit
ab4a005
·
1 Parent(s): 6a8e34d

[update]edit main

Browse files
Files changed (1) hide show
  1. main.py +24 -24
main.py CHANGED
@@ -53,8 +53,6 @@ examples = [
53
  def main():
54
  args = get_args()
55
 
56
- use_cpu = os.environ.get("USE_CPU", "all")
57
-
58
  tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, trust_remote_code=True)
59
  # QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|>
60
  if tokenizer.__class__.__name__ == "QWenTokenizer":
@@ -62,27 +60,21 @@ def main():
62
  tokenizer.bos_token_id = tokenizer.eod_id
63
  tokenizer.eos_token_id = tokenizer.eod_id
64
 
65
- if not use_cpu:
66
- model = AutoModel.from_pretrained(
67
- args.pretrained_model_name_or_path,
68
- trust_remote_code=True
69
- ).half().cuda()
70
- else:
71
- model = AutoModelForCausalLM.from_pretrained(
72
- args.pretrained_model_name_or_path,
73
- trust_remote_code=True,
74
- low_cpu_mem_usage=True,
75
- torch_dtype=torch.bfloat16,
76
- device_map="auto",
77
- offload_folder="./offload",
78
- offload_state_dict=True,
79
- # load_in_4bit=True,
80
- )
81
  model = model.eval()
82
 
83
- def fn(inputs: str):
84
  input_ids = tokenizer(
85
- inputs,
86
  return_tensors="pt",
87
  add_special_tokens=False,
88
  ).input_ids.to(args.device)
@@ -97,20 +89,28 @@ def main():
97
  response = response.strip().replace(tokenizer.eos_token, "").strip()
98
  return response
99
 
 
 
 
 
 
 
 
 
100
  with gr.Blocks() as blocks:
101
  gr.Markdown(value=description)
102
 
103
  chatbot = gr.Chatbot([], elem_id="chatbot").style(height=400)
104
  with gr.Row():
105
  with gr.Column(scale=4):
106
- text = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
107
  with gr.Column(scale=1):
108
  button = gr.Button("Generate")
109
 
110
- gr.Examples(examples, text)
111
 
112
- text.submit(fn, [text], [chatbot])
113
- button.click(fn, [text], [chatbot])
114
 
115
  blocks.queue().launch()
116
 
 
53
  def main():
54
  args = get_args()
55
 
 
 
56
  tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, trust_remote_code=True)
57
  # QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|>
58
  if tokenizer.__class__.__name__ == "QWenTokenizer":
 
60
  tokenizer.bos_token_id = tokenizer.eod_id
61
  tokenizer.eos_token_id = tokenizer.eod_id
62
 
63
+ model = AutoModelForCausalLM.from_pretrained(
64
+ args.pretrained_model_name_or_path,
65
+ trust_remote_code=True,
66
+ low_cpu_mem_usage=True,
67
+ torch_dtype=torch.float16,
68
+ device_map="auto",
69
+ offload_folder="./offload",
70
+ offload_state_dict=True,
71
+ # load_in_4bit=True,
72
+ )
 
 
 
 
 
 
73
  model = model.eval()
74
 
75
+ def fn(text: str):
76
  input_ids = tokenizer(
77
+ text,
78
  return_tensors="pt",
79
  add_special_tokens=False,
80
  ).input_ids.to(args.device)
 
89
  response = response.strip().replace(tokenizer.eos_token, "").strip()
90
  return response
91
 
92
+ def fn_stream(text: str,
93
+ max_new_tokens: int = 200,
94
+ top_p: float = 0.85,
95
+ temperature: float = 0.35,
96
+ repetition_penalty: float = 1.2
97
+ ):
98
+ return
99
+
100
  with gr.Blocks() as blocks:
101
  gr.Markdown(value=description)
102
 
103
  chatbot = gr.Chatbot([], elem_id="chatbot").style(height=400)
104
  with gr.Row():
105
  with gr.Column(scale=4):
106
+ input_text_box = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
107
  with gr.Column(scale=1):
108
  button = gr.Button("Generate")
109
 
110
+ gr.Examples(examples, input_text_box)
111
 
112
+ input_text_box.submit(fn, [input_text_box], [chatbot])
113
+ button.click(fn, [input_text_box], [chatbot])
114
 
115
  blocks.queue().launch()
116