Severian commited on
Commit
25974dd
·
verified ·
1 Parent(s): 3dfb6a3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +37 -0
README.md CHANGED
@@ -33,6 +33,43 @@ Since this is a base model the IKM dataset greatly affects the output. The IKM d
33
  ```
34
  ---
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  ```
37
  [3731/5850 3:38:52 < 2:04:22, 0.28 it/s, Epoch 6.37/10]
38
  Step Training Loss
 
33
  ```
34
  ---
35
 
36
+ ## Inference
37
+
38
+ ```
39
+ !pip install -qqq transformers>=4.39.0 mamba-ssm causal-conv1d>=1.2.0 accelerate bitsandbytes --progress-bar off
40
+ !pip install flash-attn --no-build-isolation
41
+
42
+ import torch
43
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
44
+
45
+ # Load model in 8-bit precision
46
+ quantization_config = BitsAndBytesConfig(
47
+ load_in_8bit=True,
48
+ llm_int8_skip_modules=["mamba"]
49
+ )
50
+ model = AutoModelForCausalLM.from_pretrained(
51
+ "ai21labs/Jamba-v0.1",
52
+ trust_remote_code=True,
53
+ torch_dtype=torch.bfloat16,
54
+ attn_implementation="flash_attention_2",
55
+ quantization_config=quantization_config
56
+ )
57
+ tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
58
+
59
+ # Tokenize input
60
+ prompt = """How could we use cheese to reignite the sun? Answer:"""
61
+ input_ids = tokenizer(
62
+ prompt,
63
+ return_tensors='pt'
64
+ ).to(model.device)["input_ids"]
65
+
66
+ # Generate answer
67
+ outputs = model.generate(input_ids, max_new_tokens=216)
68
+
69
+ # Print output
70
+ print(tokenizer.batch_decode(outputs))
71
+ ```
72
+
73
  ```
74
  [3731/5850 3:38:52 < 2:04:22, 0.28 it/s, Epoch 6.37/10]
75
  Step Training Loss