File size: 4,982 Bytes
ec17e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import base64
import json
import ntpath
import os
import time

import gradio as gr
import requests
from google.cloud import storage

from base_task_executor import BaseTaskExecutor

# ---
enc = "utf-8"


def decode(string):
    return json.loads(base64.b64decode(string.encode(enc)).decode(enc))


def get_storage_client_from_env():
    credentials_json = decode(os.environ["GCP_API_KEY"])
    return storage.Client.from_service_account_info(credentials_json)


def get_name_ext(filepath):
    filepath = os.path.abspath(filepath)
    _, name_ext = os.path.split(filepath)
    name, ext = os.path.splitext(name_ext)
    return name, ext


def make_remote_media_path(request_id, media_path):
    assert len(request_id) > 6
    assert os.path.exists(media_path)
    src_id = request_id[:3]
    slot_id = request_id[3:6]
    request_suffix = request_id[6:]
    name, ext = get_name_ext(media_path)
    return os.path.join(src_id, slot_id, request_suffix, name + ext)


def copy_file_to_gcloud(bucket, local_file_path, remote_file_path):
    blob = bucket.blob(remote_file_path)
    blob.upload_from_filename(local_file_path)

def copy_to_gcloud(storage_client, local_media_path, bucket_name, remote_media_path):
    bucket = storage_client.get_bucket(bucket_name)
    copy_file_to_gcloud(bucket, local_media_path, remote_media_path)


# ---


class CloudTaskExecutor(BaseTaskExecutor):
    def __init__(self):
        super().__init__()
        self.base_url = os.getenv("SUTRA_AVATAR_BASE_URL")
        self.headers = {"Authorization": f'{os.getenv("SUTRA_AVATAR_API_KEY")}', "Content-Type": "application/json"}
        self.bucket_name = os.getenv("SUTRA_AVATAR_BUCKET_NAME")
        self.storage_client = get_storage_client_from_env()

    def submit_task(self, submit_request):
        url = f"{self.base_url}/task/submit"
        response = requests.post(url, json=submit_request, headers=self.headers)
        if response.status_code == 200:
            return response.json()
        else:
            response.raise_for_status()

    def get_task_status(self, request_id):
        url = f"{self.base_url}/task/status"
        response = requests.get(url, params={"rid": request_id}, headers=self.headers)
        if response.status_code == 200:
            return response.json()
        else:
            response.raise_for_status()

    def generate(
        self,
        input_base_path,
        input_driving_path,
        base_motion_expression,
        input_driving_audio_path,
        output_video_path,
        request_id,
    ):

        # Upload files
        media_paths = [input_base_path, input_driving_audio_path]
        for media_path in media_paths:
            if media_path:
                remote_media_path = make_remote_media_path(request_id, media_path)
                copy_to_gcloud(self.storage_client, media_path, self.bucket_name, remote_media_path)

        submit_request = {
            "requestId": request_id,
            "input_base_path": ntpath.basename(input_base_path),
            "input_driving_path": "",
            "base_motion_expression": base_motion_expression,
            "input_driving_audio_path": ntpath.basename(input_driving_audio_path),
            "output_video_path": ntpath.basename(output_video_path),
        }
        submit_reply = self.submit_task(submit_request)
        estimatedWaitSeconds = "unknown"
        if "estimatedWaitSeconds" in submit_reply.keys():
            estimatedWaitSeconds = submit_reply["estimatedWaitSeconds"]

        completion_statuses = {"Succeeded", "Cancelled", "Failed", "NotFound"}
        timeout = 240  # maximum time to wait in seconds
        if isinstance(estimatedWaitSeconds, int):
            timeout += estimatedWaitSeconds
        start_time = time.time()

        result = {"messages": ''}
        while True:
            status_reply = self.get_task_status(request_id)
            task_status = status_reply["taskStatus"]

            if status_reply["taskStatus"] in completion_statuses:
                break

            if time.time() - start_time > timeout:
                msg = "The task did not complete within the timeout period.\n The server is very busy serving other requests.\n Please try again."
                result["success"] = False
                result["messages"] = msg
                gr.Error(msg)
                break
            time.sleep(3)

        task_status = status_reply["taskStatus"]
        if task_status == "Succeeded":
            pipe_reply = status_reply["pipeReply"]
            result["success"] = pipe_reply["status"] == "success"
            result["messages"] = pipe_reply["messages"]
            output_video_path = status_reply["videoURL"]
        else:
            messages = ""
            if "pipeReply" in status_reply.keys():
                messages = status_reply["pipeReply"]["messages"]
            result["success"] = False
            result["messages"] += messages
        return result, output_video_path