File size: 8,082 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import json
import os
import yaml
import requests
import pathlib
from aiohttp import web
from server import PromptServer
from .image import tensor2pil, pil2tensor, image2base64, pil2byte
from .log import log_node_error
root_path = pathlib.Path(__file__).parent.parent.parent
config_path = os.path.join(root_path,'config.yaml')
default_key = [{'name':'Default', 'key':''}]
class StabilityAPI:
def __init__(self):
self.api_url = "https://api.stability.ai"
self.api_keys = None
self.api_current = 0
self.user_info = {}
self.getAPIKeys()
def getErrors(self, code):
errors = {
400: "Bad Request",
403: "ApiKey Forbidden",
413: "Your request was larger than 10MiB.",
429: "You have made more than 150 requests in 10 seconds.",
500: "Internal Server Error",
}
return errors.get(code, "Unknown Error")
def getAPIKeys(self):
if os.path.isfile(config_path):
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
if not data:
data = {'STABILITY_API_KEY': default_key, 'STABILITY_API_DEFAULT':0}
with open(config_path, 'w') as f:
yaml.dump(data, f)
if 'STABILITY_API_KEY' not in data:
data['STABILITY_API_KEY'] = default_key
data['STABILITY_API_DEFAULT'] = 0
with open(config_path, 'w') as f:
yaml.dump(data, f)
api_keys = data['STABILITY_API_KEY']
self.api_current = data['STABILITY_API_DEFAULT']
self.api_keys = api_keys
return api_keys
else:
# create a yaml file
with open(config_path, 'w') as f:
data = {'STABILITY_API_KEY': default_key, 'STABILITY_API_DEFAULT':0}
yaml.dump(data, f)
return data['STABILITY_API_KEY']
pass
def setAPIKeys(self, api_keys):
if len(api_keys) > 0:
self.api_keys = api_keys
# load and save the yaml file
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data['STABILITY_API_KEY'] = api_keys
with open(config_path, 'w') as f:
yaml.dump(data, f)
return True
def setAPIDefault(self, current):
if current is not None:
self.api_current = current
# load and save the yaml file
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data['STABILITY_API_DEFAULT'] = current
with open(config_path, 'w') as f:
yaml.dump(data, f)
return True
def generate_sd3_image(self, prompt, negative_prompt, aspect_ratio, model, seed, mode='text-to-image', image=None, strength=1, output_format='png', node_name='easy stableDiffusion3API'):
url = f"{self.api_url}/v2beta/stable-image/generate/sd3"
api_key = self.api_keys[self.api_current]['key']
files = None
data = {
"prompt": prompt,
"mode": mode,
"model": model,
"seed": seed,
"output_format": output_format,
}
if model == 'sd3':
data['negative_prompt'] = negative_prompt
if mode == 'text-to-image':
files = {"none": ''}
data['aspect_ratio'] = aspect_ratio
elif mode == 'image-to-image':
pil_image = tensor2pil(image)
image_byte = pil2byte(pil_image)
files = {"image": ("output.png", image_byte, 'image/png')}
data['strength'] = strength
response = requests.post(url,
headers={"authorization": f"{api_key}", "accept": "application/json"},
files=files,
data=data,
)
if response.status_code == 200:
PromptServer.instance.send_sync('stable-diffusion-api-generate-succeed',{"model":model})
json_data = response.json()
image_base64 = json_data['image']
image_data = image2base64(image_base64)
output_t = pil2tensor(image_data)
return output_t
else:
if 'application/json' in response.headers['Content-Type']:
error_info = response.json()
log_node_error(node_name, error_info.get('name', 'No name provided'))
log_node_error(node_name, error_info.get('errors', ['No details provided']))
error_status_text = self.getErrors(response.status_code)
PromptServer.instance.send_sync('easyuse-toast',{"type": "error", "content": error_status_text})
raise Exception(f"Failed to generate image: {error_status_text}")
# get user account
async def getUserAccount(self, cache=True):
url = f"{self.api_url}/v1/user/account"
api_key = self.api_keys[self.api_current]['key']
name = self.api_keys[self.api_current]['name']
if cache and name in self.user_info:
return self.user_info[name]
else:
response = requests.get(url, headers={"Authorization": f"Bearer {api_key}"})
if response.status_code == 200:
user_info = response.json()
self.user_info[name] = user_info
return user_info
else:
PromptServer.instance.send_sync('easyuse-toast',{'type': 'error', 'content': self.getErrors(response.status_code)})
return None
# get user balance
async def getUserBalance(self):
url = f"{self.api_url}/v1/user/balance"
api_key = self.api_keys[self.api_current]['key']
response = requests.get(url, headers={
"Authorization": f"Bearer {api_key}"
})
if response.status_code == 200:
return response.json()
else:
PromptServer.instance.send_sync('easyuse-toast', {'type': 'error', 'content': self.getErrors(response.status_code)})
return None
stableAPI = StabilityAPI()
@PromptServer.instance.routes.get("/easyuse/stability/api_keys")
async def get_stability_api_keys(request):
stableAPI.getAPIKeys()
return web.json_response({"keys": stableAPI.api_keys, "current": stableAPI.api_current})
@PromptServer.instance.routes.post("/easyuse/stability/set_api_keys")
async def set_stability_api_keys(request):
post = await request.post()
api_keys = post.get("api_keys")
current = post.get('current')
if api_keys is not None:
api_keys = json.loads(api_keys)
stableAPI.setAPIKeys(api_keys)
if current is not None:
print(current)
stableAPI.setAPIDefault(int(current))
account = await stableAPI.getUserAccount()
balance = await stableAPI.getUserBalance()
return web.json_response({'account': account, 'balance': balance})
else:
return web.json_response({'status': 'ok'})
else:
return web.Response(status=400)
@PromptServer.instance.routes.post("/easyuse/stability/set_apikey_default")
async def set_stability_api_default(request):
post = await request.post()
current = post.get("current")
if current is not None and current < len(stableAPI.api_keys):
stableAPI.api_current = current
return web.json_response({'status': 'ok'})
else:
return web.Response(status=400)
@PromptServer.instance.routes.get("/easyuse/stability/user_info")
async def get_account_info(request):
account = await stableAPI.getUserAccount()
balance = await stableAPI.getUserBalance()
return web.json_response({'account': account, 'balance': balance})
@PromptServer.instance.routes.get("/easyuse/stability/balance")
async def get_balance_info(request):
balance = await stableAPI.getUserBalance()
return web.json_response({'balance': balance})
|