from PIL import Image
import numpy as np
from fastapi.responses import StreamingResponse
from io import BytesIO
from matplotlib import pyplot
import matplotlib

class FaceGen:
    def __init__(self,model,device):
        self.model = model
        self.device = device

    def generate_latent_points(self,latent_dim,n_samples,seed):
        z_input = seed.reshape(n_samples,latent_dim)
        return z_input

    def plot_generated(self, examples, n):
        buf = BytesIO()
        for i in range(n * n):
            pyplot.subplot(n, n, 1 + i)
            pyplot.axis('off')
            pyplot.imshow(examples[i, :, :])
        pyplot.savefig(buf, format='png', transparent=True)
        buf.seek(0)
        return StreamingResponse(buf, media_type="image/png")

    def generate_image(self, latent_dim, n_samples, seed):
        latent_points = self.generate_latent_points(latent_dim, n_samples, seed)
        generated_images = self.model.predict(latent_points)
        X = (generated_images + 1) / 2.0
        return self.plot_generated(X, 1)