ai-forever commited on
Commit
fec3cbb
·
verified ·
1 Parent(s): d7c1097

Update src/kandinsky.py

Browse files
Files changed (1) hide show
  1. src/kandinsky.py +95 -57
src/kandinsky.py CHANGED
@@ -16,75 +16,113 @@ from fastapi import FastAPI, HTTPException
16
  from pydantic import BaseModel
17
  import base64
18
  import requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- device_map = torch.device('cuda:0')
21
- dtype_map = {
22
- 'unet': torch.float32,
23
- 'text_encoder': torch.float16,
24
- 'movq': torch.float32,
25
- }
26
 
27
- # Initialize the FastAPI app
28
- app = FastAPI()
29
 
30
  # Define the request model
31
- class GenerateImageRequest(BaseModel):
32
- prompt: str
33
- width: Optional[int] = 1024
34
- height: Optional[int] = 1024
35
 
36
  # Define the response model
37
- class GenerateImageResponse(BaseModel):
38
- image_base64: str
39
 
40
  # Define the endpoint
41
- @app.post("/k31/", response_model=GenerateImageResponse)
42
- async def generate_image(request: GenerateImageRequest):
43
- try:
44
- # Generate the image using the pipeline
45
- pil_image = t2i_pipe(request.prompt, width=request.width, height=request.height, steps=50)[0]
46
-
47
- # Resize the image if necessary
48
- if pil_image.size != (request.width, request.height):
49
- pil_image = pil_image.resize((request.width, request.height))
50
 
51
- # Convert the PIL image to base64
52
- buffered = BytesIO()
53
- pil_image.save(buffered, format="PNG")
54
- image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
55
 
56
- # Return the response
57
- return GenerateImageResponse(image_base64=image_base64)
58
-
59
- except Exception as e:
60
- raise HTTPException(status_code=500, detail=str(e))
61
-
62
- def api_k31_generate(prompt, width=1024, height=1024, url = "http://0.0.0.0:8188/k31/"):
63
- # Define the text message and image parameters
64
- data = {
65
- "prompt": prompt,
66
- "width": width,
67
- "height": height
68
- }
69
 
70
- # Send the POST request
71
- response = requests.post(url, json=data)
72
 
73
- # Check if the request was successful
74
- if response.status_code == 200:
75
- # Extract the base64 encoded image from the response
76
- image_base64 = response.json()["image_base64"]
77
 
78
- # You can further process the image here, for example, decode it from base64
79
- decoded_image = Image.open(BytesIO(base64.b64decode(image_base64)))
80
 
81
- return decoded_image
82
- else:
83
- print("Error:", response.text)
84
 
85
- # Run the FastAPI app
86
- if __name__ == "__main__":
87
- t2i_pipe = get_T2I_pipeline(
88
- device_map, dtype_map,
89
- )
90
- uvicorn.run(app, host="0.0.0.0", port=8188)
 
16
  from pydantic import BaseModel
17
  import base64
18
  import requests
19
+ import spaces #[uncomment to use ZeroGPU]
20
+ from diffusers import DiffusionPipeline
21
+ import torch
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ model_repo_id = "kandinsky-community/kandinsky-3" #"stabilityai/sdxl-turbo" #Replace to the model you would like to use
25
+
26
+ if torch.cuda.is_available():
27
+ torch_dtype = torch.float16
28
+ else:
29
+ torch_dtype = torch.float32
30
+
31
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, variant="fp16", torch_dtype=torch_dtype)
32
+ pipe = pipe.to(device)
33
+
34
+ MAX_SEED = np.iinfo(np.int32).max
35
+ MAX_IMAGE_SIZE = 1024
36
+
37
+ @spaces.GPU #[uncomment to use ZeroGPU]
38
+ def generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
39
+
40
+ if randomize_seed:
41
+ seed = random.randint(0, MAX_SEED)
42
+
43
+ generator = torch.Generator().manual_seed(seed)
44
+
45
+ image = pipe(
46
+ prompt = prompt,
47
+ negative_prompt = negative_prompt,
48
+ guidance_scale = guidance_scale,
49
+ num_inference_steps = num_inference_steps,
50
+ width = width,
51
+ height = height,
52
+ generator = generator
53
+ ).images[0]
54
+
55
+ return image, seed
56
+
57
 
58
+ # device_map = torch.device('cuda:0')
59
+ # dtype_map = {
60
+ # 'unet': torch.float32,
61
+ # 'text_encoder': torch.float16,
62
+ # 'movq': torch.float32,
63
+ # }
64
 
65
+ # # Initialize the FastAPI app
66
+ # app = FastAPI()
67
 
68
  # Define the request model
69
+ # class GenerateImageRequest(BaseModel):
70
+ # prompt: str
71
+ # width: Optional[int] = 1024
72
+ # height: Optional[int] = 1024
73
 
74
  # Define the response model
75
+ # class GenerateImageResponse(BaseModel):
76
+ # image_base64: str
77
 
78
  # Define the endpoint
79
+ # @app.post("/k31/", response_model=GenerateImageResponse)
80
+ # async def generate_image(request: GenerateImageRequest):
81
+ # try:
82
+ # # Generate the image using the pipeline
83
+ # pil_image = t2i_pipe(request.prompt, width=request.width, height=request.height, steps=50)[0]
84
+
85
+ # # Resize the image if necessary
86
+ # if pil_image.size != (request.width, request.height):
87
+ # pil_image = pil_image.resize((request.width, request.height))
88
 
89
+ # # Convert the PIL image to base64
90
+ # buffered = BytesIO()
91
+ # pil_image.save(buffered, format="PNG")
92
+ # image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
93
 
94
+ # # Return the response
95
+ # return GenerateImageResponse(image_base64=image_base64)
96
+
97
+ # except Exception as e:
98
+ # raise HTTPException(status_code=500, detail=str(e))
99
+
100
+ # def api_k31_generate(prompt, width=1024, height=1024, url = "http://0.0.0.0:8188/k31/"):
101
+ # # Define the text message and image parameters
102
+ # data = {
103
+ # "prompt": prompt,
104
+ # "width": width,
105
+ # "height": height
106
+ # }
107
 
108
+ # # Send the POST request
109
+ # response = requests.post(url, json=data)
110
 
111
+ # # Check if the request was successful
112
+ # if response.status_code == 200:
113
+ # # Extract the base64 encoded image from the response
114
+ # image_base64 = response.json()["image_base64"]
115
 
116
+ # # You can further process the image here, for example, decode it from base64
117
+ # decoded_image = Image.open(BytesIO(base64.b64decode(image_base64)))
118
 
119
+ # return decoded_image
120
+ # else:
121
+ # print("Error:", response.text)
122
 
123
+ # # Run the FastAPI app
124
+ # if __name__ == "__main__":
125
+ # t2i_pipe = get_T2I_pipeline(
126
+ # device_map, dtype_map,
127
+ # )
128
+ # uvicorn.run(app, host="0.0.0.0", port=8188)