-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
mnist_tutorial.py
114 lines (91 loc) · 3.92 KB
/
mnist_tutorial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from absl import app, flags
import datasets
import itertools
import time
import jax.numpy as np
import numpy.random as npr
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import logsoftmax
from cleverhans.jax.attacks.fast_gradient_method import fast_gradient_method
from cleverhans.jax.attacks.projected_gradient_descent import projected_gradient_descent
FLAGS = flags.FLAGS
def main(_):
rng = random.PRNGKey(0)
# Load MNIST dataset
train_images, train_labels, test_images, test_labels = datasets.mnist()
batch_size = 128
batch_shape = (-1, 28, 28, 1)
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
train_images = np.reshape(train_images, batch_shape)
test_images = np.reshape(test_images, batch_shape)
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size : (i + 1) * batch_size]
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
# Model, loss, and accuracy functions
init_random_params, predict = stax.serial(
stax.Conv(32, (8, 8), strides=(2, 2), padding="SAME"),
stax.Relu,
stax.Conv(128, (6, 6), strides=(2, 2), padding="VALID"),
stax.Relu,
stax.Conv(128, (5, 5), strides=(1, 1), padding="VALID"),
stax.Flatten,
stax.Dense(128),
stax.Relu,
stax.Dense(10),
)
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -np.mean(logsoftmax(preds) * targets)
def accuracy(params, batch):
inputs, targets = batch
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class)
# Instantiate an optimizer
opt_init, opt_update, get_params = optimizers.adam(0.001)
@jit
def update(i, opt_state, batch):
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
# Initialize model
_, init_params = init_random_params(rng, batch_shape)
opt_state = opt_init(init_params)
itercount = itertools.count()
# Training loop
print("\nStarting training...")
for epoch in range(FLAGS.nb_epochs):
start_time = time.time()
for _ in range(num_batches):
opt_state = update(next(itercount), opt_state, next(batches))
epoch_time = time.time() - start_time
# Evaluate model on clean data
params = get_params(opt_state)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
# Evaluate model on adversarial data
model_fn = lambda images: predict(params, images)
test_images_fgm = fast_gradient_method(model_fn, test_images, FLAGS.eps, np.inf)
test_images_pgd = projected_gradient_descent(
model_fn, test_images, FLAGS.eps, 0.01, 40, np.inf
)
test_acc_fgm = accuracy(params, (test_images_fgm, test_labels))
test_acc_pgd = accuracy(params, (test_images_pgd, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy: {}".format(train_acc))
print("Test set accuracy on clean examples: {}".format(test_acc))
print("Test set accuracy on FGM adversarial examples: {}".format(test_acc_fgm))
print("Test set accuracy on PGD adversarial examples: {}".format(test_acc_pgd))
if __name__ == "__main__":
flags.DEFINE_integer("nb_epochs", 8, "Number of epochs.")
flags.DEFINE_float("eps", 0.3, "Total epsilon for FGM and PGD attacks.")
app.run(main)