Source code for cleverhans.attacks.saliency_map_method

"""The SalienceMapMethod attack
"""
# pylint: disable=missing-docstring
import warnings

import numpy as np
from six.moves import xrange
import tensorflow as tf

from cleverhans.attacks.attack import Attack
from cleverhans.compat import reduce_sum, reduce_max, reduce_any

tf_dtype = tf.as_dtype('float32')


[docs]class SaliencyMapMethod(Attack): """ The Jacobian-based Saliency Map Method (Papernot et al. 2016). Paper link: https://arxiv.org/pdf/1511.07528.pdf :param model: cleverhans.model.Model :param sess: optional tf.Session :param dtypestr: dtype of the data :param kwargs: passed through to super constructor :note: When not using symbolic implementation in `generate`, `sess` should be provided """ def __init__(self, model, sess=None, dtypestr='float32', **kwargs): """ Create a SaliencyMapMethod instance. Note: the model parameter should be an instance of the cleverhans.model.Model abstraction provided by CleverHans. """ super(SaliencyMapMethod, self).__init__(model, sess, dtypestr, **kwargs) self.feedable_kwargs = ('y_target',) self.structural_kwargs = [ 'theta', 'gamma', 'clip_max', 'clip_min', 'symbolic_impl' ]
[docs] def generate(self, x, **kwargs): """ Generate symbolic graph for adversarial examples and return. :param x: The model's symbolic inputs. :param kwargs: See `parse_params` """ # Parse and save attack-specific parameters assert self.parse_params(**kwargs) if self.symbolic_impl: # Create random targets if y_target not provided if self.y_target is None: from random import randint def random_targets(gt): result = gt.copy() nb_s = gt.shape[0] nb_classes = gt.shape[1] for i in range(nb_s): result[i, :] = np.roll(result[i, :], randint(1, nb_classes - 1)) return result labels, nb_classes = self.get_or_guess_labels(x, kwargs) self.y_target = tf.py_func(random_targets, [labels], self.tf_dtype) self.y_target.set_shape([None, nb_classes]) x_adv = jsma_symbolic( x, model=self.model, y_target=self.y_target, theta=self.theta, gamma=self.gamma, clip_min=self.clip_min, clip_max=self.clip_max) else: raise NotImplementedError("The jsma_batch function has been removed." " The symbolic_impl argument to SaliencyMapMethod will be removed" " on 2019-07-18 or after. Any code that depends on the non-symbolic" " implementation of the JSMA should be revised. Consider using" " SaliencyMapMethod.generate_np() instead.") return x_adv
[docs] def parse_params(self, theta=1., gamma=1., clip_min=0., clip_max=1., y_target=None, symbolic_impl=True, **kwargs): """ Take in a dictionary of parameters and applies attack-specific checks before saving them as attributes. Attack-specific parameters: :param theta: (optional float) Perturbation introduced to modified components (can be positive or negative) :param gamma: (optional float) Maximum percentage of perturbed features :param clip_min: (optional float) Minimum component value for clipping :param clip_max: (optional float) Maximum component value for clipping :param y_target: (optional) Target tensor if the attack is targeted """ self.theta = theta self.gamma = gamma self.clip_min = clip_min self.clip_max = clip_max self.y_target = y_target self.symbolic_impl = symbolic_impl if len(kwargs.keys()) > 0: warnings.warn("kwargs is unused and will be removed on or after " "2019-04-26.") return True
def jsma_batch(*args, **kwargs): raise NotImplementedError( "The jsma_batch function has been removed. Any code that depends on it should be revised.") def jsma_symbolic(x, y_target, model, theta, gamma, clip_min, clip_max): """ TensorFlow implementation of the JSMA (see https://arxiv.org/abs/1511.07528 for details about the algorithm design choices). :param x: the input placeholder :param y_target: the target tensor :param model: a cleverhans.model.Model object. :param theta: delta for each feature adjustment :param gamma: a float between 0 - 1 indicating the maximum distortion percentage :param clip_min: minimum value for components of the example returned :param clip_max: maximum value for components of the example returned :return: a tensor for the adversarial example """ nb_classes = int(y_target.shape[-1].value) nb_features = int(np.product(x.shape[1:]).value) if x.dtype == tf.float32 and y_target.dtype == tf.int64: y_target = tf.cast(y_target, tf.int32) if x.dtype == tf.float32 and y_target.dtype == tf.float64: warnings.warn("Downcasting labels---this should be harmless unless" " they are smoothed") y_target = tf.cast(y_target, tf.float32) max_iters = np.floor(nb_features * gamma / 2) increase = bool(theta > 0) tmp = np.ones((nb_features, nb_features), int) np.fill_diagonal(tmp, 0) zero_diagonal = tf.constant(tmp, tf_dtype) # Compute our initial search domain. We optimize the initial search domain # by removing all features that are already at their maximum values (if # increasing input features---otherwise, at their minimum value). if increase: search_domain = tf.reshape( tf.cast(x < clip_max, tf_dtype), [-1, nb_features]) else: search_domain = tf.reshape( tf.cast(x > clip_min, tf_dtype), [-1, nb_features]) # Loop variables # x_in: the tensor that holds the latest adversarial outputs that are in # progress. # y_in: the tensor for target labels # domain_in: the tensor that holds the latest search domain # cond_in: the boolean tensor to show if more iteration is needed for # generating adversarial samples def condition(x_in, y_in, domain_in, i_in, cond_in): # Repeat the loop until we have achieved misclassification or # reaches the maximum iterations return tf.logical_and(tf.less(i_in, max_iters), cond_in) # Same loop variables as above def body(x_in, y_in, domain_in, i_in, cond_in): # Create graph for model logits and predictions logits = model.get_logits(x_in) preds = tf.nn.softmax(logits) preds_onehot = tf.one_hot(tf.argmax(preds, axis=1), depth=nb_classes) # create the Jacobian graph list_derivatives = [] for class_ind in xrange(nb_classes): derivatives = tf.gradients(logits[:, class_ind], x_in) list_derivatives.append(derivatives[0]) grads = tf.reshape( tf.stack(list_derivatives), shape=[nb_classes, -1, nb_features]) # Compute the Jacobian components # To help with the computation later, reshape the target_class # and other_class to [nb_classes, -1, 1]. # The last dimention is added to allow broadcasting later. target_class = tf.reshape( tf.transpose(y_in, perm=[1, 0]), shape=[nb_classes, -1, 1]) other_classes = tf.cast(tf.not_equal(target_class, 1), tf_dtype) grads_target = reduce_sum(grads * target_class, axis=0) grads_other = reduce_sum(grads * other_classes, axis=0) # Remove the already-used input features from the search space # Subtract 2 times the maximum value from those value so that # they won't be picked later increase_coef = (4 * int(increase) - 2) \ * tf.cast(tf.equal(domain_in, 0), tf_dtype) target_tmp = grads_target target_tmp -= increase_coef \ * reduce_max(tf.abs(grads_target), axis=1, keepdims=True) target_sum = tf.reshape(target_tmp, shape=[-1, nb_features, 1]) \ + tf.reshape(target_tmp, shape=[-1, 1, nb_features]) other_tmp = grads_other other_tmp += increase_coef \ * reduce_max(tf.abs(grads_other), axis=1, keepdims=True) other_sum = tf.reshape(other_tmp, shape=[-1, nb_features, 1]) \ + tf.reshape(other_tmp, shape=[-1, 1, nb_features]) # Create a mask to only keep features that match conditions if increase: scores_mask = ((target_sum > 0) & (other_sum < 0)) else: scores_mask = ((target_sum < 0) & (other_sum > 0)) # Create a 2D numpy array of scores for each pair of candidate features scores = tf.cast(scores_mask, tf_dtype) \ * (-target_sum * other_sum) * zero_diagonal # Extract the best two pixels best = tf.argmax( tf.reshape(scores, shape=[-1, nb_features * nb_features]), axis=1) p1 = tf.mod(best, nb_features) p2 = tf.floordiv(best, nb_features) p1_one_hot = tf.one_hot(p1, depth=nb_features) p2_one_hot = tf.one_hot(p2, depth=nb_features) # Check if more modification is needed for each sample mod_not_done = tf.equal(reduce_sum(y_in * preds_onehot, axis=1), 0) cond = mod_not_done & (reduce_sum(domain_in, axis=1) >= 2) # Update the search domain cond_float = tf.reshape(tf.cast(cond, tf_dtype), shape=[-1, 1]) to_mod = (p1_one_hot + p2_one_hot) * cond_float domain_out = domain_in - to_mod # Apply the modification to the images to_mod_reshape = tf.reshape( to_mod, shape=([-1] + x_in.shape[1:].as_list())) if increase: x_out = tf.minimum(clip_max, x_in + to_mod_reshape * theta) else: x_out = tf.maximum(clip_min, x_in - to_mod_reshape * theta) # Increase the iterator, and check if all misclassifications are done i_out = tf.add(i_in, 1) cond_out = reduce_any(cond) return x_out, y_in, domain_out, i_out, cond_out # Run loop to do JSMA x_adv, _, _, _, _ = tf.while_loop( condition, body, [x, y_target, search_domain, 0, True], parallel_iterations=1) return x_adv