use influence function a classic technique from robust statistics to trace a model’s prediction through the learning algorithm and back to its training data
- 365
- 0
- 1
注脚
展开查看详情
1. Understanding Black-box Predictions via Influence Functions Pang Wei Koh 1 Percy Liang 1 Abstract point (Ribeiro et al., 2016) or by perturbing the test point to How can we explain the predictions of a black- see how the prediction changes (Simonyan et al., 2013; Li box model? In this paper, we use influence func- et al., 2016b; Datta et al., 2016; Adler et al., 2016). These tions — a classic technique from robust statis- works explain the predictions in terms of the model, but how can we explain where the model came from? arXiv:1703.04730v2 [stat.ML] 10 Jul 2017 tics — to trace a model’s prediction through the learning algorithm and back to its training data, In this paper, we tackle this question by tracing a model’s thereby identifying training points most respon- predictions through its learning algorithm and back to the sible for a given prediction. To scale up influence training data, where the model parameters ultimately de- functions to modern machine learning settings, rive from. To formalize the impact of a training point on a we develop a simple, efficient implementation prediction, we ask the counterfactual: what would happen that requires only oracle access to gradients and if we did not have this training point, or if the values of this Hessian-vector products. We show that even on training point were changed slightly? non-convex and non-differentiable models where the theory breaks down, approximations to influ- Answering this question by perturbing the data and retrain- ence functions can still provide valuable infor- ing the model can be prohibitively expensive. To overcome mation. On linear models and convolutional neu- this problem, we use influence functions, a classic tech- ral networks, we demonstrate that influence func- nique from robust statistics (Cook & Weisberg, 1980) that tions are useful for multiple purposes: under- tells us how the model parameters change as we upweight standing model behavior, debugging models, de- a training point by an infinitesimal amount. This allows us tecting dataset errors, and even creating visually- to “differentiate through the training” to estimate in closed- indistinguishable training-set attacks. form the effect of a variety of training perturbations. Despite their rich history in statistics, influence functions 1. Introduction have not seen widespread use in machine learning; to the best of our knowledge, the work closest to ours is Wo- A key question often asked of machine learning systems jnowicz et al. (2016), which introduced a method for ap- is “Why did the system make this prediction?” We want proximating a quantity related to influence in generalized models that are not just high-performing but also explain- linear models. One obstacle to adoption is that influ- able. By understanding why a model does what it does, we ence functions require expensive second derivative calcu- can hope to improve the model (Amershi et al., 2015), dis- lations and assume model differentiability and convexity, cover new science (Shrikumar et al., 2016), and provide which limits their applicability in modern contexts where end-users with explanations of actions that impact them models are often non-differentiable, non-convex, and high- (Goodman & Flaxman, 2016). dimensional. We address these challenges by showing that However, the best-performing models in many domains — we can efficiently approximate influence functions using e.g., deep neural networks for image and speech recogni- second-order optimization techniques (Pearlmutter, 1994; tion (Krizhevsky et al., 2012) — are complicated, black- Martens, 2010; Agarwal et al., 2016), and that they remain box models whose predictions seem hard to explain. Work accurate even as the underlying assumptions of differentia- on interpreting these black-box models has focused on un- bility and convexity degrade. derstanding how a fixed model leads to particular predic- Influence functions capture the core idea of studying mod- tions, e.g., by locally fitting a simpler model around the test els through the lens of their training data. We show that 1 Stanford University, Stanford, CA. Correspondence to: they are a versatile tool that can be applied to a wide variety Pang Wei Koh <pangwei@cs.stanford.edu>, Percy Liang <pli- of seemingly disparate tasks: understanding model behav- ang@cs.stanford.edu>. ior, debugging models, detecting dataset errors, and cre- ating visually-indistinguishable adversarial training exam- Proceedings of the 34 th International Conference on Machine ples that can flip neural network test predictions, the train- Learning, Sydney, Australia, PMLR 70, 2017. Copyright 2017 by the author(s). ing set analogue of Goodfellow et al. (2015).
2. Understanding Black-box Predictions via Influence Functions 2. Approach 2.2. Perturbing a training input Consider a prediction problem from some input space X Let us develop a finer-grained notion of influence by study- (e.g., images) to an output space Y (e.g., labels). We are ing a different counterfactual: how would the model’s pre- given training points z1 , . . . , zn , where zi = (xi , yi ) ∈ dictions change if a training input were modified? X × Y. For a point z and parameters θ ∈ Θ, let def n L(z, θ) be the loss, and let n1 i=1 L(zi , θ) be the em- For a training point z = (x, y), define zδ = (x + δ, y). pirical risk. The empirical risk minimizer is given by Consider the perturbation z → zδ , and let θˆzδ ,−z be the def n empirical risk minimizer on the training points with zδ in θˆ = arg minθ∈Θ n1 i=1 L(zi , θ).1 Assume that the em- place of z. To approximate its effects, define the parameters pirical risk is twice-differentiable and strictly convex in θ; def in Section 4 we explore relaxing these assumptions. resulting from moving mass from z onto zδ : θˆ ,zδ ,−z = 1 n arg minθ∈Θ n i=1 L(zi , θ) + L(zδ , θ) − L(z, θ). An 2.1. Upweighting a training point analogous calculation to (1) yields: Our goal is to understand the effect of training points on a dθˆ ,zδ ,−z = Iup,params (zδ ) − Iup,params (z) model’s predictions. We formalize this goal by asking the d =0 counterfactual: how would the model’s predictions change ˆ − ∇θ L(z, θ) = −Hθˆ−1 ∇θ L(zδ , θ) ˆ . (3) if we did not have this training point? As before, we can make the linear approximation θˆzδ ,−z − Let us begin by studying the change in model pa- θˆ ≈ − n1 (Iup,params (zδ ) − Iup,params (z)), giving us a closed- rameters due to removing a point z from the train- form estimate of the effect of z → zδ on the model. Anal- ing set. Formally, this change is θˆ−z − θ, ˆ where def ogous equations also apply for changes in y. While in- θˆ−z = arg minθ∈Θ zi =z L(zi , θ). However, retraining fluence functions might appear to only work for infinitesi- the model for each removed z is prohibitively slow. mal (therefore continuous) perturbations, it is important to Fortunately, influence functions give us an efficient approx- note that this approximation holds for arbitrary δ: the - imation. The idea is to compute the parameter change if z upweighting scheme allows us to smoothly interpolate be- were upweighted by some small , giving us new param- tween z and zδ . This is particularly useful for working with def n discrete data (e.g., in NLP) or with discrete label changes. eters θˆ ,z = arg minθ∈Θ n1 i=1 L(zi , θ) + L(z, θ). A classic result (Cook & Weisberg, 1982) tells us that the in- If x is continuous and δ is small, we can further approxi- fluence of upweighting z on the parameters θˆ is given by mate (3). Assume that the input domain X ⊆ Rd , the pa- rameter space Θ ⊆ Rp , and L is differentiable in θ and x. def dθˆ ,z ˆ As δ → 0, ∇θ L(zδ , θ)ˆ − ∇θ L(z, θ)ˆ ≈ [∇x ∇θ L(z, θ)]δ, ˆ Iup,params (z) = = −Hθˆ−1 ∇θ L(z, θ), (1) d =0 ˆ where ∇x ∇θ L(z, θ) ∈ R p×d . Substituting into (3), def n where Hθˆ = n1 i=1 ∇2θ L(zi , θ) ˆ is the Hessian and is dθˆ ,zδ ,−z ˆ ≈ −Hθˆ−1 [∇x ∇θ L(z, θ)]δ. (4) positive definite (PD) by assumption. In essence, we form d =0 a quadratic approximation to the empirical risk around θˆ and take a single Newton step; see appendix A for a deriva- We thus have θˆzδ ,−z − θˆ ≈ − n1 Hθˆ−1 [∇x ∇θ L(z, θ)]δ. ˆ Dif- tion. Since removing a point z is the same as upweighting ferentiating w.r.t. δ and applying the chain rule gives us it by = − n1 , we can linearly approximate the parame- Ipert,loss (z, ztest ) def = ∇δ L(ztest , θˆzδ ,−z ) (5) ter change due to removing z by computing θˆ−z − θˆ ≈ δ=0 − n1 Iup,params (z), without retraining the model. ˆ = −∇θ L(ztest , θ) ˆ Hθˆ−1 ∇x ∇θ L(z, θ). Next, we apply the chain rule to measure how upweighting Ipert,loss (z, ztest ) δ tells us the approximate effect that z → ˆ In particular, the influence of z changes functions of θ. z + δ has on the loss at ztest . By setting δ in the direction of upweighting z on the loss at a test point ztest again has a Ipert,loss (z, ztest ), we can construct local perturbations of z closed-form expression: that maximally increase the loss at ztest . In Section 5.2, we will use this to construct training-set attacks. Finally, we def dL(ztest , θˆ ,z ) note that Ipert,loss (z, ztest ) can help us identify the features Iup,loss (z, ztest ) = (2) d =0 of z that are most responsible for the prediction on ztest . ˆ ˆ dθ ,z = ∇θ L(ztest , θ) d =0 2.3. Relation to Euclidean distance ˆ H −1 ∇θ L(z, θ). = −∇θ L(ztest , θ) ˆ θˆ To find the training points most relevant to a test point, it 1 We fold in any regularization terms into L. is common to look at its nearest neighbors in Euclidean
3. Understanding Black-box Predictions via Influence Functions Figure 1. Components of influence. (a) What is the effect of the training loss and Hθˆ−1 terms in Iup,loss ? Here, we plot Iup,loss against variants that are missing these terms and show that they are necessary for picking up the truly influential training points. For these calculations, we use logistic regression to distinguish 1’s from 7’s in MNIST (LeCun et al., 1998), picking an arbitrary test point ztest ; similar trends hold across other test points. Green dots are train images of the same label as the test image (7) while red dots are 1’s. Left: Without the train loss term, we overestimate the influence of many training points: the points near the y=0 line should have Iup,loss close to 0, but instead have high influence when we remove the train loss term. Mid: Without Hθˆ−1 , all green training points are helpful (removing each point increases test loss) and all red points are harmful (removing each point decreases test loss). This is because ∀x, x 0 (all pixel values are positive), so x · xtest ≥ 0, but it is incorrect: many harmful training points actually share the same label as ztest . See panel (b). Right: Without training loss or Hθˆ−1 , what is left is the scaled Euclidean inner product ytest y · σ(−ytest θ xtest ) · xtest x, which fails to accurately capture influence; the scatter plot deviates quite far from the diagonal. (b) The test image and a harmful training image with the same label. To the model, they look very different, so the presence of the training image makes the model think that the test image is less likely to be a 7. The Euclidean inner product does not pick up on these less intuitive, but important, harmful influences. space (e.g., Ribeiro et al. (2016)); if all points have the is too expensive for models like deep neural networks with same norm, this is equivalent to choosing x with the largest millions of parameters. Second, we often want to calculate x · xtest . For intuition, we compare this to Iup,loss (z, ztest ) on Iup,loss (zi , ztest ) across all training points zi . a logistic regression model and show that influence is much The first problem is well-studied in second-order optimiza- more accurate at accounting for the effect of training. tion. The idea is to avoid explicitly computing Hθˆ−1 ; in- Let p(y | x) = σ(yθ x), with y ∈ {−1, 1} and σ(t) = stead, we use implicit Hessian-vector products (HVPs) to 1 def 1+exp(−t) . We seek to maximize the probability of the ˆ and then efficiently approximate stest = Hθˆ−1 ∇θ L(ztest , θ) training set. For a training point z = (x, y), L(z, θ) = ˆ This also compute Iup,loss (z, ztest ) = −stest · ∇θ L(z, θ). log(1 + exp(−yθ x)), ∇θ L(z, θ) = −σ(−yθ x)yx, solves the second problem: for each test point of inter- n and Hθ = n1 i=1 σ(θ xi )σ(−θ xi )xi xi . From (2), est, we can precompute stest and then efficiently compute Iup,loss (z, ztest ) is: ˆ for each training point zi . −stest · ∇θ L(zi , θ) −ytest y · σ(−ytest θ xtest ) · σ(−yθ x) · xtest Hθˆ−1 x. We discuss two techniques for approximating stest , both We highlight two key differences from x · xtest . First, relying on the fact that the HVP of a single term in Hθˆ, σ(−yθ x) gives points with high training loss more influ- ˆ [∇2θ L(zi , θ)]v, can be computed for arbitrary v in the same ence, revealing that outliers can dominate the model pa- time that ∇θ L(zi , θ)ˆ would take, which is typically O(p) rameters. Second, the weighted covariance matrix Hθˆ−1 (Pearlmutter, 1994). measures the “resistance” of the other training points to the ˆ points in a direction of little Conjugate gradients (CG). The first technique is a stan- removal of z; if ∇θ L(z, θ) dard transformation of matrix inversion into an optimiza- variation, its influence will be higher since moving in that tion problem. Since Hθˆ 0 by assumption, Hθˆ−1 v ≡ direction will not significantly increase the loss on other training points. As we show in Fig 1, these differences arg mint { 21 t Hθˆt − v t}. We can solve this with CG mean that influence functions capture the effect of model approaches that only require the evaluation of Hθˆt, which training much more accurately than nearest neighbors. takes O(np) time, without explicitly forming Hθˆ. While an exact solution takes p CG iterations, in practice we can get a good approximation with fewer iterations; see Martens 3. Efficiently Calculating Influence (2010) for more details. There are two computational challenges to using Stochastic estimation. With large datasets, standard CG ˆ H −1 ∇θ L(z, θ). Iup,loss (z, ztest ) = −∇θ L(ztest , θ) ˆ First, it can be slow; each iteration still goes through all n train- θˆ 1 n ˆ requires forming and inverting Hθˆ = n i=1 ∇2θ L(zi , θ), ing points. We use a method developed by Agarwal et al. the Hessian of the empirical risk. With n training points (2016) to get an estimator that only samples a single point and θ ∈ Rp , this requires O(np2 + p3 ) operations, which per iteration, which results in significant speedups.
4. Understanding Black-box Predictions via Influence Functions def j Dropping the θˆ subscript for clarity, let Hj−1 = i=0 (I − H)i , the first j terms in the Taylor expansion of H −1 . Rewrite this recursively as Hj−1 = I + (I − H)Hj−1 −1 . −1 −1 From the validity of the Taylor expansion, Hj → H as j → ∞.2 The key is that at each iteration, we can substi- tute the full H with a draw from any unbiased (and faster- to-compute) estimator of H to form H ˜ −1 ] = ˜ j . Since E[H j Hj−1 , we still have E[H ˜ −1 ] → H −1 . j Figure 2. Influence matches leave-one-out retraining. We arbi- In particular, we can uniformly sample zi and use trarily picked a wrongly-classified test point ztest , but this trend ∇2θ L(zi , θ) ˆ as an unbiased estimator of H. This gives held more broadly. These results are from 10-class MNIST. Left: us the following procedure: uniformly sample t points For each of the 500 training points z with largest Iup,loss (z, ztest ) , ˜ −1 v = we plotted − n1 · Iup,loss (z, ztest ) against the actual change in test zs1 , . . . , zst from the training data; define H 0 −1 ˜ v = v + I − loss after removing that point and retraining. The inverse HVP v; and recursively compute H j was solved exactly with CG. Mid: Same, but with the stochastic ∇2θ L(zsj , θ) ˆ H ˜ −1 v, taking H ˜ t−1 v as our final unbiased es- approximation. Right: The same plot for a CNN, computed on j−1 timate of H −1 v. We pick t to be large enough such that H ˜t the 100 most influential points with CG. For the actual difference stabilizes, and to reduce variance we repeat this procedure in loss, we removed each point and retrained from θ˜ for 30k steps. r times and average results. Empirically, we found this sig- nificantly faster than CG. strictly convex. Here, we empirically show that influence functions are accurate approximations (Section 4.1) that We note that the original method of Agarwal et al. (2016) provide useful information even when these assumptions dealt only with generalized linear models, for which are violated (Sections 4.2, 4.3). ˆ can be efficiently computed in O(p) time. [∇2θ L(zi , θ)]v In our case, we rely on Pearlmutter (1994)’s more general 4.1. Influence functions vs. leave-one-out retraining algorithm for fast HVPs, described above, to achieve the same time complexity.3 Influence functions assume that the weight on a training point is changed by an infinitesimally small . To investi- With these techniques, we can compute Iup,loss (zi , ztest ) gate the accuracy of using influence functions to approx- on all training points zi in O(np + rtp) time; we show in imate the effect of removing a training point and retrain- Section 4.1 that empirically, choosing rt = O(n) gives ac- ing, we compared − n1 Iup,loss (z, ztest ) with L(ztest , θˆ−z ) − curate results. Similarly, we compute Ipert,loss (zi , ztest ) = ˆ (i.e., actually doing leave-one-out retraining). L(ztest , θ) − n1 ∇θ L(ztest , θ) ˆ ˆ H −1 ∇x ∇θ L(zi , θ) with two θˆ With a logistic regression model on 10-class MNIST,4 the matrix-vector products: we first compute stest , then predicted and actual changes matched closely (Fig 2-Left). ˆ with the same HVP trick. These stest ∇x ∇θ L(zi , θ), computations are easy to implement in auto-grad systems The stochastic approximation from Agarwal et al. (2016) like TensorFlow (Abadi et al., 2015) and Theano (Theano was also accurate with r = 10 repeats and t = 5, 000 iter- D. Team, 2016), as users need only specify L; the rest is ations (Fig 2-Mid). Since each iteration only requires one automatically handled. ˆ HVP [∇2θ L(zi , θ)]v, this runs quickly: in fact, we accu- rately estimated H −1 v without even looking at every data point, since n = 55, 000 > rt. Surprisingly, even r = 1 4. Validation and Extensions worked; while results were noisier, it was still able to iden- Recall that influence functions are asymptotic approxima- tify the most influential points. tions of leave-one-out retraining under the assumptions that (i) the model parameters θˆ minimize the empirical risk, 4.2. Non-convexity and non-convergence and that (ii) the empirical risk is twice-differentiable and In Section 2, we took θˆ as the global minimum. In practice, 2 We assume w.l.o.g. that ∀i, ∇2θ L(zi , θ)ˆ I; if this is not if we obtain our parameters θ˜ by running SGD with early true, we can scale the loss down without affecting the parameters. stopping or on non-convex objectives, θ˜ = θ. ˆ As a result, In some cases, we can get an upper bound on ∇2θ L(zi , θ) ˆ (e.g., for Hθ˜ could have negative eigenvalues. We show that influ- linear models and bounded input), which makes this easy. Other- ence functions on θ˜ still give meaningful results in practice. wise, we treat the scaling as a separate hyperparameter and tune it such that the Taylor expansion converges. Our approach is to form a convex quadratic approxima- 3 To increase stability, especially with non-convex models (see ˜ i.e., L(z, tion of the loss around θ, ˜ θ) = L(z, θ) ˜ + Section 4.2), we can also sample a mini-batch of training points at each iteration, instead of relying on a single training point. 4 We trained with L-BFGS (Liu & Nocedal, 1989), with L2 regularization of 0.01, n = 55, 000, and p = 7, 840 parameters.
5. Understanding Black-box Predictions via Influence Functions Figure 3. Smooth approximations to the hinge loss. (a) By varying t, we can approximate the hinge loss with arbitrary accuracy: the green and blue lines are overlaid on top of each other. (b) Using a random, wrongly-classified test point, we compared the predicted vs. actual differences in loss after leave-one-out retraining on the 100 most influential training points. A similar trend held for other test points. The SVM objective is to minimize 0.005 w 22 + n1 i Hinge(yi w xi ). Left: Influence functions were unable to accurately predict the change, overestimating its magnitude considerably. Mid: Using SmoothHinge(·, 0.001) let us accurately predict the change in the hinge loss after retraining. Right: Correlation remained high over a wide range of t, though it degrades when t is too large. When t = 0.001, Pearson’s R = 0.95; when t = 0.1, Pearson’s R = 0.91. ∇L(z, θ)˜ (θ − θ)+ ˜ (H ˜ +λI)(θ − θ). ˜ 1 (θ − θ) ˜ Here, λ is imizing Hinge(s) = max(0, 1 − s); this simple piece- 2 θ a damping term that we add if Hθ˜ has negative eigenvalues; wise linear function is similar to ReLUs, which cause non- this corresponds to adding L2 regularization on θ. We then differentiability in neural networks. We set the deriva- ˜ If θ˜ is close to a local minimum, calculate Iup,loss using L. tives at the hinge to 0 and calculated Iup,loss . As one this is correlated with the result of taking a Newton step might expect, this was inaccurate (Fig 3b-Left): the sec- from θ˜ after removing weight from z (see appendix B). ond derivative carries no information about how close a support vector z is to the hinge, so the quadratic approx- We checked the behavior of Iup,loss in a non-convergent, ˆ is linear (up to regularization), which imation of L(z, θ) non-convex setting by training a convolutional neural net- leads to Iup,loss (z, ztest ) overestimating the influence of z. work for 500k iterations.5 The model had not converged and Hθ˜ was not PD, so we added a damping term with For the purposes of calculating influence, we approximated λ = 0.01. Even in this difficult setting, the predicted and Hinge(s) with SmoothHinge(s, t) = t log(1+exp( 1−s t )), actual changes in loss were highly correlated (Pearson’s R which approaches the hinge loss as t → 0 (Fig 3a). Using = 0.86, Fig 2-Right). the same SVM weights as before, we found that calculat- ing Iup,loss using SmoothHinge(s, 0.001) closely matched 4.3. Non-differentiable losses the actual change due to retraining in the original Hinge(s) (Pearson’s R = 0.95; Fig 3b-Mid) and remained accurate What happens when the derivatives of the loss, ∇θ L and over a wide range of t (Fig 3b-Right). ∇2θ L, do not exist? In this section, we show that in- fluence functions computed on smooth approximations to non-differentiable losses can predict the behavior of the 5. Use Cases of Influence Functions original, non-differentiable loss under leave-one-out re- 5.1. Understanding model behavior training. The robustness of this approximation suggests that we can train non-differentiable models and swap out By telling us the training points “responsible” for a given non-differentiable components for smoothed versions for prediction, influence functions reveal insights about how the purposes of calculating influence. models rely on and extrapolate from the training data. In this section, we show that two models can make the same To see this, we trained a linear SVM on the same 1s correct predictions but get there in very different ways. vs. 7s MNIST task in Section 2.3. This involves min- 5 We compared (a) the state-of-the-art Inception v3 network The network had 7 sets of convolutional layers with tanh(·) non-linearities, modeled after the all-convolutional network from (Szegedy et al., 2016) with all but the top layer frozen6 to (Springenberg et al., 2014). For speed, we used 10% of the (b) an SVM with an RBF kernel on a dog vs. fish image MNIST training set and only 2,616 parameters, since repeatedly classification dataset we extracted from ImageNet (Rus- retraining the network was expensive. Training was done with sakovsky et al., 2015), with 900 training examples for each mini-batches of 500 examples and the Adam optimizer (Kingma class. Freezing neural networks in this way is not uncom- & Ba, 2014). The model had not converged after 500k iterations; training it for another 500k iterations, using a full training pass 6 We used pre-trained weights from Keras (Chollet, 2015). for each iteration, reduced train loss from 0.14 to 0.12.
6. Understanding Black-box Predictions via Influence Functions mon in computer vision and is equivalent to training a lo- able from real test images but completely fool a classifier gistic regression model on the bottleneck features (Don- (Goodfellow et al., 2015; Moosavi-Dezfooli et al., 2016). ahue et al., 2014). We picked a test image both models We demonstrate that influence functions can be used to got correct (Fig 4-Top) and used SmoothHinge(·, 0.001) craft adversarial training images that are similarly visually- to compute the influence for the SVM. indistinguishable and can flip a model’s prediction on a sep- arate test image. To the best of our knowledge, this is the As expected, Iup,loss in the RBF SVM varied inversely with first proof-of-concept that visually-indistinguishable train- raw pixel distance, with training images far from the test ing attacks can be executed on otherwise highly-accurate image in pixel space having almost no influence. The In- neural networks. ception influences were much less correlated with distance in pixel space (Fig 4-Left). Looking at the two most help- The key idea is that Ipert,loss (z, ztest ) tells us how to mod- ful images (most positive −Iup,loss ) for each model in Fig ify training point z to most increase the loss on ztest . 4-Right, we see that the Inception network picked up on the Concretely, for a target test image ztest , we can construct distinctive characteristics of clownfish, whereas the RBF z˜i , an adversarial version of a training image zi , by ini- SVM pattern-matched training images superficially. tializing z˜i := zi and then iterating z˜i := Π(˜ zi + α sign(Ipert,loss (˜ zi , ztest ))), where α is the step size and Π Moreover, in the RBF SVM, fish (green points) close to projects onto the set of valid images that share the same 8- the test image were mostly helpful, while dogs (red) were bit representation with zi . After each iteration, we retrain mostly harmful, with the RBF acting as a soft nearest the model. This is an iterated, training-set analogue of the neighbor function (Fig 4-Left). In contrast, in the Incep- methods used by, e.g., Goodfellow et al. (2015); Moosavi- tion network, fish and dogs could be helpful or harmful for Dezfooli et al. (2016) for test-set attacks. correctly classifying the test image as a fish; in fact, some of the most helpful training images were dogs that, to the We tested these training attacks on the same Inception net- model, looked very different from the test fish (Fig 4-Top). work on dogs vs. fish from Section 5.1, choosing this pair of animals to provide a stark contrast between the classes. We set α = 0.02 and ran the attack for 100 iterations on each test image. As before, we froze all but the top layer for training; note that computing Ipert,loss still involves differentiating through the entire network. Originally, the model correctly classified 591 / 600 test images. For each of these 591 test images, considered separately, we tried to find a visually-indistinguishable perturbation (i.e., same 8- bit representation) to a single training image, out of 1,800 total training images, that would flip the model’s predic- tion. We were able to do this on 335 (57%) of the 591 test images. By perturbing 2 training images for each test image, we could flip predictions on 77% of the 591 test im- ages; and if we perturbed 10 training images, we could flip all but 1 of the 591. The above results are from attacking each test image separately, i.e., using a different training set to attack each test image. We also tried to attack multiple test images simultaneously by increasing their average loss, Figure 4. Inception vs. RBF SVM. Bottom left: and found that single training image perturbations could si- −Iup,loss (z, ztest ) vs. z − ztest 22 . Green dots are fish and multaneously flip multiple test predictions as well (Fig 5). red dots are dogs. Bottom right: The two most helpful training images, for each model, on the test. Top right: An image of a We make three observations about these attacks. First, dog in the training set that helped the Inception model correctly though the change in pixel values is small, the change in classify the test image as a fish. the final Inception feature layer is significantly larger: us- 5.2. Adversarial training examples ing L2 distance in pixel space, the training values change by less than 1% of the mean distance of a training point to In this section, we show that models that place a lot of in- its class centroid, whereas in Inception feature space, the fluence on a small number of points can be vulnerable to change is on the same order as the mean distance. This training input perturbations, posing a serious security risk leaves open the possibility that our attacks, while visually- in real-world ML systems where attackers can influence the imperceptible, can be detected by examining the feature training data (Huang et al., 2011). Recent work has gener- space. Second, the attack tries to perturb the training ex- ated adversarial test images that are visually indistinguish-
7. Understanding Black-box Predictions via Influence Functions Figure 5. Training-set at- tacks. We targeted a set of 30 test images fea- turing the first author’s dog in a variety of poses and backgrounds. By maximizing the average loss over these 30 im- ages, we created a visually- imperceptible change to the particular training im- age (shown on top) that flipped predictions on 16 test images. ample in a direction of low variance, causing the model to 3 out of the 24 children under age 10 in this dataset were overfit in that direction and consequently incorrectly clas- re-admitted. To induce a domain mismatch, we filtered out sify the test images; we expect attacking to be harder as 20 children who were not re-admitted, leaving 3 out of 4 re- the number of training examples grows. Third, ambiguous admitted. This caused the model to wrongly classify many or mislabeled training images are effective points to attack: children in the test set. Our aim is to identify the 4 children the model has low confidence and thus high loss on them, in the training set as being “responsible” for these errors. making them highly influential (recall Section 2.3). For ex- As a baseline, we tried the common practice of looking at ample, the image in Fig 5 contains both a dog and a fish and the learned parameters θˆ to see if the indicator variable for is highly ambiguous; as a result, it is the training example being a child was obviously different. However, this did that the model is least confident on (with a confidence of not work: 14/127 features had a larger coefficient. 77%, compared to the next lowest confidence of 90%). Picking a random child ztest that the model got wrong, we This attack is mathematically equivalent to the gradient- calculated −Iup,loss (zi , ztest ) for each training point zi . This based training set attacks explored by Biggio et al. (2012); clearly highlighted the 4 training children, each of whom Mei & Zhu (2015b) and others in the context of different were 30-40 times as influential as the next most influential models. Biggio et al. (2012) constructed a dataset poison- examples. The 1 child in the training set who was not read- ing attack against a linear SVM on a two-class MNIST task, mitted had a very positive influence, while the other 3 had but had to modify the training points in an obviously distin- very negative influences. Moreover, calculating Ipert,loss on guishable way to be effective. Measuring the magnitude of these 4 children showed that the ‘child’ indicator variable Ipert,loss gives model developers a way of quantifying how contributed significantly to the magnitude of Iup,loss . vulnerable their models are to training-set attacks. 5.4. Fixing mislabeled examples 5.3. Debugging domain mismatch Labels in the real world are often noisy, especially if crowd- Domain mismatch — where the training distribution does sourced (Fr´enay & Verleysen, 2014), and can even be ad- not match the test distribution — can cause models with versarially corrupted. Even if a human expert could rec- high training accuracy to do poorly on test data (Ben-David ognize wrongly labeled examples, it is impossible in many et al., 2010). We show that influence functions can identify applications to manually review all of the training data. We the training examples most responsible for the errors, help- show that influence functions can help human experts pri- ing model developers identify domain mismatch. oritize their attention, allowing them to inspect only the ex- As a case study, we predicted whether a patient would be amples that actually matter. readmitted to hospital. Domain mismatches are common The key idea is to flag the training points that exert the in biomedical data, e.g., different hospitals serve different most influence on the model. Because we do not have ac- populations, and models trained on one population can do cess to the test set, we measure the influence of zi with poorly on another (Kansagara et al., 2011). We used logis- Iup,loss (zi , zi ), which approximates the error incurred on zi tic regression to predict readmission with a balanced train- if we remove zi from the training set. ing dataset of 20K diabetic patients from 100+ US hospi- tals, each represented by 127 features (Strack et al., 2014).7 Our case study is email spam classification, which relies 7 graphic (e.g., age, race, gender), administrative (e.g., length of Hospital readmission was defined as whether a patient would be readmitted within the next 30 days. Features were demo- hospital stay), or medical (e.g., test results).
8. Understanding Black-box Predictions via Influence Functions on user-provided labels and is also vulnerable to adversar- ods specific to generalized linear models. ial attack (Biggio et al., 2011). We flipped the labels of a As noted in Section 5.2, our training-set attack is mathe- random 10% of the training data and then simulated manu- matically equivalent to an approach first explored by Big- ally inspecting a fraction of the training points, correcting gio et al. (2012) in the context of SVMs, with follow-up them if they had been flipped. Using influence functions work extending the framework and applying it to linear to prioritize the training points to inspect allowed us to re- and logistic regression (Mei & Zhu, 2015b), topic mod- pair the dataset (Fig 6, blue) without checking too many eling (Mei & Zhu, 2015a), and collaborative filtering (Li points, outperforming the baselines of checking points with et al., 2016a). These papers derived the attack directly from the highest train loss (Fig 6, green) or at random (Fig 6, the KKT conditions without considering influence, though red). No method had access to the test data. for continuous data, the end result is equivalent. Influ- ence functions additionally let us consider attacks on dis- crete data (Section 2.2), but we have not tested this em- pirically. Our work connects the literature on training- set attacks with work on “adversarial examples” (Goodfel- low et al., 2015; Moosavi-Dezfooli et al., 2016), visually- imperceptible perturbations on test inputs. In contrast to training-set attacks, Cadamuro et al. (2016) consider the task of taking an incorrect test prediction and finding a small subset of training data such that changing the labels on this subset makes the prediction correct. They provide a solution for OLS and Gaussian process models Figure 6. Fixing mislabeled examples. Plots of how test accu- racy (left) and the fraction of flipped data detected (right) change when the labels are continuous. Our work with influence with the fraction of train data checked, using different algorithms functions allow us to solve this problem in a much larger for picking points to check. Error bars show the std. dev. across range of models and in datasets with discrete labels. 40 repeats of this experiment, with a different subset of labels flipped in each; error bars on the right are too small to be seen. 7. Discussion These results are on the Enron1 spam dataset (Metsis et al., 2006), with 4,147 training and 1,035 test examples; we trained logistic We have discussed a variety of applications, from creat- regression on a bag-of-words representation of the emails. ing training-set attacks to debugging models and fixing datasets. Underlying each of these applications is a com- 6. Related Work mon tool, the influence function, which is based on a sim- ple idea — we can better understand model behavior by The use of influence-based diagnostics originated in statis- looking at how it was derived from its training data. tics in the 70s and 80s, driven by seminal papers by Cook and others (Cook, 1977; Cook & Weisberg, 1980; 1982), At their core, influence functions measure the effect of lo- though similar ideas appeared even earlier in other forms, cal changes: what happens when we upweight a point by e.g., the infinitesimal jackknife (Jaeckel, 1972). Earlier an infinitesimally-small ? This locality allows us to de- work focused on removing training points from linear mod- rive efficient closed-form estimates, and as we show, they els, with later work extending this to more general models can be surprisingly effective. However, we might want to and a wider variety of perturbations (Cook, 1986; Thomas ask about more global changes, e.g., how does a subpopu- & Cook, 1990; Chatterjee & Hadi, 1986; Wei et al., 1998). lation of patients from this hospital affect the model? Since Most of this prior work focused on experiments with small influence functions depend on the model not changing too datasets, e.g., n = 24 and p = 10 in Cook & Weisberg much, how to tackle this is an open question. (1980), with special attention therefore paid to exact solu- It seems inevitable that high-performing, complex, black- tions, or if not possible, characterizations of the error terms. box models will become increasingly prevalent and impor- Influence functions have not been used much in the ML tant. We hope that the approach presented here — of look- literature, with some exceptions. Christmann & Stein- ing at the model through the lens of the training data — wart (2004); Debruyne et al. (2008); Liu et al. (2014) use will become a standard part of the toolkit of developing, influence functions to study model robustness and to do understanding, and diagnosing machine learning. fast cross-validation in kernel methods. Wojnowicz et al. The code and data for replicating our experiments is avail- (2016) uses matrix sketching to estimate Cook’s distance, able on GitHub http://bit.ly/gt-influence which is closely related to influence; they focus on priori- and Codalab http://bit.ly/cl-influence. tizing training points for human attention and derive meth-
9. Understanding Black-box Predictions via Influence Functions A. Deriving the influence function Iup,params Since θˆ minimzes R, we have ∇R(θ) ˆ = 0. Keeping only O( ) terms, we have For completeness, we provide a standard derivation of the influence function Iup,params in the context of loss minimiza- ˆ −1 ∇L(z, θ) ∆ ≈ − ∇2 R(θ) ˆ . (13) tion (M-estimation). This derivation is based on asymp- totic arguments and is not fully rigorous; see van der Vaart (1998) and other statistics textbooks for a more thorough Combining with (7) and (9), we conclude that: treatment. dθˆ ,z ˆ Recall that θˆ minimizes the empirical risk: = −Hθˆ−1 ∇L(z, θ) (14) d =0 def n = Iup,params (z). (15) 1 def R(θ) = L(zi , θ). (6) n i=1 B. Influence at non-convergence We further assume that R is twice-differentiable and Consider a training point z. When the model parameters strictly convex in θ, i.e., θ˜ are close to but not at a local minimum, Iup,params (z) is n approximately equal to a constant (which does not depend def ˆ = 1 Hθˆ = ∇2 R(θ) ˆ ∇2θ L(zi , θ) (7) on z) plus the change in parameters after upweighting z and n i=1 then taking a single Newton step from θ. ˜ The high-level idea is that even though the gradient of the empirical risk at exists and is positive definite. This guarantees the existence θ˜ is not 0, the Newton step from θ˜ can be decomposed into of Hθˆ−1 , which we will use in the subsequent derivation. a component following the existing gradient (which does not depend on the choice of z) and a second component The perturbed parameters θˆ ,z can be written as responding to the upweighted z (which Iup,params (z) tracks). θˆ ,z = arg min {R(θ) + L(z, θ)} . (8) def n Let g = n1 i=1 ∇θ L(zi , θ) ˜ be the gradient of the em- θ∈Θ ˜ since θ˜ is not a local minimum, g = 0. pirical risk at θ; After upweighting z by , the gradient at θ˜ goes from Define the parameter change ∆ = θˆ ,z − θ, ˆ and note that, ˜ and the empirical Hessian goes from g → g + ∇θ L(z, θ), ˆ as θ doesn’t depend on , the quantity we seek to compute ˜ A Newton step from θ˜ therefore Hθ˜ → Hθ˜ + ∇2θ L(z, θ). can be written in terms of it: changes the parameters by: dθˆ ,z d∆ def −1 = . (9) N ,z ˜ = − Hθ˜ + ∇2θ L(z, θ) ˜ . g + ∇θ L(z, θ) d d (16) Since θˆ ,z is a minimizer of (8), let us examine its first- order optimality conditions: Ignoring terms in g, 2 , and higher, we get N ,z ≈ ˜ . Therefore, the actual change −Hθ˜−1 g + ∇θ L(z, θ) 0 = ∇R(θˆ ,z ) + ∇L(z, θˆ ,z ). (10) due to a Newton step N ,z is equal to a constant −Hθ˜−1 g (that doesn’t depend on z) plus times Iup,params (z) = Next, since θˆ ,z → θˆ as → 0, we perform a Taylor expan- ˜ (which captures the contribution of z). −Hθ˜−1 ∇θ L(z, θ) sion of the right-hand side: ˆ + ∇L(z, θ) 0 ≈ ∇R(θ) ˆ + (11) Acknowledgements ˆ + ∇2 L(z, θ) ˆ ∆, We thank Jacob Steinhardt, Zhenghao Chen, and Hongseok ∇2 R(θ) Namkoong for helpful discussions and comments. This work was supported by a Future of Life Research Award where we have dropped o( ∆ ) terms. and a Microsoft Research Faculty Fellowship. Solving for ∆ , we get: −1 References ˆ + ∇2 L(z, θ) ∆ ≈ − ∇2 R(θ) ˆ (12) Abadi, M., Agarwal, A., Barham, P., Brevdo, E., Chen, Z., ˆ + ∇L(z, θ) ˆ . Citro, C., Corrado, G. S., Davis, A., Dean, J., Devin, M., ∇R(θ) Ghemawat, S., Goodfellow, I. J., Harp, A., Irving, G.,
10. Understanding Black-box Predictions via Influence Functions Isard, M., Jia, Y., J´ozefowicz, R., Kaiser, L., Kudlur, M., Cook, R. D. Assessment of local influence. Journal of the Levenberg, J., Man´e, D., Monga, R., Moore, S., Mur- Royal Statistical Society. Series B (Methodological), pp. ray, D. G., Olah, C., Schuster, M., Shlens, J., Steiner, 133–169, 1986. B., Sutskever, I., Talwar, K., Tucker, P. A., Vanhoucke, Cook, R. D. and Weisberg, S. Characterizations of an em- V., Vasudevan, V., Vi´egas, F. B., Vinyals, O., Warden, P., pirical influence function for detecting influential cases Wattenberg, M., Wicke, M., Yu, Y., and Zheng, X. Ten- in regression. Technometrics, 22:495–508, 1980. sorflow: Large-scale machine learning on heterogeneous distributed systems. arXiv preprint arXiv:1603.04467, Cook, R. D. and Weisberg, S. Residuals and influence in 2015. regression. New York: Chapman and Hall, 1982. Adler, P., Falk, C., Friedler, S. A., Rybeck, G., Scheideg- Datta, A., Sen, S., and Zick, Y. Algorithmic transparency ger, C., Smith, B., and Venkatasubramanian, S. Auditing via quantitative input influence: Theory and experiments black-box models for indirect influence. arXiv preprint with learning systems. In Security and Privacy (SP), arXiv:1602.07043, 2016. 2016 IEEE Symposium on, pp. 598–617, 2016. Debruyne, M., Hubert, M., and Suykens, J. A. Model selec- Agarwal, N., Bullins, B., and Hazan, E. Second order tion in kernel based regression using the influence func- stochastic optimization in linear time. arXiv preprint tion. Journal of Machine Learning Research (JMLR), 9 arXiv:1602.03943, 2016. (0):2377–2400, 2008. Amershi, S., Chickering, M., Drucker, S. M., Lee, B., Donahue, J., Jia, Y., Vinyals, O., Hoffman, J., Zhang, N., Simard, P., and Suh, J. Modeltracker: Redesigning per- Tzeng, E., and Darrell, T. Decaf: A deep convolutional formance analysis tools for machine learning. In Con- activation feature for generic visual recognition. In Inter- ference on Human Factors in Computing Systems (CHI), national Conference on Machine Learning (ICML), vol- pp. 337–346, 2015. ume 32, pp. 647–655, 2014. Ben-David, S., Blitzer, J., Crammer, K., Kulesza, A., Fr´enay, B. and Verleysen, M. Classification in the presence Pereira, F., and Vaughan, J. W. A theory of learning of label noise: a survey. IEEE Transactions on Neural from different domains. Machine Learning, 79(1):151– Networks and Learning Systems, 25:845–869, 2014. 175, 2010. Goodfellow, I. J., Shlens, J., and Szegedy, C. Explaining Biggio, B., Nelson, B., and Laskov, P. Support vector ma- and harnessing adversarial examples. In International chines under adversarial label noise. ACML, 20:97–112, Conference on Learning Representations (ICLR), 2015. 2011. Goodman, B. and Flaxman, S. European union regulations on algorithmic decision-making and a “right to explana- Biggio, B., Nelson, B., and Laskov, P. Poisoning attacks tion”. arXiv preprint arXiv:1606.08813, 2016. against support vector machines. In International Con- ference on Machine Learning (ICML), pp. 1467–1474, Huang, L., Joseph, A. D., Nelson, B., Rubinstein, B. I., and 2012. Tygar, J. Adversarial machine learning. In Proceedings of the 4th ACM workshop on Security and artificial in- Cadamuro, G., Gilad-Bachrach, R., and Zhu, X. Debug- telligence, pp. 43–58, 2011. ging machine learning models. In ICML Workshop on Reliable Machine Learning in the Wild, 2016. Jaeckel, L. A. The infinitesimal jackknife. Unpub- lished memorandum, Bell Telephone Laboratories, Mur- Chatterjee, S. and Hadi, A. S. Influential observations, high ray Hill, NJ, 1972. leverage points, and outliers in linear regression. Statis- Kansagara, D., Englander, H., Salanitro, A., Kagen, D., tical Science, pp. 379–393, 1986. Theobald, C., Freeman, M., and Kripalani, S. Risk pre- diction models for hospital readmission: a systematic re- Chollet, F. Keras, 2015. view. JAMA, 306(15):1688–1698, 2011. Christmann, A. and Steinwart, I. On robustness properties Kingma, D. and Ba, J. Adam: A method for stochastic of convex risk minimization methods for pattern recog- optimization. arXiv preprint arXiv:1412.6980, 2014. nition. Journal of Machine Learning Research (JMLR), 5(0):1007–1034, 2004. Krizhevsky, A., Sutskever, I., and Hinton, G. E. Imagenet classification with deep convolutional neural networks. Cook, R. D. Detection of influential observation in linear In Advances in Neural Information Processing Systems regression. Technometrics, 19:15–18, 1977. (NIPS), pp. 1097–1105, 2012.
11. Understanding Black-box Predictions via Influence Functions LeCun, Y., Bottou, L., Bengio, Y., and Haffner, P. Gradient- Shrikumar, A., Greenside, P., Shcherbina, A., and Kun- based learning applied to document recognition. Pro- daje, A. Not just a black box: Learning important fea- ceedings of the IEEE, 86(11):2278–2324, 1998. tures through propagating activation differences. arXiv preprint arXiv:1605.01713, 2016. Li, B., Wang, Y., Singh, A., and Vorobeychik, Y. Data poi- soning attacks on factorization-based collaborative filter- Simonyan, K., Vedaldi, A., and Zisserman, A. Deep in- ing. In Advances in Neural Information Processing Sys- side convolutional networks: Visualising image clas- tems (NIPS), 2016a. sification models and saliency maps. arXiv preprint arXiv:1312.6034, 2013. Li, J., Monroe, W., and Jurafsky, D. Understanding neural networks through representation erasure. arXiv preprint Springenberg, J. T., Dosovitskiy, A., Brox, T., and Ried- arXiv:1612.08220, 2016b. miller, M. Striving for simplicity: The all convolutional net. arXiv preprint arXiv:1412.6806, 2014. Liu, D. C. and Nocedal, J. On the limited memory BFGS method for large scale optimization. Mathematical Pro- Strack, B., DeShazo, J. P., Gennings, C., Olmo, J. L., Ven- gramming, 45(1):503–528, 1989. tura, S., Cios, K. J., and Clore, J. N. Impact of HbA1c measurement on hospital readmission rates: analysis of Liu, Y., Jiang, S., and Liao, S. Efficient approximation 70,000 clinical database patient records. BioMed Re- of cross-validation for kernel methods using Bouligand search International, 2014, 2014. influence function. In International Conference on Ma- Szegedy, C., Vanhoucke, V., Ioffe, S., Shlens, J., and Wo- chine Learning (ICML), pp. 324–332, 2014. jna, Z. Rethinking the Inception architecture for com- Martens, J. Deep learning via hessian-free optimization. In puter vision. In Computer Vision and Pattern Recogni- International Conference on Machine Learning (ICML), tion (CVPR), pp. 2818–2826, 2016. pp. 735–742, 2010. Theano D. Team. Theano: A Python framework for Mei, S. and Zhu, X. The security of latent Dirichlet alloca- fast computation of mathematical expressions. arXiv tion. In Artificial Intelligence and Statistics (AISTATS), preprint arXiv:1605.02688, 2016. 2015a. Thomas, W. and Cook, R. D. Assessing influence on pre- dictions from generalized linear models. Technometrics, Mei, S. and Zhu, X. Using machine teaching to identify 32(1):59–65, 1990. optimal training-set attacks on machine learners. In As- sociation for the Advancement of Artificial Intelligence van der Vaart, A. W. Asymptotic statistics. Cambridge (AAAI), 2015b. University Press, 1998. Metsis, V., Androutsopoulos, I., and Paliouras, G. Spam fil- Wei, B., Hu, Y., and Fung, W. Generalized leverage and tering with naive Bayes – which naive Bayes? In CEAS, its applications. Scandinavian Journal of Statistics, 25: volume 17, pp. 28–69, 2006. 25–37, 1998. Moosavi-Dezfooli, S., Fawzi, A., and Frossard, P. Deep- Wojnowicz, M., Cruz, B., Zhao, X., Wallace, B., Wolff, M., fool: a simple and accurate method to fool deep neural Luan, J., and Crable, C. “Influence sketching”: Find- networks. In Computer Vision and Pattern Recognition ing influential samples in large-scale regressions. arXiv (CVPR), pp. 2574–2582, 2016. preprint arXiv:1611.05923, 2016. Pearlmutter, B. A. Fast exact multiplication by the hessian. Neural Computation, 6(1):147–160, 1994. Ribeiro, M. T., Singh, S., and Guestrin, C. ”why should I trust you?”: Explaining the predictions of any classi- fier. In International Conference on Knowledge Discov- ery and Data Mining (KDD), 2016. Russakovsky, O., Deng, J., Su, H., Krause, J., Satheesh, S., Ma, S., Huang, Z., Karpathy, A., Khosla, A., Bernstein, M., et al. ImageNet large scale visual recognition chal- lenge. International Journal of Computer Vision, 115(3): 211–252, 2015.