|
import os |
|
import json |
|
|
|
def create_graph(lora_path, lora_name): |
|
try: |
|
import matplotlib.pyplot as plt |
|
from matplotlib.ticker import ScalarFormatter |
|
|
|
peft_model_path = f'{lora_path}/training_graph.json' |
|
image_model_path = f'{lora_path}/training_graph.png' |
|
|
|
if os.path.exists(peft_model_path): |
|
|
|
with open(peft_model_path, 'r') as file: |
|
data = json.load(file) |
|
|
|
x = [item['epoch'] for item in data] |
|
y1 = [item['learning_rate'] for item in data] |
|
y2 = [item['loss'] for item in data] |
|
|
|
|
|
fig, ax1 = plt.subplots(figsize=(10, 6)) |
|
|
|
|
|
|
|
ax1.plot(x, y1, 'b-', label='Learning Rate') |
|
ax1.set_xlabel('Epoch') |
|
ax1.set_ylabel('Learning Rate', color='b') |
|
ax1.tick_params('y', colors='b') |
|
|
|
|
|
ax2 = ax1.twinx() |
|
|
|
|
|
ax2.plot(x, y2, 'r-', label='Loss') |
|
ax2.set_ylabel('Loss', color='r') |
|
ax2.tick_params('y', colors='r') |
|
|
|
|
|
ax1.yaxis.set_major_formatter(ScalarFormatter(useMathText=True)) |
|
ax1.ticklabel_format(style='sci', axis='y', scilimits=(0,0)) |
|
|
|
|
|
ax1.grid(True) |
|
|
|
|
|
lines, labels = ax1.get_legend_handles_labels() |
|
lines2, labels2 = ax2.get_legend_handles_labels() |
|
ax2.legend(lines + lines2, labels + labels2, loc='best') |
|
|
|
|
|
plt.title(f'{lora_name} LR and Loss vs Epoch') |
|
|
|
|
|
plt.savefig(image_model_path) |
|
|
|
print(f"Graph saved in {image_model_path}") |
|
else: |
|
print(f"File 'training_graph.json' does not exist in the {lora_path}") |
|
|
|
except ImportError: |
|
print("matplotlib is not installed. Please install matplotlib to create PNG graphs") |