JanPf commited on
Commit
f81501c
·
verified ·
1 Parent(s): 7a189da

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +56 -1
README.md CHANGED
@@ -21,4 +21,59 @@ license: other
21
  # LLäMmlein 1B Chat
22
 
23
  This is a chat adapter for the German Tinyllama 1B language model.
24
- Find more details on our [page](https://www.informatik.uni-wuerzburg.de/datascience/projects/nlp/llammlein/) and our [preprint](arxiv.org/abs/2411.11171)!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # LLäMmlein 1B Chat
22
 
23
  This is a chat adapter for the German Tinyllama 1B language model.
24
+ Find more details on our [page](https://www.informatik.uni-wuerzburg.de/datascience/projects/nlp/llammlein/) and our [preprint](arxiv.org/abs/2411.11171)!
25
+
26
+ ## Run it
27
+ ```py
28
+ import torch
29
+ from peft import PeftConfig, PeftModel
30
+ from transformers import AutoModelForCausalLM, AutoTokenizer
31
+
32
+ torch.manual_seed(42)
33
+
34
+ # script config
35
+ base_model_name = "LSX-UniWue/llammchen_1b"
36
+ chat_adapter_name = "LSX-UniWue/LLaMmlein_1B_chat_all"
37
+ device = "mps" # or cuda
38
+
39
+ # chat history
40
+ messages = [
41
+ {
42
+ "role": "user",
43
+ "content": """Na wie geht's?""",
44
+ },
45
+ ]
46
+
47
+ # load model
48
+ config = PeftConfig.from_pretrained(chat_adapter_name)
49
+ base_model = model = AutoModelForCausalLM.from_pretrained(
50
+ base_model_name,
51
+ attn_implementation="flash_attention_2" if device == "cuda" else None,
52
+ torch_dtype=torch.bfloat16,
53
+ device_map=device,
54
+ )
55
+ base_model.resize_token_embeddings(32064)
56
+ model = PeftModel.from_pretrained(base_model, chat_adapter_name)
57
+ tokenizer = AutoTokenizer.from_pretrained(chat_adapter_name)
58
+
59
+ # encode message in "ChatML" format
60
+ chat = tokenizer.apply_chat_template(
61
+ messages,
62
+ return_tensors="pt",
63
+ add_generation_prompt=True,
64
+ ).to(device)
65
+
66
+ # generate response
67
+ print(
68
+ tokenizer.decode(
69
+ model.generate(
70
+ chat,
71
+ max_new_tokens=300,
72
+ pad_token_id=tokenizer.pad_token_id,
73
+ eos_token_id=tokenizer.eos_token_id,
74
+ )[0],
75
+ skip_special_tokens=False,
76
+ )
77
+ )
78
+
79
+ ```