import hashlib import time import uuid from urllib.parse import urlencode import json import requests from PIL import Image, ImageOps import os import gradio as gr import ops from watermark import WatermarkApp import db_examples class ApiClient(object): def __init__(self, app_key: str, access_key_id: str, access_key_secret: str, endpoint: str): self.app_key = app_key self.access_key_id = access_key_id self.access_key_secret = access_key_secret self.endpoint = endpoint self.base_url = 'http://' + self.endpoint + '/api' self.timeout = 8000 self.session = requests.Session() self.session.headers.update( { "Content-Type": "application/json;charset=utf-8", "accessKey": self.access_key_id, "appKey": self.app_key } ) def send_get(self, headers=None, params=None): if headers is None: headers = {} if params is None: params = {} args = self.cleanNoneValue( { "url": self.base_url, "headers": self._prepare_headers(headers), "params": self._prepare_params(params), "timeout": self.timeout } ) response = self._dispatch_request("GET")(**args) self._handle_exception(response) try: data = response.json() except ValueError: data = response.text return data def send_post(self, headers=None, params=None): if headers is None: headers = {} if params is None: params = {} args = self.cleanNoneValue( { "url": self.base_url, "headers": self._prepare_headers(headers), "json": params, "timeout": self.timeout } ) response = self._dispatch_request("POST")(**args) self._handle_exception(response) try: data = response.json() except ValueError: data = response.text return data def _prepare_headers(self, headers): headers['requestId'] = str(uuid.uuid1()) timestamp = int(round(time.time() * 1000)) headers['timestamp'] = str(timestamp) headers['sign'] = self._get_sign(timestamp) return headers def _prepare_params(self, params): return self.encoded_string(self.cleanNoneValue(params)) def _get_sign(self, timestamp): key = self.app_key + self.access_key_id + str(timestamp) + self.access_key_secret md5hash = hashlib.md5(str.encode(key, 'utf-8')) return md5hash.hexdigest() def _dispatch_request(self, http_method): return { "GET": self.session.get, "DELETE": self.session.delete, "PUT": self.session.put, "POST": self.session.post, }.get(http_method, "GET") def _handle_exception(self, response): status_code = response.status_code if status_code < 400: return raise Exception(response.text) def encoded_string(self, query): return urlencode(query, True).replace("%40", "@") def cleanNoneValue(self, d) -> dict: out = {} for k in d.keys(): if d[k] is not None: out[k] = d[k] return out api_key = os.environ['APIKEY'] ak = os.environ['AK'] sk = os.environ['SK'] endpoint = os.environ['ENDPOINT'] apiname = os.environ['APINAME'] callbackapiname = os.environ['CALLBACKAPINAME'] osssigneapiname = os.environ['OSSSIGNAPINAME'] client = ApiClient(api_key, ak, sk, endpoint) watermark_app = WatermarkApp() def upload_to_oss(file_path, accessid, policy, signature, host, key, callback): with open(file_path, 'rb') as f: files = {'file': (key, f)} data = { 'OSSAccessKeyId': accessid, 'policy': policy, 'Signature': signature, 'key': key, 'callback': callback, } response = requests.post(host, files=files, data=data, timeout=20) if response.status_code == 204 or response.ok: return key else: print(f"file upload failed: {response.text}") return None def upload_oss_bucket(image_pil): headers = { 'apiName': osssigneapiname } params = { 'fileType': '1' } result = client.send_get(headers, params) try: result = result['data'] except Exception as e: raise ValueError('oss sign error') accessid = result['accessid'] policy = result['policy'] signature = result['signature'] host = result['host'] callback = '' oss_key = os.path.join(result['dir'], '{}.jpg'.format(result['expire'])) file_path = 'test.jpg' image_pil.save(file_path) oss_key = upload_to_oss(file_path, accessid, policy, signature, host, oss_key, callback) if oss_key is None: raise ValueError('oss upload error') return oss_key def call_text_guided_relighting(image, mode, prompt, seed, steps): headers = { 'apiName': apiname } image = ImageOps.exif_transpose(image).convert('RGB') image = ops.resize_keep_hw_rate(image, tar_res=1280) image = upload_oss_bucket(image) ops.print_with_datetime('start call_text_guided_relighting') ops.print_with_datetime(prompt) params = { 'image': image, 'inference_mode': 'free_txt2bg_gen', 'mode': mode, 'prompt': prompt, 'seed': seed, 'steps': steps, } ops.print_with_datetime(f'length params, {len(json.dumps(params))}') task_id = client.send_post(headers, params)['data'] time.sleep(10) headers = { 'apiName': callbackapiname } params = { 'id': task_id } flag = True while flag: result = client.send_get(headers, params)['data'] if result['status'] != 1: flag = False else: time.sleep(10) if result['status'] != 2: raise ValueError('something wrong in the process') result_1 = result['sasMyCreationPicVOs'][0]['picUrl'] result_2 = result['sasMyCreationPicVOs'][0]['maskUrl'] result_1 = ops.decode_img_from_url(result_1) result_1 = watermark_app.process_image(result_1) result_2 = ops.decode_img_from_url(result_2) return result_1, result_2 def call_image_guided_relighting(image, ref_img, seed, steps): headers = { 'apiName': apiname } image = ImageOps.exif_transpose(image).convert('RGB') image = ops.resize_keep_hw_rate(image, tar_res=1280) image = upload_oss_bucket(image) ref_img = ImageOps.exif_transpose(ref_img).convert('RGB') ref_img = ops.resize_keep_hw_rate(ref_img, tar_res=1280) ref_img = upload_oss_bucket(ref_img) ops.print_with_datetime('start call_image_guided_relighting') params = { 'image': image, 'inference_mode': 'replica_gen', 'mode': 'normal', 'ref_img': ref_img, 'seed': seed, 'steps': steps, } ops.print_with_datetime(f'length params, {len(json.dumps(params))}') task_id = client.send_post(headers, params) print(task_id) task_id = task_id['data'] time.sleep(10) headers = { 'apiName': callbackapiname } params = { 'id': task_id } flag = True while flag: result = client.send_get(headers, params) result = result['data'] if result['status'] != 1: flag = False else: time.sleep(10) if result['status'] != 2: raise ValueError('something wrong in the process') result_1 = result['sasMyCreationPicVOs'][0]['picUrl'] result_2 = result['sasMyCreationPicVOs'][0]['maskUrl'] result_1 = ops.decode_img_from_url(result_1) result_1 = watermark_app.process_image(result_1) result_2 = ops.decode_img_from_url(result_2) return result_1, result_2 quick_prompts = [ 'warm lighting', 'sunshine from window', 'neon lighting', 'at noon. bright sunlight', 'at dusk', 'golden time', 'natural lighting', 'shadow from window', 'soft studio lighting', 'red lighting', 'purple lighting' ] quick_prompts = [[x] for x in quick_prompts] quick_content_prompts = [ 'by the sea', 'in the forest', 'on the snow mountain', 'by the city street', 'on the grassy field', 'cityscape', 'on the desert', 'in the living room', ] quick_content_prompts = [[x] for x in quick_content_prompts] quick_subjects = [ 'portrait photography of a woman', 'portrait photography of man', 'product photography', ] quick_subjects = [[x] for x in quick_subjects] with gr.Blocks().queue() as demo: gr.HTML("""
FreeLighting: A Next-generation Relighting Model with Background Replica from Any Perspective Angle
""") with gr.Row(): gr.Markdown("See more information in https://github.com/liuyuxuan3060/FreeLighting") with gr.Row(): gr.Markdown("We use an open source segmentation model to generate image mask") with gr.Tabs(): with gr.TabItem("image-guided relighting") as i2v_tab: with gr.Row(): with gr.Column(): with gr.Row(): image = gr.Image(label="original_image", type="pil", height=480) ref_img = gr.Image(label="reference_image", type="pil", height=480) image_mask = gr.Image(label="image_mask", type="pil", height=480) with gr.Row(): seed = gr.Number(value=12345, label="random seed", precision=0) steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) button = gr.Button("generate") with gr.Column(): relighting_image = gr.Image(label="relighted_image", type="pil", height=480) with gr.Row(): gr.Examples( examples=db_examples.image_guided_examples, inputs=[ image, ref_img, seed, steps , relighting_image ], outputs=[relighting_image], examples_per_page=1024 ) button.click( fn=call_image_guided_relighting, inputs=[ image, ref_img, seed, steps ], outputs=[relighting_image, image_mask], ) with gr.TabItem("text-guided relighting") as t2v_tab: with gr.Row(): with gr.Column(): with gr.Row(): image = gr.Image(label="original_image", type="pil", height=480) image_mask = gr.Image(label="image_mask", type="pil", height=480) with gr.Row(): prompt = gr.Textbox(value="", label="text prompt") with gr.Row(): example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick Prompt List. Click to choose a quick subject description', samples_per_page=1000, components=[prompt]) with gr.Row(): example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick Prompt List. Click to choose a quick lighting description', samples_per_page=1000, components=[prompt]) with gr.Row(): example_quick_content_prompts = gr.Dataset(samples=quick_content_prompts, label='Content Quick Prompt List. Click to choose a quick content description', samples_per_page=1000, components=[prompt]) with gr.Row(): mode = gr.Radio(choices=["normal", "uniform-lit"], value='normal', label="uniform-lit mode will use double time", type='value') seed = gr.Number(value=12345, label="random seed", precision=0) steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) button = gr.Button("generate") with gr.Column(): relighting_image = gr.Image(label="relighted_image", type="pil", height=480) with gr.Row(): gr.Examples( examples=db_examples.text_guided_examples, inputs=[ image, mode, prompt, seed, steps , relighting_image ], outputs=[relighting_image], examples_per_page=1024 ) button.click( fn=call_text_guided_relighting, inputs=[ image, mode, prompt, seed, steps ], outputs=[relighting_image, image_mask], ) example_quick_content_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_content_prompts, prompt], outputs=prompt, show_progress=False, queue=False) example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False) example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False) demo.launch()