File size: 2,312 Bytes
746c674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# This source file is part of DiffAI
# Copyright (c) 2018 Secure, Reliable, and Intelligent Systems Lab (SRI), ETH Zurich
# This software is distributed under the MIT License: https://opensource.org/licenses/MIT
# SPDX-License-Identifier: MIT
# For more information see https://github.com/eth-sri/diffai

# THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT ANY WARRANTY OF ANY KIND, EITHER
# EXPRESS, IMPLIED OR STATUTORY, INCLUDING BUT NOT LIMITED TO ANY WARRANTY
# THAT THE SOFTWARE WILL CONFORM TO SPECIFICATIONS OR BE ERROR-FREE AND ANY
# IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE,
# TITLE, OR NON-INFRINGEMENT.  IN NO EVENT SHALL ETH ZURICH BE LIABLE FOR ANY     
#  DAMAGES, INCLUDING BUT NOT LIMITED TO DIRECT, INDIRECT,
# SPECIAL OR CONSEQUENTIAL DAMAGES, ARISING OUT OF, RESULTING FROM, OR IN
# ANY WAY CONNECTED WITH THIS SOFTWARE (WHETHER OR NOT BASED UPON WARRANTY,
# CONTRACT, TORT OR OTHERWISE).

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import helpers as h
import domains
from domains import *
import math


POINT_DOMAINS = [m for m in h.getMethods(domains) if h.hasMethod(m, "attack")] + [ torch.FloatTensor, torch.Tensor, torch.cuda.FloatTensor ] 
SYMETRIC_DOMAINS = [domains.Box] + POINT_DOMAINS

def domRes(outDom, target, **args): # TODO: make faster again by keeping sparse tensors sparse
    t = h.one_hot(target.data.long(), outDom.size()[1]).to_dense()
    tmat = t.unsqueeze(2).matmul(t.unsqueeze(1))
    
    tl = t.unsqueeze(2).expand(-1, -1, tmat.size()[1])
    
    inv_t = h.eye(tmat.size()[1]).expand(tmat.size()[0], -1, -1)
    inv_t = inv_t - tmat
    
    tl = tl.bmm(inv_t)
    
    fst = outDom.bmm(tl)
    snd = outDom.bmm(inv_t)
    diff = fst - snd
    return diff.lb() + t

def isSafeDom(outDom, target, **args):
    od,_ = torch.min(domRes(outDom, target, **args), 1)
    return od.gt(0.0).long().item()


def isSafeBox(target, net, inp, eps, dom):
    atarg = target.argmax(1)[0].unsqueeze(0)
    if hasattr(dom, "attack"):
        x = dom.attack(net, eps, inp, target)
        pred = net(x).argmax(1)[0].unsqueeze(0) # get the index of the max log-probability
        return pred.item() == atarg.item()
    else:
        outDom = net(dom.box(inp, eps))
        return isSafeDom(outDom, atarg)