|
|
|
|
|
import torch |
|
import pickle_inspector |
|
import sys |
|
from pathlib import Path |
|
|
|
debug = len(sys.argv) == 3 |
|
|
|
dir = sys.argv[1] |
|
print("checking dir: " + dir) |
|
|
|
BASE_DIR = Path(dir) |
|
EXTENSIONS = {'.pt', '.bin', '.ckpt'} |
|
BAD_CALLS = {'os', 'shutil', 'sys', 'requests', 'net'} |
|
BAD_SIGNAL = {'rm ', 'cat ', 'nc ', '/bin/sh '} |
|
|
|
for path in BASE_DIR.glob(r'**/*'): |
|
if path.suffix in EXTENSIONS: |
|
print("") |
|
print("..." + path.as_posix()) |
|
result = torch.load(path.as_posix(), pickle_module=pickle_inspector.pickle) |
|
result_total = 0 |
|
result_other = 0 |
|
result_calls = {} |
|
result_signals = {} |
|
result_output = "" |
|
|
|
for call in BAD_CALLS: |
|
result_calls[call] = 0 |
|
|
|
for signal in BAD_SIGNAL: |
|
result_signals[signal] = 0 |
|
|
|
for c in result.calls: |
|
for call in BAD_CALLS: |
|
if (c.find(call + ".") == 0): |
|
result_calls[call] += 1 |
|
result_total += 1 |
|
result_output += "\n--- found lib call (" + call + ") ---\n" |
|
result_output += c |
|
result_output += "\n---------------\n" |
|
break |
|
for signal in BAD_SIGNAL: |
|
if (c.find(signal) > -1): |
|
result_signals[signal] += 1 |
|
result_total += 1 |
|
result_output += "\n--- found malicious signal (" + signal + ") ---\n" |
|
result_output += c |
|
result_output += "\n---------------\n" |
|
break |
|
|
|
if ( |
|
c.find("numpy.") != 0 and |
|
c.find("_codecs.") != 0 and |
|
c.find("collections.") != 0 and |
|
c.find("torch.") != 0): |
|
result_total += 1 |
|
result_other += 1 |
|
result_output += "\n--- found non-standard lib call ---\n" |
|
result_output += c |
|
result_output += "\n---------------\n" |
|
|
|
if (result_total > 0): |
|
for call in BAD_CALLS: |
|
print("library call (" + call + ".): " + str(result_calls[call])) |
|
for signal in BAD_SIGNAL: |
|
print("malicious signal (" + signal + "): " + str(result_signals[signal])) |
|
print("non-standard calls: " + str(result_other)) |
|
print("total: " + str(result_total)) |
|
print("") |
|
print("SCAN FAILED") |
|
|
|
if (debug): |
|
print(result_output) |
|
else: |
|
print("SCAN PASSED!") |
|
|