File size: 1,957 Bytes
621acf4 e776bd4 621acf4 313ae49 621acf4 b3c44e2 621acf4 e776bd4 621acf4 079f878 e776bd4 621acf4 079f878 621acf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import os
import pty
import subprocess
import select
import shlex
import base64
class EndpointHandler:
def __init__(self, *args, **kwargs):
pass
def run_command(self, command):
def read_output(fd):
output = b""
while True:
r, _, _ = select.select([fd], [], [], 0.1)
if fd in r:
data = os.read(fd, 1024)
if not data:
break
output += data
else:
break
return output
master_fd, slave_fd = pty.openpty()
try:
process = subprocess.Popen(
shlex.split(command),
stdin=slave_fd,
stdout=slave_fd,
stderr=slave_fd,
text=True,
)
output = b""
while process.poll() is None:
output += read_output(master_fd)
# Capture any remaining output after the process has finished
output += read_output(master_fd)
process.wait()
finally:
os.close(master_fd)
os.close(slave_fd)
return output.decode()
def _decode_base64(self, command):
try:
decoded_command = base64.b64decode(command).decode()
except Exception as e:
return command
return decoded_command
def __call__(self, data):
"""
:param data: input data from the inference endpoint REST API
:return: output data
"""
# get inputs
command = data.pop("inputs", None)
command = self._decode_base64(command)
if not isinstance(command, str):
return {"error": "inputs attribute is required"}
# run command
result = self.run_command(command)
return {"result": result} |