File size: 3,763 Bytes
d90acf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fec3cbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a56b7ac
 
 
fec3cbb
 
 
 
 
 
 
 
 
 
a56b7ac
fec3cbb
d90acf0
fec3cbb
 
 
 
 
 
d90acf0
fec3cbb
 
d90acf0
 
fec3cbb
 
 
 
d90acf0
 
fec3cbb
 
d90acf0
 
fec3cbb
 
 
 
 
 
 
 
 
d90acf0
fec3cbb
 
 
 
d90acf0
fec3cbb
 
 
 
 
 
 
 
 
 
 
 
 
d90acf0
fec3cbb
 
d90acf0
fec3cbb
 
 
 
d90acf0
fec3cbb
 
d90acf0
fec3cbb
 
 
d90acf0
fec3cbb
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import sys
sys.path.append('Kandinsky-3')

import torch
from kandinsky3 import get_T2I_pipeline
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import base64
from io import BytesIO
from PIL import Image
import uvicorn

import time
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import base64
import requests
import spaces #[uncomment to use ZeroGPU]
from diffusers import DiffusionPipeline
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "kandinsky-community/kandinsky-3" #"stabilityai/sdxl-turbo" #Replace to the model you would like to use

if torch.cuda.is_available():
    torch_dtype = torch.float16
else:
    torch_dtype = torch.float32

pipe = DiffusionPipeline.from_pretrained(model_repo_id, variant="fp16", torch_dtype=torch_dtype)
pipe = pipe.to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

@spaces.GPU #[uncomment to use ZeroGPU]
def generate_image(prompt, width, height, guidance_scale=3, num_inference_steps=50, progress=gr.Progress(track_tqdm=True)):

    # generator = torch.Generator().manual_seed(0)

    image = pipe(
        prompt = prompt, 
        guidance_scale = guidance_scale, 
        num_inference_steps = num_inference_steps, 
        width = width, 
        height = height,
        generator = generator
    ).images[0] 
    
    return image


# device_map = torch.device('cuda:0')
# dtype_map = {
#     'unet': torch.float32,
#     'text_encoder': torch.float16,
#     'movq': torch.float32,
# }

# # Initialize the FastAPI app
# app = FastAPI()

# Define the request model
# class GenerateImageRequest(BaseModel):
#     prompt: str
#     width: Optional[int] = 1024
#     height: Optional[int] = 1024

# Define the response model
# class GenerateImageResponse(BaseModel):
#     image_base64: str

# Define the endpoint
# @app.post("/k31/", response_model=GenerateImageResponse)
# async def generate_image(request: GenerateImageRequest):
#     try:
#         # Generate the image using the pipeline
#         pil_image = t2i_pipe(request.prompt, width=request.width, height=request.height, steps=50)[0]

#         # Resize the image if necessary
#         if pil_image.size != (request.width, request.height):
#             pil_image = pil_image.resize((request.width, request.height))
        
#         # Convert the PIL image to base64
#         buffered = BytesIO()
#         pil_image.save(buffered, format="PNG")
#         image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
        
#         # Return the response
#         return GenerateImageResponse(image_base64=image_base64)

#     except Exception as e:
#         raise HTTPException(status_code=500, detail=str(e))

# def api_k31_generate(prompt, width=1024, height=1024, url = "http://0.0.0.0:8188/k31/"):
#     # Define the text message and image parameters
#     data = {
#         "prompt": prompt,
#         "width": width,
#         "height": height
#     }
    
#     # Send the POST request
#     response = requests.post(url, json=data)
    
#     # Check if the request was successful
#     if response.status_code == 200:
#         # Extract the base64 encoded image from the response
#         image_base64 = response.json()["image_base64"]
        
#         # You can further process the image here, for example, decode it from base64
#         decoded_image = Image.open(BytesIO(base64.b64decode(image_base64)))
        
#         return decoded_image
#     else:
#         print("Error:", response.text)
        
# # Run the FastAPI app
# if __name__ == "__main__":
#     t2i_pipe = get_T2I_pipeline(
#         device_map, dtype_map,
#     )
#     uvicorn.run(app, host="0.0.0.0", port=8188)