Baicai003 commited on
Commit
8a4c63b
·
1 Parent(s): c9cbbfd

Create README.MD

Browse files
Files changed (1) hide show
  1. README.MD +72 -0
README.MD ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: openrail
3
+ datasets:
4
+ - shareAI/ShareGPT-Chinese-English-90k
5
+ - shareAI/CodeChat
6
+ language:
7
+ - en
8
+ library_name: transformers
9
+ tags:
10
+ - code
11
+ ---
12
+
13
+ Code:
14
+ (just run it, and the model weights will be auto download)
15
+
16
+ Github:https://github.com/CrazyBoyM/CodeLLaMA-chat
17
+
18
+ ```
19
+ # from Firefly
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer
21
+ import torch
22
+
23
+
24
+ def main():
25
+ model_name = 'shareAI/CodeLLaMA-chat-13b-Chinese'
26
+
27
+ device = 'cuda'
28
+ max_new_tokens = 500 # max token for reply.
29
+ history_max_len = 1000 # max token in history
30
+ top_p = 0.9
31
+ temperature = 0.35
32
+ repetition_penalty = 1.0
33
+
34
+ model = AutoModelForCausalLM.from_pretrained(
35
+ model_name,
36
+ trust_remote_code=True,
37
+ low_cpu_mem_usage=True,
38
+ torch_dtype=torch.float16,
39
+ device_map='auto'
40
+ ).to(device).eval()
41
+ tokenizer = AutoTokenizer.from_pretrained(
42
+ model_name,
43
+ trust_remote_code=True,
44
+ use_fast=False
45
+ )
46
+
47
+
48
+ history_token_ids = torch.tensor([[]], dtype=torch.long)
49
+
50
+ user_input = input('User:')
51
+ while True:
52
+ input_ids = tokenizer(user_input, return_tensors="pt", add_special_tokens=False).input_ids
53
+ eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long)
54
+ user_input_ids = torch.concat([input_ids, eos_token_id], dim=1)
55
+ history_token_ids = torch.concat((history_token_ids, user_input_ids), dim=1)
56
+ model_input_ids = history_token_ids[:, -history_max_len:].to(device)
57
+ with torch.no_grad():
58
+ outputs = model.generate(
59
+ input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p,
60
+ temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id
61
+ )
62
+ model_input_ids_len = model_input_ids.size(1)
63
+ response_ids = outputs[:, model_input_ids_len:]
64
+ history_token_ids = torch.concat((history_token_ids, response_ids.cpu()), dim=1)
65
+ response = tokenizer.batch_decode(response_ids)
66
+ print("Bot:" + response[0].strip().replace(tokenizer.eos_token, ""))
67
+ user_input = input('User:')
68
+
69
+
70
+ if __name__ == '__main__':
71
+ main()
72
+ ```