jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import json
import os
import yaml
import requests
import pathlib
from aiohttp import web
from server import PromptServer
from .image import tensor2pil, pil2tensor, image2base64, pil2byte
from .log import log_node_error
root_path = pathlib.Path(__file__).parent.parent.parent
config_path = os.path.join(root_path,'config.yaml')
default_key = [{'name':'Default', 'key':''}]
class StabilityAPI:
def __init__(self):
self.api_url = "https://api.stability.ai"
self.api_keys = None
self.api_current = 0
self.user_info = {}
self.getAPIKeys()
def getErrors(self, code):
errors = {
400: "Bad Request",
403: "ApiKey Forbidden",
413: "Your request was larger than 10MiB.",
429: "You have made more than 150 requests in 10 seconds.",
500: "Internal Server Error",
}
return errors.get(code, "Unknown Error")
def getAPIKeys(self):
if os.path.isfile(config_path):
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
if not data:
data = {'STABILITY_API_KEY': default_key, 'STABILITY_API_DEFAULT':0}
with open(config_path, 'w') as f:
yaml.dump(data, f)
if 'STABILITY_API_KEY' not in data:
data['STABILITY_API_KEY'] = default_key
data['STABILITY_API_DEFAULT'] = 0
with open(config_path, 'w') as f:
yaml.dump(data, f)
api_keys = data['STABILITY_API_KEY']
self.api_current = data['STABILITY_API_DEFAULT']
self.api_keys = api_keys
return api_keys
else:
# create a yaml file
with open(config_path, 'w') as f:
data = {'STABILITY_API_KEY': default_key, 'STABILITY_API_DEFAULT':0}
yaml.dump(data, f)
return data['STABILITY_API_KEY']
pass
def setAPIKeys(self, api_keys):
if len(api_keys) > 0:
self.api_keys = api_keys
# load and save the yaml file
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data['STABILITY_API_KEY'] = api_keys
with open(config_path, 'w') as f:
yaml.dump(data, f)
return True
def setAPIDefault(self, current):
if current is not None:
self.api_current = current
# load and save the yaml file
with open(config_path, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data['STABILITY_API_DEFAULT'] = current
with open(config_path, 'w') as f:
yaml.dump(data, f)
return True
def generate_sd3_image(self, prompt, negative_prompt, aspect_ratio, model, seed, mode='text-to-image', image=None, strength=1, output_format='png', node_name='easy stableDiffusion3API'):
url = f"{self.api_url}/v2beta/stable-image/generate/sd3"
api_key = self.api_keys[self.api_current]['key']
files = None
data = {
"prompt": prompt,
"mode": mode,
"model": model,
"seed": seed,
"output_format": output_format,
}
if model == 'sd3':
data['negative_prompt'] = negative_prompt
if mode == 'text-to-image':
files = {"none": ''}
data['aspect_ratio'] = aspect_ratio
elif mode == 'image-to-image':
pil_image = tensor2pil(image)
image_byte = pil2byte(pil_image)
files = {"image": ("output.png", image_byte, 'image/png')}
data['strength'] = strength
response = requests.post(url,
headers={"authorization": f"{api_key}", "accept": "application/json"},
files=files,
data=data,
)
if response.status_code == 200:
PromptServer.instance.send_sync('stable-diffusion-api-generate-succeed',{"model":model})
json_data = response.json()
image_base64 = json_data['image']
image_data = image2base64(image_base64)
output_t = pil2tensor(image_data)
return output_t
else:
if 'application/json' in response.headers['Content-Type']:
error_info = response.json()
log_node_error(node_name, error_info.get('name', 'No name provided'))
log_node_error(node_name, error_info.get('errors', ['No details provided']))
error_status_text = self.getErrors(response.status_code)
PromptServer.instance.send_sync('easyuse-toast',{"type": "error", "content": error_status_text})
raise Exception(f"Failed to generate image: {error_status_text}")
# get user account
async def getUserAccount(self, cache=True):
url = f"{self.api_url}/v1/user/account"
api_key = self.api_keys[self.api_current]['key']
name = self.api_keys[self.api_current]['name']
if cache and name in self.user_info:
return self.user_info[name]
else:
response = requests.get(url, headers={"Authorization": f"Bearer {api_key}"})
if response.status_code == 200:
user_info = response.json()
self.user_info[name] = user_info
return user_info
else:
PromptServer.instance.send_sync('easyuse-toast',{'type': 'error', 'content': self.getErrors(response.status_code)})
return None
# get user balance
async def getUserBalance(self):
url = f"{self.api_url}/v1/user/balance"
api_key = self.api_keys[self.api_current]['key']
response = requests.get(url, headers={
"Authorization": f"Bearer {api_key}"
})
if response.status_code == 200:
return response.json()
else:
PromptServer.instance.send_sync('easyuse-toast', {'type': 'error', 'content': self.getErrors(response.status_code)})
return None
stableAPI = StabilityAPI()
@PromptServer.instance.routes.get("/easyuse/stability/api_keys")
async def get_stability_api_keys(request):
stableAPI.getAPIKeys()
return web.json_response({"keys": stableAPI.api_keys, "current": stableAPI.api_current})
@PromptServer.instance.routes.post("/easyuse/stability/set_api_keys")
async def set_stability_api_keys(request):
post = await request.post()
api_keys = post.get("api_keys")
current = post.get('current')
if api_keys is not None:
api_keys = json.loads(api_keys)
stableAPI.setAPIKeys(api_keys)
if current is not None:
print(current)
stableAPI.setAPIDefault(int(current))
account = await stableAPI.getUserAccount()
balance = await stableAPI.getUserBalance()
return web.json_response({'account': account, 'balance': balance})
else:
return web.json_response({'status': 'ok'})
else:
return web.Response(status=400)
@PromptServer.instance.routes.post("/easyuse/stability/set_apikey_default")
async def set_stability_api_default(request):
post = await request.post()
current = post.get("current")
if current is not None and current < len(stableAPI.api_keys):
stableAPI.api_current = current
return web.json_response({'status': 'ok'})
else:
return web.Response(status=400)
@PromptServer.instance.routes.get("/easyuse/stability/user_info")
async def get_account_info(request):
account = await stableAPI.getUserAccount()
balance = await stableAPI.getUserBalance()
return web.json_response({'account': account, 'balance': balance})
@PromptServer.instance.routes.get("/easyuse/stability/balance")
async def get_balance_info(request):
balance = await stableAPI.getUserBalance()
return web.json_response({'balance': balance})