AlexNet Transfer Learning - Ipynb
AlexNet Transfer Learning - Ipynb
with AlexNet
In [ ]:
import random
import tensorflow as tf
import numpy as np
import os
from scipy import ndimage
import matplotlib.pyplot as plt
%matplotlib inline
sess = tf.Session(graph=graph)
importer.restore(sess, 'saved_models/alex_vars')
with tf.name_scope('transfer'):
labels = tf.placeholder(tf.int32, [None])
one_hot_labels = tf.one_hot(labels, 2)
with tf.name_scope('cat_dog_final_layer'):
weights = tf.Variable(tf.truncated_normal([4096, 2],
stddev=0.001),
name='final_weights')
biases = tf.Variable(tf.zeros([2]), name='final_biases')
logits = tf.nn.xw_plus_b(fc7, weights, biases, name='logits')
prediction = tf.nn.softmax(logits, name='cat_dog_softmax')
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits,
one_hot_labels)
loss = tf.reduce_mean(cross_entropy, name='cat_dog_loss')
var_list=cat_dog_variables)
with tf.name_scope('accuracy'):
label_prediction = tf.argmax(prediction, 1, name='predicted_label')
correct_prediction = tf.equal(label_prediction,
tf.argmax(one_hot_labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
init = tf.initialize_all_variables()
In [ ]:
sess = tf.Session(graph=graph)
sess.run(init)
dog_files = [
'data/dogs_and_cats/dogs/' + f
for
f
in
os.listdir('data/dogs_and_cats/dogs')
]
distort_sess = tf.Session(graph=distort_graph)
epoch = 0
idx = 0
while epoch < max_epochs:
batch = []
labels = []
for i in range(batch_size):
if idx + i >= len(data):
random.shuffle(data)
epoch += 1
idx = 0
image_path = data[idx + i].encode()
if should_distort:
val = distort_sess.run(distort_result,
feed_dict={jpeg_name: image_path})
else:
val = distort_sess.run(resized_image,
feed_dict={jpeg_name: image_path})
if b'dog' in ntpath.basename(image_path):
labels.append(1)
else:
labels.append(0)
batch.append(val)
idx += batch_size
yield batch, labels
In [ ]:
sess.run(init)
Validate
In [ ]:
def check_accuracy(valid_data):
batch_size = 50
num_correct = 0
total = len(valid_data)
i = 0
for data_batch, label_batch in get_batch(batch_size, valid_data, 1):
feed_dict = {x: data_batch, labels: label_batch}
correct_guesses = sess.run(correct_prediction,
feed_dict=feed_dict)
num_correct += np.sum(correct_guesses)
i += batch_size
if i % (batch_size * 10) == 0:
print('\tIntermediate accuracy: {}'.format((float(num_correct)
/ float(i))))
acc = num_correct / float(total)
print('\nAccuracy: {}'.format(acc))
In [ ]:
check_accuracy(valid_data)