from PIL import Image import numpy as np from fastapi.responses import StreamingResponse from io import BytesIO from matplotlib import pyplot import matplotlib class FaceGen: def __init__(self,model,device): self.model = model self.device = device def generate_latent_points(self,latent_dim,n_samples,seed): z_input = seed.reshape(n_samples,latent_dim) return z_input def plot_generated(self, examples, n): buf = BytesIO() for i in range(n * n): pyplot.subplot(n, n, 1 + i) pyplot.axis('off') pyplot.imshow(examples[i, :, :]) pyplot.savefig(buf, format='png', transparent=True) buf.seek(0) return StreamingResponse(buf, media_type="image/png") def generate_image(self, latent_dim, n_samples, seed): latent_points = self.generate_latent_points(latent_dim, n_samples, seed) generated_images = self.model.predict(latent_points) X = (generated_images + 1) / 2.0 return self.plot_generated(X, 1)