Spaces:
Running
Running
import hashlib | |
import time | |
import uuid | |
from urllib.parse import urlencode | |
import json | |
import requests | |
from PIL import Image, ImageOps | |
import os | |
import gradio as gr | |
import ops | |
from watermark import WatermarkApp | |
import db_examples | |
class ApiClient(object): | |
def __init__(self, app_key: str, access_key_id: str, access_key_secret: str, endpoint: str): | |
self.app_key = app_key | |
self.access_key_id = access_key_id | |
self.access_key_secret = access_key_secret | |
self.endpoint = endpoint | |
self.base_url = 'http://' + self.endpoint + '/api' | |
self.timeout = 8000 | |
self.session = requests.Session() | |
self.session.headers.update( | |
{ | |
"Content-Type": "application/json;charset=utf-8", | |
"accessKey": self.access_key_id, | |
"appKey": self.app_key | |
} | |
) | |
def send_get(self, headers=None, params=None): | |
if headers is None: | |
headers = {} | |
if params is None: | |
params = {} | |
args = self.cleanNoneValue( | |
{ | |
"url": self.base_url, | |
"headers": self._prepare_headers(headers), | |
"params": self._prepare_params(params), | |
"timeout": self.timeout | |
} | |
) | |
response = self._dispatch_request("GET")(**args) | |
self._handle_exception(response) | |
try: | |
data = response.json() | |
except ValueError: | |
data = response.text | |
return data | |
def send_post(self, headers=None, params=None): | |
if headers is None: | |
headers = {} | |
if params is None: | |
params = {} | |
args = self.cleanNoneValue( | |
{ | |
"url": self.base_url, | |
"headers": self._prepare_headers(headers), | |
"json": params, | |
"timeout": self.timeout | |
} | |
) | |
response = self._dispatch_request("POST")(**args) | |
self._handle_exception(response) | |
try: | |
data = response.json() | |
except ValueError: | |
data = response.text | |
return data | |
def _prepare_headers(self, headers): | |
headers['requestId'] = str(uuid.uuid1()) | |
timestamp = int(round(time.time() * 1000)) | |
headers['timestamp'] = str(timestamp) | |
headers['sign'] = self._get_sign(timestamp) | |
return headers | |
def _prepare_params(self, params): | |
return self.encoded_string(self.cleanNoneValue(params)) | |
def _get_sign(self, timestamp): | |
key = self.app_key + self.access_key_id + str(timestamp) + self.access_key_secret | |
md5hash = hashlib.md5(str.encode(key, 'utf-8')) | |
return md5hash.hexdigest() | |
def _dispatch_request(self, http_method): | |
return { | |
"GET": self.session.get, | |
"DELETE": self.session.delete, | |
"PUT": self.session.put, | |
"POST": self.session.post, | |
}.get(http_method, "GET") | |
def _handle_exception(self, response): | |
status_code = response.status_code | |
if status_code < 400: | |
return | |
raise Exception(response.text) | |
def encoded_string(self, query): | |
return urlencode(query, True).replace("%40", "@") | |
def cleanNoneValue(self, d) -> dict: | |
out = {} | |
for k in d.keys(): | |
if d[k] is not None: | |
out[k] = d[k] | |
return out | |
api_key = os.environ['APIKEY'] | |
ak = os.environ['AK'] | |
sk = os.environ['SK'] | |
endpoint = os.environ['ENDPOINT'] | |
apiname = os.environ['APINAME'] | |
callbackapiname = os.environ['CALLBACKAPINAME'] | |
osssigneapiname = os.environ['OSSSIGNAPINAME'] | |
client = ApiClient(api_key, ak, sk, endpoint) | |
watermark_app = WatermarkApp() | |
def upload_to_oss(file_path, accessid, policy, signature, host, key, callback): | |
with open(file_path, 'rb') as f: | |
files = {'file': (key, f)} | |
data = { | |
'OSSAccessKeyId': accessid, | |
'policy': policy, | |
'Signature': signature, | |
'key': key, | |
'callback': callback, | |
} | |
response = requests.post(host, files=files, data=data, timeout=20) | |
if response.status_code == 204 or response.ok: | |
return key | |
else: | |
print(f"file upload failed: {response.text}") | |
return None | |
def upload_oss_bucket(image_pil): | |
headers = { | |
'apiName': osssigneapiname | |
} | |
params = { | |
'fileType': '1' | |
} | |
result = client.send_get(headers, params) | |
try: | |
result = result['data'] | |
except Exception as e: | |
raise ValueError('oss sign error') | |
accessid = result['accessid'] | |
policy = result['policy'] | |
signature = result['signature'] | |
host = result['host'] | |
callback = '' | |
oss_key = os.path.join(result['dir'], '{}.jpg'.format(result['expire'])) | |
file_path = 'test.jpg' | |
image_pil.save(file_path) | |
oss_key = upload_to_oss(file_path, accessid, policy, signature, host, oss_key, callback) | |
if oss_key is None: | |
raise ValueError('oss upload error') | |
return oss_key | |
def call_text_guided_relighting(image, mode, prompt, seed, steps): | |
headers = { | |
'apiName': apiname | |
} | |
image = ImageOps.exif_transpose(image).convert('RGB') | |
image = ops.resize_keep_hw_rate(image, tar_res=1280) | |
image = upload_oss_bucket(image) | |
ops.print_with_datetime('start call_text_guided_relighting') | |
ops.print_with_datetime(prompt) | |
params = { | |
'image': image, | |
'inference_mode': 'free_txt2bg_gen', | |
'mode': mode, | |
'prompt': prompt, | |
'seed': seed, | |
'steps': steps, | |
} | |
ops.print_with_datetime(f'length params, {len(json.dumps(params))}') | |
task_id = client.send_post(headers, params)['data'] | |
time.sleep(10) | |
headers = { | |
'apiName': callbackapiname | |
} | |
params = { | |
'id': task_id | |
} | |
flag = True | |
while flag: | |
result = client.send_get(headers, params)['data'] | |
if result['status'] != 1: | |
flag = False | |
else: | |
time.sleep(10) | |
if result['status'] != 2: | |
raise ValueError('something wrong in the process') | |
result_1 = result['sasMyCreationPicVOs'][0]['picUrl'] | |
result_2 = result['sasMyCreationPicVOs'][0]['maskUrl'] | |
result_1 = ops.decode_img_from_url(result_1) | |
result_1 = watermark_app.process_image(result_1) | |
result_2 = ops.decode_img_from_url(result_2) | |
return result_1, result_2 | |
def call_image_guided_relighting(image, ref_img, seed, steps): | |
headers = { | |
'apiName': apiname | |
} | |
image = ImageOps.exif_transpose(image).convert('RGB') | |
image = ops.resize_keep_hw_rate(image, tar_res=1280) | |
image = upload_oss_bucket(image) | |
ref_img = ImageOps.exif_transpose(ref_img).convert('RGB') | |
ref_img = ops.resize_keep_hw_rate(ref_img, tar_res=1280) | |
ref_img = upload_oss_bucket(ref_img) | |
ops.print_with_datetime('start call_image_guided_relighting') | |
params = { | |
'image': image, | |
'inference_mode': 'replica_gen', | |
'mode': 'normal', | |
'ref_img': ref_img, | |
'seed': seed, | |
'steps': steps, | |
} | |
ops.print_with_datetime(f'length params, {len(json.dumps(params))}') | |
task_id = client.send_post(headers, params) | |
print(task_id) | |
task_id = task_id['data'] | |
time.sleep(10) | |
headers = { | |
'apiName': callbackapiname | |
} | |
params = { | |
'id': task_id | |
} | |
flag = True | |
while flag: | |
result = client.send_get(headers, params) | |
result = result['data'] | |
if result['status'] != 1: | |
flag = False | |
else: | |
time.sleep(10) | |
if result['status'] != 2: | |
raise ValueError('something wrong in the process') | |
result_1 = result['sasMyCreationPicVOs'][0]['picUrl'] | |
result_2 = result['sasMyCreationPicVOs'][0]['maskUrl'] | |
result_1 = ops.decode_img_from_url(result_1) | |
result_1 = watermark_app.process_image(result_1) | |
result_2 = ops.decode_img_from_url(result_2) | |
return result_1, result_2 | |
quick_prompts = [ | |
'warm lighting', | |
'sunshine from window', | |
'neon lighting', | |
'at noon. bright sunlight', | |
'at dusk', | |
'golden time', | |
'natural lighting', | |
'shadow from window', | |
'soft studio lighting', | |
'red lighting', | |
'purple lighting' | |
] | |
quick_prompts = [[x] for x in quick_prompts] | |
quick_content_prompts = [ | |
'by the sea', | |
'in the forest', | |
'on the snow mountain', | |
'by the city street', | |
'on the grassy field', | |
'cityscape', | |
'on the desert', | |
'in the living room', | |
] | |
quick_content_prompts = [[x] for x in quick_content_prompts] | |
quick_subjects = [ | |
'portrait photography of a woman', | |
'portrait photography of man', | |
'product photography', | |
] | |
quick_subjects = [[x] for x in quick_subjects] | |
with gr.Blocks().queue() as demo: | |
gr.HTML(""" | |
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> | |
FreeLighting: relighting model with both text-condition and image-condition | |
</div> | |
""") | |
with gr.Row(): | |
gr.Markdown("See more information in https://github.com/liuyuxuan3060/FreeLighting") | |
with gr.Row(): | |
gr.Markdown("We use a open source segmentation model to generate image mask") | |
with gr.Tabs(): | |
with gr.TabItem("text-guided relighting") as t2v_tab: | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
image = gr.Image(label="original_image", type="pil", height=480) | |
image_mask = gr.Image(label="image_mask", type="pil", height=480) | |
with gr.Row(): | |
prompt = gr.Textbox(value="", label="text prompt") | |
with gr.Row(): | |
example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick Prompt List. Click to choose a quick subject description', samples_per_page=1000, components=[prompt]) | |
with gr.Row(): | |
example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick Prompt List. Click to choose a quick lighting description', samples_per_page=1000, components=[prompt]) | |
with gr.Row(): | |
example_quick_content_prompts = gr.Dataset(samples=quick_content_prompts, label='Content Quick Prompt List. Click to choose a quick content description', samples_per_page=1000, components=[prompt]) | |
with gr.Row(): | |
mode = gr.Radio(choices=["normal", "uniform-lit"], value='normal', label="uniform-lit mode will use double time", type='value') | |
seed = gr.Number(value=12345, label="random seed", precision=0) | |
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) | |
button = gr.Button("generate") | |
with gr.Column(): | |
relighting_image = gr.Image(label="relighted_image", type="pil", height=480) | |
with gr.Row(): | |
gr.Examples( | |
examples=db_examples.text_guided_examples, | |
inputs=[ | |
image, mode, prompt, seed, steps | |
, relighting_image | |
], | |
outputs=[relighting_image], | |
examples_per_page=1024 | |
) | |
button.click( | |
fn=call_text_guided_relighting, | |
inputs=[ | |
image, mode, prompt, seed, steps | |
], | |
outputs=[relighting_image, image_mask], | |
) | |
example_quick_content_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_content_prompts, prompt], outputs=prompt, show_progress=False, queue=False) | |
example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False) | |
example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False) | |
with gr.TabItem("image-guided relighting") as i2v_tab: | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
image = gr.Image(label="original_image", type="pil", height=480) | |
ref_img = gr.Image(label="reference_image", type="pil", height=480) | |
image_mask = gr.Image(label="image_mask", type="pil", height=480) | |
with gr.Row(): | |
seed = gr.Number(value=12345, label="random seed", precision=0) | |
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) | |
button = gr.Button("generate") | |
with gr.Column(): | |
relighting_image = gr.Image(label="relighted_image", type="pil", height=480) | |
with gr.Row(): | |
gr.Examples( | |
examples=db_examples.image_guided_examples, | |
inputs=[ | |
image, ref_img, seed, steps | |
, relighting_image | |
], | |
outputs=[relighting_image], | |
examples_per_page=1024 | |
) | |
button.click( | |
fn=call_image_guided_relighting, | |
inputs=[ | |
image, ref_img, seed, steps | |
], | |
outputs=[relighting_image, image_mask], | |
) | |
demo.launch() | |