mgbam's picture
Add untracked files and synchronize with remote
9c7387c
import matplotlib.pyplot as plt
import io
import base64
def generate_federated_learning_plot(client_accuracies):
"""
Generates a plot showing the training accuracy of each client in a federated learning setting.
This is a placeholder. You'll need to integrate it with your actual FL framework
and store the client accuracies during training.
"""
# Assuming client_accuracies is a dictionary of client_id: accuracy
client_ids = list(client_accuracies.keys())
accuracies = list(client_accuracies.values())
plt.figure(figsize=(10, 6))
plt.bar(client_ids, accuracies, color='skyblue')
plt.xlabel('Client ID')
plt.ylabel('Accuracy')
plt.title('Federated Learning: Client Accuracies')
plt.ylim(0, 1) # Assuming accuracy is between 0 and 1
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
# Convert plot to base64 image
img_buf = io.BytesIO()
plt.savefig(img_buf, format='png')
img_buf.seek(0)
img_data = base64.b64encode(img_buf.read()).decode('utf-8')
plt.close() # Close the plot to free memory
return f'<img src="data:image/png;base64,{img_data}" alt="Federated Learning Plot"/>'