zhuohaoyu commited on
Commit
944baa0
·
1 Parent(s): 32e4312
Files changed (1) hide show
  1. app.py +8 -36
app.py CHANGED
@@ -36,40 +36,21 @@ from transformers.generation import GenerationConfig
36
  DEFAULT_CKPT_PATH = 'WisdomShell/CodeShell-7B-Chat'
37
 
38
 
39
- def _get_args():
40
- parser = ArgumentParser()
41
- parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
42
- help="Checkpoint name or path, default to %(default)r")
43
- parser.add_argument("--device", type=str, default="cuda:0", help="GPU device.")
44
-
45
- parser.add_argument("--share", action="store_true", default=False,
46
- help="Create a publicly shareable link for the interface.")
47
- parser.add_argument("--inbrowser", action="store_true", default=False,
48
- help="Automatically launch the interface in a new tab on the default browser.")
49
- parser.add_argument("--server-port", type=int, default=8000,
50
- help="Demo server port.")
51
- parser.add_argument("--server-name", type=str, default="127.0.0.1",
52
- help="Demo server name.")
53
-
54
- args = parser.parse_args()
55
- return args
56
-
57
-
58
  def _load_model_tokenizer(args):
59
  tokenizer = AutoTokenizer.from_pretrained(
60
- args.checkpoint_path, trust_remote_code=True, resume_download=True,
61
  )
62
 
63
  model = AutoModelForCausalLM.from_pretrained(
64
- args.checkpoint_path,
65
- device_map=args.device,
66
  trust_remote_code=True,
67
  resume_download=True,
68
  torch_dtype=torch.bfloat16
69
  ).eval()
70
 
71
  config = GenerationConfig.from_pretrained(
72
- args.checkpoint_path, trust_remote_code=True, resume_download=True,
73
  )
74
 
75
  return model, tokenizer, config
@@ -188,21 +169,12 @@ including hate speech, violence, pornography, deception, etc. \
188
  (注:本演示受CodeShell的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
189
  包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""")
190
 
191
- demo.queue().launch(
192
- share=args.share,
193
- inbrowser=args.inbrowser,
194
- server_port=args.server_port,
195
- server_name=args.server_name,
196
- )
197
-
198
 
199
- def main():
200
- args = _get_args()
201
 
202
- model, tokenizer, config = _load_model_tokenizer(args)
203
 
204
- _launch_demo(args, model, tokenizer, config)
205
 
 
206
 
207
- if __name__ == '__main__':
208
- main()
 
36
  DEFAULT_CKPT_PATH = 'WisdomShell/CodeShell-7B-Chat'
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def _load_model_tokenizer(args):
40
  tokenizer = AutoTokenizer.from_pretrained(
41
+ 'WisdomShell/CodeShell-7B-Chat', trust_remote_code=True, resume_download=True,
42
  )
43
 
44
  model = AutoModelForCausalLM.from_pretrained(
45
+ 'WisdomShell/CodeShell-7B-Chat',
46
+ device_map='cpu',
47
  trust_remote_code=True,
48
  resume_download=True,
49
  torch_dtype=torch.bfloat16
50
  ).eval()
51
 
52
  config = GenerationConfig.from_pretrained(
53
+ 'WisdomShell/CodeShell-7B-Chat', trust_remote_code=True, resume_download=True,
54
  )
55
 
56
  return model, tokenizer, config
 
169
  (注:本演示受CodeShell的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
170
  包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""")
171
 
172
+ demo.queue().launch()
 
 
 
 
 
 
173
 
 
 
174
 
175
+ args = {}
176
 
177
+ model, tokenizer, config = _load_model_tokenizer(args)
178
 
179
+ _launch_demo(args, model, tokenizer, config)
180