File size: 2,271 Bytes
5afd101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# copyright zxix 2022
# https://creativecommons.org/licenses/by-nc-sa/4.0/
import torch
import pickle_inspector
import sys
from pathlib import Path

debug = len(sys.argv) == 3

dir = sys.argv[1]
print("checking dir: " + dir)

BASE_DIR = Path(dir)
EXTENSIONS = {'.pt', '.bin', '.ckpt'}
BAD_CALLS = {'os', 'shutil', 'sys', 'requests', 'net'}
BAD_SIGNAL = {'rm ', 'cat ', 'nc ', '/bin/sh '}

for path in BASE_DIR.glob(r'**/*'):
  if path.suffix in EXTENSIONS:
    print("")
    print("..." + path.as_posix())
    result = torch.load(path.as_posix(), pickle_module=pickle_inspector.pickle)
    result_total = 0
    result_other = 0
    result_calls = {}
    result_signals = {}
    result_output = ""

    for call in BAD_CALLS:
      result_calls[call] = 0

    for signal in BAD_SIGNAL:
      result_signals[signal] = 0

    for c in result.calls:
      for call in BAD_CALLS:
        if (c.find(call + ".") == 0):
          result_calls[call] += 1
          result_total += 1
          result_output += "\n--- found lib call (" + call + ") ---\n"
          result_output += c
          result_output += "\n---------------\n"
          break
      for signal in BAD_SIGNAL:
        if (c.find(signal) > -1):
          result_signals[signal] += 1
          result_total += 1
          result_output += "\n--- found malicious signal (" + signal + ") ---\n"
          result_output += c
          result_output += "\n---------------\n"
          break

      if (
        c.find("numpy.") != 0 and 
        c.find("_codecs.") != 0 and 
        c.find("collections.") != 0 and 
        c.find("torch.") != 0):
        result_total += 1
        result_other += 1
        result_output += "\n--- found non-standard lib call ---\n"
        result_output += c
        result_output += "\n---------------\n"

    if (result_total > 0):
      for call in BAD_CALLS:
        print("library call (" + call + ".): " + str(result_calls[call]))
      for signal in BAD_SIGNAL:
        print("malicious signal (" + signal + "): " + str(result_signals[signal]))
      print("non-standard calls: " + str(result_other))
      print("total: " + str(result_total))
      print("")
      print("SCAN FAILED")

      if (debug):
        print(result_output)
    else:
      print("SCAN PASSED!")