manu commited on
Commit
7ff4f07
·
1 Parent(s): f9b1082

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +82 -1
README.md CHANGED
@@ -43,4 +43,85 @@ https://twitter.com/ManuelFaysse/status/1706949891358859624
43
 
44
  ### Usage
45
 
46
- Let's figure this out together in the discussion ! Probably the same conversion scripts from llama2 (non-hf) versions should be used !
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  ### Usage
45
 
46
+ Probably something like the Llama2 models from the non-hf release ! Obviously, the model is probably not completely similar to Llama, so conversion to HF will not be so direct.
47
+ Let's figure this out together !
48
+
49
+ ```bash
50
+ torchrun --nproc_per_node 1 example_text_completion.py \
51
+ --ckpt_dir llama-2-7b/ \
52
+ --tokenizer_path tokenizer.model \
53
+ --max_seq_len 128 --max_batch_size 4
54
+ ```
55
+
56
+ `example_text_completion.py`
57
+ ```python
58
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
59
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
60
+
61
+ import fire
62
+
63
+ from llama import Llama
64
+ from typing import List
65
+
66
+ def main(
67
+ ckpt_dir: str,
68
+ tokenizer_path: str,
69
+ temperature: float = 0.6,
70
+ top_p: float = 0.9,
71
+ max_seq_len: int = 128,
72
+ max_gen_len: int = 64,
73
+ max_batch_size: int = 4,
74
+ ):
75
+ """
76
+ Entry point of the program for generating text using a pretrained model.
77
+
78
+ Args:
79
+ ckpt_dir (str): The directory containing checkpoint files for the pretrained model.
80
+ tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding.
81
+ temperature (float, optional): The temperature value for controlling randomness in generation.
82
+ Defaults to 0.6.
83
+ top_p (float, optional): The top-p sampling parameter for controlling diversity in generation.
84
+ Defaults to 0.9.
85
+ max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 128.
86
+ max_gen_len (int, optional): The maximum length of generated sequences. Defaults to 64.
87
+ max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 4.
88
+ """
89
+ generator = Llama.build(
90
+ ckpt_dir=ckpt_dir,
91
+ tokenizer_path=tokenizer_path,
92
+ max_seq_len=max_seq_len,
93
+ max_batch_size=max_batch_size,
94
+ )
95
+
96
+ prompts: List[str] = [
97
+ # For these prompts, the expected answer is the natural continuation of the prompt
98
+ "I believe the meaning of life is",
99
+ "Simply put, the theory of relativity states that ",
100
+ """A brief message congratulating the team on the launch:
101
+
102
+ Hi everyone,
103
+
104
+ I just """,
105
+ # Few shot prompt (providing a few examples before asking model to complete more);
106
+ """Translate English to French:
107
+
108
+ sea otter => loutre de mer
109
+ peppermint => menthe poivrée
110
+ plush girafe => girafe peluche
111
+ cheese =>""",
112
+ ]
113
+ results = generator.text_completion(
114
+ prompts,
115
+ max_gen_len=max_gen_len,
116
+ temperature=temperature,
117
+ top_p=top_p,
118
+ )
119
+ for prompt, result in zip(prompts, results):
120
+ print(prompt)
121
+ print(f"> {result['generation']}")
122
+ print("\n==================================\n")
123
+
124
+
125
+ if __name__ == "__main__":
126
+ fire.Fire(main)
127
+ ```