FreeLighting / app.py
刘宇轩
webui adjust
92cdc12
raw
history blame
13.9 kB
import hashlib
import time
import uuid
from urllib.parse import urlencode
import json
import requests
from PIL import Image, ImageOps
import os
import gradio as gr
import ops
from watermark import WatermarkApp
import db_examples
class ApiClient(object):
def __init__(self, app_key: str, access_key_id: str, access_key_secret: str, endpoint: str):
self.app_key = app_key
self.access_key_id = access_key_id
self.access_key_secret = access_key_secret
self.endpoint = endpoint
self.base_url = 'http://' + self.endpoint + '/api'
self.timeout = 8000
self.session = requests.Session()
self.session.headers.update(
{
"Content-Type": "application/json;charset=utf-8",
"accessKey": self.access_key_id,
"appKey": self.app_key
}
)
def send_get(self, headers=None, params=None):
if headers is None:
headers = {}
if params is None:
params = {}
args = self.cleanNoneValue(
{
"url": self.base_url,
"headers": self._prepare_headers(headers),
"params": self._prepare_params(params),
"timeout": self.timeout
}
)
response = self._dispatch_request("GET")(**args)
self._handle_exception(response)
try:
data = response.json()
except ValueError:
data = response.text
return data
def send_post(self, headers=None, params=None):
if headers is None:
headers = {}
if params is None:
params = {}
args = self.cleanNoneValue(
{
"url": self.base_url,
"headers": self._prepare_headers(headers),
"json": params,
"timeout": self.timeout
}
)
response = self._dispatch_request("POST")(**args)
self._handle_exception(response)
try:
data = response.json()
except ValueError:
data = response.text
return data
def _prepare_headers(self, headers):
headers['requestId'] = str(uuid.uuid1())
timestamp = int(round(time.time() * 1000))
headers['timestamp'] = str(timestamp)
headers['sign'] = self._get_sign(timestamp)
return headers
def _prepare_params(self, params):
return self.encoded_string(self.cleanNoneValue(params))
def _get_sign(self, timestamp):
key = self.app_key + self.access_key_id + str(timestamp) + self.access_key_secret
md5hash = hashlib.md5(str.encode(key, 'utf-8'))
return md5hash.hexdigest()
def _dispatch_request(self, http_method):
return {
"GET": self.session.get,
"DELETE": self.session.delete,
"PUT": self.session.put,
"POST": self.session.post,
}.get(http_method, "GET")
def _handle_exception(self, response):
status_code = response.status_code
if status_code < 400:
return
raise Exception(response.text)
def encoded_string(self, query):
return urlencode(query, True).replace("%40", "@")
def cleanNoneValue(self, d) -> dict:
out = {}
for k in d.keys():
if d[k] is not None:
out[k] = d[k]
return out
api_key = os.environ['APIKEY']
ak = os.environ['AK']
sk = os.environ['SK']
endpoint = os.environ['ENDPOINT']
apiname = os.environ['APINAME']
callbackapiname = os.environ['CALLBACKAPINAME']
osssigneapiname = os.environ['OSSSIGNAPINAME']
client = ApiClient(api_key, ak, sk, endpoint)
watermark_app = WatermarkApp()
def upload_to_oss(file_path, accessid, policy, signature, host, key, callback):
with open(file_path, 'rb') as f:
files = {'file': (key, f)}
data = {
'OSSAccessKeyId': accessid,
'policy': policy,
'Signature': signature,
'key': key,
'callback': callback,
}
response = requests.post(host, files=files, data=data, timeout=20)
if response.status_code == 204 or response.ok:
return key
else:
print(f"file upload failed: {response.text}")
return None
def upload_oss_bucket(image_pil):
headers = {
'apiName': osssigneapiname
}
params = {
'fileType': '1'
}
result = client.send_get(headers, params)
try:
result = result['data']
except Exception as e:
raise ValueError('oss sign error')
accessid = result['accessid']
policy = result['policy']
signature = result['signature']
host = result['host']
callback = ''
oss_key = os.path.join(result['dir'], '{}.jpg'.format(result['expire']))
file_path = 'test.jpg'
image_pil.save(file_path)
oss_key = upload_to_oss(file_path, accessid, policy, signature, host, oss_key, callback)
if oss_key is None:
raise ValueError('oss upload error')
return oss_key
def call_text_guided_relighting(image, mode, prompt, seed, steps):
headers = {
'apiName': apiname
}
image = ImageOps.exif_transpose(image).convert('RGB')
image = ops.resize_keep_hw_rate(image, tar_res=1280)
image = upload_oss_bucket(image)
ops.print_with_datetime('start call_text_guided_relighting')
ops.print_with_datetime(prompt)
params = {
'image': image,
'inference_mode': 'free_txt2bg_gen',
'mode': mode,
'prompt': prompt,
'seed': seed,
'steps': steps,
}
ops.print_with_datetime(f'length params, {len(json.dumps(params))}')
task_id = client.send_post(headers, params)['data']
time.sleep(10)
headers = {
'apiName': callbackapiname
}
params = {
'id': task_id
}
flag = True
while flag:
result = client.send_get(headers, params)['data']
if result['status'] != 1:
flag = False
else:
time.sleep(10)
if result['status'] != 2:
raise ValueError('something wrong in the process')
result_1 = result['sasMyCreationPicVOs'][0]['picUrl']
result_2 = result['sasMyCreationPicVOs'][0]['maskUrl']
result_1 = ops.decode_img_from_url(result_1)
result_1 = watermark_app.process_image(result_1)
result_2 = ops.decode_img_from_url(result_2)
return result_1, result_2
def call_image_guided_relighting(image, ref_img, seed, steps):
headers = {
'apiName': apiname
}
image = ImageOps.exif_transpose(image).convert('RGB')
image = ops.resize_keep_hw_rate(image, tar_res=1280)
image = upload_oss_bucket(image)
ref_img = ImageOps.exif_transpose(ref_img).convert('RGB')
ref_img = ops.resize_keep_hw_rate(ref_img, tar_res=1280)
ref_img = upload_oss_bucket(ref_img)
ops.print_with_datetime('start call_image_guided_relighting')
params = {
'image': image,
'inference_mode': 'replica_gen',
'mode': 'normal',
'ref_img': ref_img,
'seed': seed,
'steps': steps,
}
ops.print_with_datetime(f'length params, {len(json.dumps(params))}')
task_id = client.send_post(headers, params)
print(task_id)
task_id = task_id['data']
time.sleep(10)
headers = {
'apiName': callbackapiname
}
params = {
'id': task_id
}
flag = True
while flag:
result = client.send_get(headers, params)
result = result['data']
if result['status'] != 1:
flag = False
else:
time.sleep(10)
if result['status'] != 2:
raise ValueError('something wrong in the process')
result_1 = result['sasMyCreationPicVOs'][0]['picUrl']
result_2 = result['sasMyCreationPicVOs'][0]['maskUrl']
result_1 = ops.decode_img_from_url(result_1)
result_1 = watermark_app.process_image(result_1)
result_2 = ops.decode_img_from_url(result_2)
return result_1, result_2
quick_prompts = [
'warm lighting',
'sunshine from window',
'neon lighting',
'at noon. bright sunlight',
'at dusk',
'golden time',
'natural lighting',
'shadow from window',
'soft studio lighting',
'red lighting',
'purple lighting'
]
quick_prompts = [[x] for x in quick_prompts]
quick_content_prompts = [
'by the sea',
'in the forest',
'on the snow mountain',
'by the city street',
'on the grassy field',
'cityscape',
'on the desert',
'in the living room',
]
quick_content_prompts = [[x] for x in quick_content_prompts]
quick_subjects = [
'portrait photography of a woman',
'portrait photography of man',
'product photography',
]
quick_subjects = [[x] for x in quick_subjects]
with gr.Blocks().queue() as demo:
gr.HTML("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
FreeLighting: relighting model with both text-condition and image-condition
</div>
""")
with gr.Row():
gr.Markdown("See more information in https://github.com/liuyuxuan3060/FreeLighting")
with gr.Row():
gr.Markdown("We use a open source segmentation model to generate image mask")
with gr.Tabs():
with gr.TabItem("text-guided relighting") as t2v_tab:
with gr.Row():
with gr.Column():
with gr.Row():
image = gr.Image(label="original_image", type="pil", height=480)
image_mask = gr.Image(label="image_mask", type="pil", height=480)
with gr.Row():
prompt = gr.Textbox(value="", label="text prompt")
with gr.Row():
example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick Prompt List. Click to choose a quick subject description', samples_per_page=1000, components=[prompt])
with gr.Row():
example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick Prompt List. Click to choose a quick lighting description', samples_per_page=1000, components=[prompt])
with gr.Row():
example_quick_content_prompts = gr.Dataset(samples=quick_content_prompts, label='Content Quick Prompt List. Click to choose a quick content description', samples_per_page=1000, components=[prompt])
with gr.Row():
mode = gr.Radio(choices=["normal", "uniform-lit"], value='normal', label="uniform-lit mode will use double time", type='value')
seed = gr.Number(value=12345, label="random seed", precision=0)
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
button = gr.Button("generate")
with gr.Column():
relighting_image = gr.Image(label="relighted_image", type="pil", height=480)
with gr.Row():
gr.Examples(
examples=db_examples.text_guided_examples,
inputs=[
image, mode, prompt, seed, steps
, relighting_image
],
outputs=[relighting_image],
examples_per_page=1024
)
button.click(
fn=call_text_guided_relighting,
inputs=[
image, mode, prompt, seed, steps
],
outputs=[relighting_image, image_mask],
)
example_quick_content_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_content_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
with gr.TabItem("image-guided relighting") as i2v_tab:
with gr.Row():
with gr.Column():
with gr.Row():
image = gr.Image(label="original_image", type="pil", height=480)
ref_img = gr.Image(label="reference_image", type="pil", height=480)
image_mask = gr.Image(label="image_mask", type="pil", height=480)
with gr.Row():
seed = gr.Number(value=12345, label="random seed", precision=0)
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
button = gr.Button("generate")
with gr.Column():
relighting_image = gr.Image(label="relighted_image", type="pil", height=480)
with gr.Row():
gr.Examples(
examples=db_examples.image_guided_examples,
inputs=[
image, ref_img, seed, steps
, relighting_image
],
outputs=[relighting_image],
examples_per_page=1024
)
button.click(
fn=call_image_guided_relighting,
inputs=[
image, ref_img, seed, steps
],
outputs=[relighting_image, image_mask],
)
demo.launch()