Spaces:
Sleeping
Sleeping
import json | |
import random | |
import torch | |
from distributed_training.data.dataset import DataLoader | |
from huggingface_hub import list_repo_refs | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
device = "cuda" | |
test_indices_length = 10 | |
models = ["distributed/optimized-gpt2-250m", "distributed/optimized-gpt2-250m-v0.1.1", "distributed/gpt2-94m"] | |
with open('results.json', 'r') as file: | |
results = json.load(file) | |
for model_name in models: | |
if (model_name not in results.keys()): | |
results[model_name] = {} | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
refs = list_repo_refs(model_name, repo_type="model") | |
global_epoch = max([int(tag.name) for tag in refs.tags]) if refs.tags else None | |
for epoch in range(0, global_epoch): | |
if str(epoch) in results[model_name].keys(): | |
continue | |
model = AutoModelForCausalLM.from_pretrained(model_name, revision=str(epoch), trust_remote_code=True) | |
model = model.to(device) | |
search_start = random.choice( | |
range( | |
DataLoader.max_pages | |
- test_indices_length | |
+ 1 | |
) | |
) | |
group = [ | |
i | |
for i in range( | |
search_start, search_start + test_indices_length | |
) | |
] | |
dataloader = DataLoader( | |
batch_size=1, | |
sequence_length=1024, | |
rows=group, | |
) | |
total_loss = 0 | |
index = 0 | |
# Train data for one epoch | |
for index, batch in enumerate(dataloader): | |
inputs = batch[0].to(device) | |
labels = batch[1].to(device) | |
if (len(inputs[0]) != len(labels[0])): | |
breakpoint() | |
if "optimized" in model_name: | |
outputs = model(input_ids=inputs, labels=labels) | |
loss = outputs[1] | |
else: | |
outputs = model(input_ids=inputs, labels=inputs) | |
loss = outputs.loss | |
# Accumulate Total Loss | |
total_loss += loss.detach().item() | |
# Backward Pass | |
model.zero_grad() | |
average_loss = total_loss / (index+1) | |
results[model_name][str(epoch)] = [average_loss] | |
print(f"Epoch: {epoch} Average Loss: {average_loss:.2f}") | |
with open("results.json", "w") as outfile: | |
json.dump(results, outfile, indent = 4) |