File size: 2,629 Bytes
d90acf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import sys
sys.path.append('Kandinsky-3')

import torch
from kandinsky3 import get_T2I_pipeline
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import base64
from io import BytesIO
from PIL import Image
import uvicorn

import time
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import base64
import requests

device_map = torch.device('cuda:0')
dtype_map = {
    'unet': torch.float32,
    'text_encoder': torch.float16,
    'movq': torch.float32,
}

# Initialize the FastAPI app
app = FastAPI()

# Define the request model
class GenerateImageRequest(BaseModel):
    prompt: str
    width: Optional[int] = 1024
    height: Optional[int] = 1024

# Define the response model
class GenerateImageResponse(BaseModel):
    image_base64: str

# Define the endpoint
@app.post("/k31/", response_model=GenerateImageResponse)
async def generate_image(request: GenerateImageRequest):
    try:
        # Generate the image using the pipeline
        pil_image = t2i_pipe(request.prompt, width=request.width, height=request.height, steps=50)[0]

        # Resize the image if necessary
        if pil_image.size != (request.width, request.height):
            pil_image = pil_image.resize((request.width, request.height))
        
        # Convert the PIL image to base64
        buffered = BytesIO()
        pil_image.save(buffered, format="PNG")
        image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
        
        # Return the response
        return GenerateImageResponse(image_base64=image_base64)

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

def api_k31_generate(prompt, width=1024, height=1024, url = "http://0.0.0.0:8188/k31/"):
    # Define the text message and image parameters
    data = {
        "prompt": prompt,
        "width": width,
        "height": height
    }
    
    # Send the POST request
    response = requests.post(url, json=data)
    
    # Check if the request was successful
    if response.status_code == 200:
        # Extract the base64 encoded image from the response
        image_base64 = response.json()["image_base64"]
        
        # You can further process the image here, for example, decode it from base64
        decoded_image = Image.open(BytesIO(base64.b64decode(image_base64)))
        
        return decoded_image
    else:
        print("Error:", response.text)
        
# Run the FastAPI app
if __name__ == "__main__":
    t2i_pipe = get_T2I_pipeline(
        device_map, dtype_map,
    )
    uvicorn.run(app, host="0.0.0.0", port=8188)