Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 4,982 Bytes
ec17e66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import base64
import json
import ntpath
import os
import time
import gradio as gr
import requests
from google.cloud import storage
from base_task_executor import BaseTaskExecutor
# ---
enc = "utf-8"
def decode(string):
return json.loads(base64.b64decode(string.encode(enc)).decode(enc))
def get_storage_client_from_env():
credentials_json = decode(os.environ["GCP_API_KEY"])
return storage.Client.from_service_account_info(credentials_json)
def get_name_ext(filepath):
filepath = os.path.abspath(filepath)
_, name_ext = os.path.split(filepath)
name, ext = os.path.splitext(name_ext)
return name, ext
def make_remote_media_path(request_id, media_path):
assert len(request_id) > 6
assert os.path.exists(media_path)
src_id = request_id[:3]
slot_id = request_id[3:6]
request_suffix = request_id[6:]
name, ext = get_name_ext(media_path)
return os.path.join(src_id, slot_id, request_suffix, name + ext)
def copy_file_to_gcloud(bucket, local_file_path, remote_file_path):
blob = bucket.blob(remote_file_path)
blob.upload_from_filename(local_file_path)
def copy_to_gcloud(storage_client, local_media_path, bucket_name, remote_media_path):
bucket = storage_client.get_bucket(bucket_name)
copy_file_to_gcloud(bucket, local_media_path, remote_media_path)
# ---
class CloudTaskExecutor(BaseTaskExecutor):
def __init__(self):
super().__init__()
self.base_url = os.getenv("SUTRA_AVATAR_BASE_URL")
self.headers = {"Authorization": f'{os.getenv("SUTRA_AVATAR_API_KEY")}', "Content-Type": "application/json"}
self.bucket_name = os.getenv("SUTRA_AVATAR_BUCKET_NAME")
self.storage_client = get_storage_client_from_env()
def submit_task(self, submit_request):
url = f"{self.base_url}/task/submit"
response = requests.post(url, json=submit_request, headers=self.headers)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
def get_task_status(self, request_id):
url = f"{self.base_url}/task/status"
response = requests.get(url, params={"rid": request_id}, headers=self.headers)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
def generate(
self,
input_base_path,
input_driving_path,
base_motion_expression,
input_driving_audio_path,
output_video_path,
request_id,
):
# Upload files
media_paths = [input_base_path, input_driving_audio_path]
for media_path in media_paths:
if media_path:
remote_media_path = make_remote_media_path(request_id, media_path)
copy_to_gcloud(self.storage_client, media_path, self.bucket_name, remote_media_path)
submit_request = {
"requestId": request_id,
"input_base_path": ntpath.basename(input_base_path),
"input_driving_path": "",
"base_motion_expression": base_motion_expression,
"input_driving_audio_path": ntpath.basename(input_driving_audio_path),
"output_video_path": ntpath.basename(output_video_path),
}
submit_reply = self.submit_task(submit_request)
estimatedWaitSeconds = "unknown"
if "estimatedWaitSeconds" in submit_reply.keys():
estimatedWaitSeconds = submit_reply["estimatedWaitSeconds"]
completion_statuses = {"Succeeded", "Cancelled", "Failed", "NotFound"}
timeout = 240 # maximum time to wait in seconds
if isinstance(estimatedWaitSeconds, int):
timeout += estimatedWaitSeconds
start_time = time.time()
result = {"messages": ''}
while True:
status_reply = self.get_task_status(request_id)
task_status = status_reply["taskStatus"]
if status_reply["taskStatus"] in completion_statuses:
break
if time.time() - start_time > timeout:
msg = "The task did not complete within the timeout period.\n The server is very busy serving other requests.\n Please try again."
result["success"] = False
result["messages"] = msg
gr.Error(msg)
break
time.sleep(3)
task_status = status_reply["taskStatus"]
if task_status == "Succeeded":
pipe_reply = status_reply["pipeReply"]
result["success"] = pipe_reply["status"] == "success"
result["messages"] = pipe_reply["messages"]
output_video_path = status_reply["videoURL"]
else:
messages = ""
if "pipeReply" in status_reply.keys():
messages = status_reply["pipeReply"]["messages"]
result["success"] = False
result["messages"] += messages
return result, output_video_path
|