File size: 2,424 Bytes
70e3e61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c681f7
70e3e61
 
 
 
 
 
 
 
 
 
 
 
1c681f7
 
 
 
 
 
 
 
 
70e3e61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c681f7
70e3e61
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# pylint: disable=line-too-long
"""Generate an image using the Stability AI API

Keyword arguments:
prompt -- The prompt to generate the image from
Return: An image saved in a .png file
"""

import os
import io
import warnings
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
from PIL import Image
from config import STABILITY_API_KEY


def generate_image_with_stability(prompt, seed=42, steps=50, cfg_scale=7.0, width=1024, height=1024, samples=1, api_key=STABILITY_API_KEY):
    """
    Generates an image based on the given prompt using Stability API.

    :param prompt: The prompt to generate the image from.
    :param seed: Seed for deterministic generation.
    :param steps: Number of inference steps.
    :param cfg_scale: CFG scale for prompt guidance.
    :param width: Width of the generated image.
    :param height: Height of the generated image.
    :param samples: Number of images to generate.
    :return: A PIL.Image object of the generated image.
    """
    os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'
    os.environ['STABILITY_KEY'] = api_key
    # Set up our connection to the Stability API.
    stability_api = client.StabilityInference(
        key=os.environ['STABILITY_KEY'],
        verbose=True,
        engine="stable-diffusion-xl-1024-v1-0",
        )

    print("Creating Stability Image...")
    answers = stability_api.generate(
        prompt=prompt,
        seed=seed,
        steps=steps,
        cfg_scale=cfg_scale,
        width=width,
        height=height,
        samples=samples,
        # sampler=generation.SAMPLER_K_DPMPP_2M # default: auto
    )

    # Retrieve and process the generated image
    for resp in answers:
        for artifact in resp.artifacts:
            if artifact.finish_reason == generation.FILTER:
                warnings.warn(
                    "Your request activated the API's safety filters and could not be processed."
                    "Please modify the prompt and try again.")
            if artifact.type == generation.ARTIFACT_IMAGE:
                # saving img:
                img = Image.open(io.BytesIO(artifact.binary))
                img.save("output_img/sd_generated_img.png")
                print("Image saved in output_img/sd_generated_img.png")
                return "output_img/sd_generated_img.png"

    raise ValueError("No image was generated.")