Committing all changes before LFS migration
Browse files- __pycache__/utils.cpython-311.pyc +0 -0
- chat_cli.py +39 -0
- merge_script.py +63 -0
- test.py +12 -0
- utils.py +17 -0
__pycache__/utils.cpython-311.pyc
ADDED
Binary file (1.09 kB). View file
|
|
chat_cli.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
|
4 |
+
def chat_with_model(model_path: str):
|
5 |
+
# Ensure CUDA is available and set the device to use the first GPU
|
6 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
7 |
+
print(f"Using device: {device}")
|
8 |
+
|
9 |
+
# Load the model and tokenizer
|
10 |
+
model = AutoModelForCausalLM.from_pretrained(model_path).to(device)
|
11 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
12 |
+
|
13 |
+
# Wrap the model with DataParallel to use multiple GPUs
|
14 |
+
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
|
15 |
+
print(f"Using {torch.cuda.device_count()} GPUs!")
|
16 |
+
model = torch.nn.DataParallel(model)
|
17 |
+
|
18 |
+
print("You're now chatting with the model. Type 'quit' to exit.")
|
19 |
+
|
20 |
+
while True:
|
21 |
+
# Get user input
|
22 |
+
input_text = input("You: ")
|
23 |
+
if input_text.lower() == 'quit':
|
24 |
+
break
|
25 |
+
|
26 |
+
# Encode the input text
|
27 |
+
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
28 |
+
|
29 |
+
# Generate a response
|
30 |
+
with torch.no_grad():
|
31 |
+
generated_text_samples = model.generate(input_ids, max_length=50, pad_token_id=tokenizer.eos_token_id)
|
32 |
+
|
33 |
+
# Decode and print the model's response
|
34 |
+
response_text = tokenizer.decode(generated_text_samples[0], skip_special_tokens=True)
|
35 |
+
print("AI:", response_text)
|
36 |
+
|
37 |
+
if __name__ == "__main__":
|
38 |
+
model_path = '/home/energyxadmin/UI2/merge'
|
39 |
+
chat_with_model(model_path)
|
merge_script.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from peft import PeftModel # Ensure you have 'peft' library or modify according to your setup
|
3 |
+
import os
|
4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
5 |
+
import argparse
|
6 |
+
from utils import get_logger # Ensure this is implemented in your environment
|
7 |
+
import json
|
8 |
+
|
9 |
+
logger = get_logger("merge", "info")
|
10 |
+
|
11 |
+
def smart_tokenizer_and_embedding_resize(tokenizer, model, custom_tokens_path=None):
|
12 |
+
"""Resize tokenizer and embedding to accommodate new tokens."""
|
13 |
+
special_tokens_dict = {
|
14 |
+
"pad_token": "[PAD]",
|
15 |
+
"eos_token": "</s>",
|
16 |
+
"bos_token": "<s>",
|
17 |
+
"unk_token": "<unk>"
|
18 |
+
}
|
19 |
+
|
20 |
+
# Load custom tokens if specified
|
21 |
+
custom_tokens = []
|
22 |
+
if custom_tokens_path is not None:
|
23 |
+
with open(custom_tokens_path, 'r') as file:
|
24 |
+
custom_tokens = [line.strip() for line in file.readlines()]
|
25 |
+
|
26 |
+
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
|
27 |
+
if custom_tokens:
|
28 |
+
num_added_toks += tokenizer.add_tokens(custom_tokens, special_tokens=True)
|
29 |
+
|
30 |
+
model.resize_token_embeddings(len(tokenizer))
|
31 |
+
logger.info(f"Resized tokenizer and model embeddings. Added {num_added_toks} tokens.")
|
32 |
+
|
33 |
+
def main():
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
parser.add_argument("-bm", "--base_model", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="Base model name or path")
|
36 |
+
parser.add_argument("-lm", "--lora_model", type=str, required=True, help="Path to the Lora model directory")
|
37 |
+
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory for the merged model")
|
38 |
+
parser.add_argument("--custom_tokens", type=str, default=None, help="Path to a file containing custom tokens")
|
39 |
+
args = parser.parse_args()
|
40 |
+
|
41 |
+
if not os.path.exists(args.lora_model):
|
42 |
+
raise FileNotFoundError(f"LoRA model directory {args.lora_model} not found.")
|
43 |
+
|
44 |
+
os.makedirs(args.output, exist_ok=True)
|
45 |
+
|
46 |
+
# Load the base model and tokenizer
|
47 |
+
model = AutoModelForCausalLM.from_pretrained(args.base_model)
|
48 |
+
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
|
49 |
+
|
50 |
+
# Adjust tokenizer and model for any additional tokens
|
51 |
+
smart_tokenizer_and_embedding_resize(tokenizer, model, args.custom_tokens)
|
52 |
+
|
53 |
+
# Load and merge the LoRA model
|
54 |
+
logger.info("Loading and merging the LoRA model...")
|
55 |
+
lora_model = PeftModel.from_pretrained(model, args.lora_model, merge_with_base=True)
|
56 |
+
|
57 |
+
# Save the merged model and tokenizer
|
58 |
+
lora_model.save_pretrained(args.output)
|
59 |
+
tokenizer.save_pretrained(args.output)
|
60 |
+
logger.info(f"Merged model saved to {args.output}")
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
main()
|
test.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2 |
+
|
3 |
+
model_path = '/home/energyxadmin/UI2/merge'
|
4 |
+
|
5 |
+
model = AutoModelForCausalLM.from_pretrained(model_path)
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
7 |
+
|
8 |
+
# Example text generation
|
9 |
+
input_ids = tokenizer.encode("What song did Eric Pask write or was a part of", return_tensors="pt")
|
10 |
+
generated_text_samples = model.generate(input_ids, max_length=1000)
|
11 |
+
|
12 |
+
print("Generated text:", tokenizer.decode(generated_text_samples[0], skip_special_tokens=True))
|
utils.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing_extensions import Literal
|
3 |
+
from rich.logging import RichHandler
|
4 |
+
|
5 |
+
|
6 |
+
def get_logger(name: str, level: Literal["info", "warning", "debug"]) -> logging.Logger:
|
7 |
+
rich_handler = RichHandler(level=logging.INFO, rich_tracebacks=True, markup=True)
|
8 |
+
|
9 |
+
logger = logging.getLogger(name)
|
10 |
+
logger.setLevel(logging._nameToLevel[level.upper()])
|
11 |
+
|
12 |
+
if not logger.handlers:
|
13 |
+
logger.addHandler(rich_handler)
|
14 |
+
|
15 |
+
logger.propagate = False
|
16 |
+
|
17 |
+
return logger
|