Notebook - Deep Neural Networks
Notebook - Deep Neural Networks
1.1 Libraries
[53]: import numpy as np # linear algebra
import matplotlib.pyplot as plt # this is used for the plot the graph
from sklearn.datasets import load_digits # this is used for import the dataset
from sklearn.model_selection import train_test_split # to split the data into␣
↪two parts
import logging
1
# Set up the logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Parameters:
-----------
layers: list of int
The number of neurons in each layer including the input and output layer
loss_func: str
The loss function to use. Options are 'mse' for mean squared error,␣
↪'log_loss' for logistic loss, and 'categorical_crossentropy' for categorical␣
↪crossentropy.
dropout_rate: float
The dropout rate for dropout regularization. Must be between 0 and 1.
grad_clip: float
The gradient clipping threshold.
"""
def __init__(self,
layers,
init_method='glorot_uniform', # 'zeros', 'random',␣
↪'glorot_uniform', 'glorot_normal', 'he_uniform', 'he_normal'
loss_func='mse',
dropout_rate=0.5,
clip_type='value',
grad_clip=5.0
):
self.layers = []
self.loss_func = loss_func
self.dropout_rate = dropout_rate
self.clip_type = clip_type
self.grad_clip = grad_clip
self.init_method = init_method
# Initialize layers
for i in range(len(layers) - 1):
if self.init_method == 'zeros':
weights = np.zeros((layers[i], layers[i + 1]))
elif self.init_method == 'random':
2
weights = np.random.randn(layers[i], layers[i + 1])
elif self.init_method == 'glorot_uniform':
weights = self.glorot_uniform(layers[i], layers[i + 1])
elif self.init_method == 'glorot_normal':
weights = self.glorot_normal(layers[i], layers[i + 1])
elif self.init_method == 'he_uniform':
weights = self.he_uniform(layers[i], layers[i + 1])
elif self.init_method == 'he_normal':
weights = self.he_normal(layers[i], layers[i + 1])
else:
raise ValueError(f'Unknown initialization method {self.
↪init_method}')
self.layers.append({
'weights': weights,
'biases': np.zeros((1, layers[i + 1]))
})
# track loss
self.train_loss = []
self.test_loss = []
def __str__(self):
"""
Print the Neural Network architecture.
"""
structure = f"NN Layout:\nInput Layer: {len(self.layers[0]['weights'])}␣
↪neurons"
else:
structure += f"s\nHidden Layer {i+1}: {len(layer['weights'])}␣
↪neurons"
return structure
Parameters:
-----------
fan_in: int
3
The number of input units in the weight tensor
fan_out: int
The number of output units in the weight tensor
Returns:
--------
numpy array
The initialized weights
"""
limit = np.sqrt(6 / (fan_in + fan_out))
return np.random.uniform(-limit, limit, (fan_in, fan_out))
Parameters:
-----------
fan_in: int
The number of input units in the weight tensor
fan_out: int
The number of output units in the weight tensor
Returns:
--------
numpy array
The initialized weights
"""
limit = np.sqrt(2 / fan_in)
return np.random.uniform(-limit, limit, (fan_in, fan_out))
Parameters:
-----------
fan_in: int
The number of input units in the weight tensor
fan_out: int
The number of output units in the weight tensor
Returns:
--------
numpy array
The initialized weights
4
"""
stddev = np.sqrt(2. / (fan_in + fan_out))
return np.random.normal(0., stddev, size=(fan_in, fan_out))
Parameters:
-----------
fan_in: int
The number of input units in the weight tensor
fan_out: int
The number of output units in the weight tensor
Returns:
--------
numpy array
The initialized weights
"""
stddev = np.sqrt(2. / fan_in)
return np.random.normal(0., stddev, size=(fan_in, fan_out))
Parameters:
-----------
X: numpy array
The input data
is_training: bool
Whether the forward pass is for training or testing/prediction
Returns:
--------
numpy array
The predicted output
"""
self.a = [X]
for i, layer in enumerate(self.layers):
z = np.dot(self.a[-1], layer['weights']) + layer['biases']
a = self.sigmoid(z)
if is_training and i < len(self.layers) - 1: # apply dropout to␣
↪all layers except the output layer
5
a *= dropout_mask
self.a.append(a)
return self.a[-1]
Parameters:
-----------
X: numpy array
The input data
y: numpy array
The target output
learning_rate: float
The learning rate
"""
m = X.shape[0]
self.dz = [self.a[-1] - y]
self.gradient_norms = [] # List to store the gradient norms
self.gradient_norms.append(np.linalg.norm(self.layers[i +␣
↪1]['weights'])) # Compute and store the gradient norm
self.dz = self.dz[::-1]
self.gradient_norms = self.gradient_norms[::-1] # Reverse the list to␣
↪match the order of the layers
for i in range(len(self.layers)):
grads_w = np.dot(self.a[i].T, self.dz[i]) / m
grads_b = np.sum(self.dz[i], axis=0, keepdims=True) / m
# gradient clipping
if self.clip_type == 'value':
grads_w = np.clip(grads_w, -self.grad_clip, self.grad_clip)
grads_b = np.clip(grads_b, -self.grad_clip, self.grad_clip)
elif self.clip_type == 'norm':
grads_w = self.clip_by_norm(grads_w, self.grad_clip)
grads_b = self.clip_by_norm(grads_b, self.grad_clip)
6
"""
Clip gradients by norm.
Parameters:
-----------
grads: numpy array
The gradients
clip_norm: float
The threshold for clipping
Returns:
--------
numpy array
The clipped gradients
"""
l2_norm = np.linalg.norm(grads)
if l2_norm > clip_norm:
grads = grads / l2_norm * clip_norm
return grads
Parameters:
-----------
x: numpy array
The input data
Returns:
--------
numpy array
The output of the sigmoid function
"""
return 1 / (1 + np.exp(-x))
Parameters:
-----------
x: numpy array
The input data
Returns:
--------
7
numpy array
The output of the derivative of the sigmoid function
"""
return x * (1 - x)
Parameters:
-----------
model: NeuralNetwork
The neural network model to train
loss_func: str
The loss function to use. Options are 'mse' for mean squared error,␣
↪'log_loss' for logistic loss, and 'categorical_crossentropy' for categorical␣
↪crossentropy.
"""
def __init__(self, model, loss_func='mse'):
self.model = model
self.loss_func = loss_func
self.train_loss = []
self.val_loss = []
Parameters:
-----------
y_true: numpy array
The true output
y_pred: numpy array
The predicted output
Returns:
--------
float
The loss
"""
if self.loss_func == 'mse':
return np.mean((y_pred - y_true)**2)
elif self.loss_func == 'log_loss':
return -np.mean(y_true*np.log(y_pred) + (1-y_true)*np.log(1-y_pred))
elif self.loss_func == 'categorical_crossentropy':
8
return -np.mean(y_true*np.log(y_pred))
else:
raise ValueError('Invalid loss function')
"""
Train the neural network.
Parameters:
-----------
X_train: numpy array
The training input data
y_train: numpy array
The training target output
X_val: numpy array
The test input data
y_val: numpy array
The test target output
epochs: int
The number of epochs to train the model
learning_rate: float
The learning rate
early_stopping: bool
Whether to stop training early if the test loss doesn't improve for␣
↪a number of epochs
patience: int
The number of epochs to wait for an improvement in the test loss
"""
best_loss = np.inf
epochs_no_improve = 0
self.model.forward(X_val)
val_loss = self.calculate_loss(y_val, self.model.a[-1])
self.val_loss.append(val_loss)
9
# Early stopping
if early_stopping:
if val_loss < best_loss:
best_loss = val_loss
best_weights = [layer['weights'] for layer in self.model.
↪ layers]
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve == patience:
print('Early stopping!')
# Restore the best weights
for i, layer in enumerate(self.model.layers):
layer['weights'] = best_weights[i]
break
def plot_gradient_norms(self):
for i, gradient_norm in enumerate(self.model.gradient_norms):
plt.plot(gradient_norm, label=f'Layer {i + 1}')
plt.legend()
plt.show()
for i in range(10):
axes[i].imshow(digits.images[i], cmap='gray')
axes[i].axis('off')
axes[i].set_title(f"Label: {digits.target[i]}")
plt.tight_layout()
plt.show()
10
1.5 Data Preprocessing
[57]: # Preprocess the dataset
scaler = MinMaxScaler()
X = scaler.fit_transform(digits.data)
y = digits.target
# Split the training set into a smaller training set and a validation set
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.
↪2, random_state=42)
11
patience = 200
dropout_rate = 0.1
# Create the NN
nn = NeuralNetwork([input_size, 64, 64, output_size], loss_func=loss_func,␣
↪init_method=init_method, dropout_rate=dropout_rate)
NN Layout:
Input Layer: 64 neuronss
Hidden Layer 1: 64 neuronss
Hidden Layer 2: 64 neurons
1.7 Train NN
[82]: trainer = Trainer(nn, loss_func)
trainer.train(X_train, y_train, X_val, y_val, epochs=epochs,␣
↪learning_rate=learning_rate, early_stopping=early_stopping,␣
↪patience=patience)
12
INFO:__main__:Epoch 850: loss = 0.167, val_loss = 0.171
INFO:__main__:Epoch 900: loss = 0.160, val_loss = 0.160
INFO:__main__:Epoch 950: loss = 0.151, val_loss = 0.154
INFO:__main__:Epoch 1000: loss = 0.143, val_loss = 0.146
INFO:__main__:Epoch 1050: loss = 0.137, val_loss = 0.140
INFO:__main__:Epoch 1100: loss = 0.130, val_loss = 0.138
INFO:__main__:Epoch 1150: loss = 0.124, val_loss = 0.128
INFO:__main__:Epoch 1200: loss = 0.120, val_loss = 0.125
INFO:__main__:Epoch 1250: loss = 0.115, val_loss = 0.119
INFO:__main__:Epoch 1300: loss = 0.110, val_loss = 0.113
INFO:__main__:Epoch 1350: loss = 0.108, val_loss = 0.108
INFO:__main__:Epoch 1400: loss = 0.102, val_loss = 0.108
INFO:__main__:Epoch 1450: loss = 0.098, val_loss = 0.104
INFO:__main__:Epoch 1500: loss = 0.097, val_loss = 0.098
INFO:__main__:Epoch 1550: loss = 0.094, val_loss = 0.095
INFO:__main__:Epoch 1600: loss = 0.091, val_loss = 0.092
INFO:__main__:Epoch 1650: loss = 0.087, val_loss = 0.089
INFO:__main__:Epoch 1700: loss = 0.084, val_loss = 0.090
INFO:__main__:Epoch 1750: loss = 0.081, val_loss = 0.083
INFO:__main__:Epoch 1800: loss = 0.079, val_loss = 0.082
INFO:__main__:Epoch 1850: loss = 0.077, val_loss = 0.081
INFO:__main__:Epoch 1900: loss = 0.073, val_loss = 0.077
INFO:__main__:Epoch 1950: loss = 0.071, val_loss = 0.079
INFO:__main__:Epoch 2000: loss = 0.069, val_loss = 0.073
INFO:__main__:Epoch 2050: loss = 0.068, val_loss = 0.073
INFO:__main__:Epoch 2100: loss = 0.067, val_loss = 0.071
INFO:__main__:Epoch 2150: loss = 0.063, val_loss = 0.066
INFO:__main__:Epoch 2200: loss = 0.062, val_loss = 0.069
INFO:__main__:Epoch 2250: loss = 0.062, val_loss = 0.066
INFO:__main__:Epoch 2300: loss = 0.060, val_loss = 0.062
INFO:__main__:Epoch 2350: loss = 0.058, val_loss = 0.064
INFO:__main__:Epoch 2400: loss = 0.056, val_loss = 0.060
INFO:__main__:Epoch 2450: loss = 0.054, val_loss = 0.056
INFO:__main__:Epoch 2500: loss = 0.055, val_loss = 0.059
INFO:__main__:Epoch 2550: loss = 0.053, val_loss = 0.055
INFO:__main__:Epoch 2600: loss = 0.050, val_loss = 0.054
INFO:__main__:Epoch 2650: loss = 0.051, val_loss = 0.052
INFO:__main__:Epoch 2700: loss = 0.051, val_loss = 0.060
INFO:__main__:Epoch 2750: loss = 0.048, val_loss = 0.055
INFO:__main__:Epoch 2800: loss = 0.046, val_loss = 0.048
INFO:__main__:Epoch 2850: loss = 0.047, val_loss = 0.051
INFO:__main__:Epoch 2900: loss = 0.044, val_loss = 0.049
INFO:__main__:Epoch 2950: loss = 0.043, val_loss = 0.045
INFO:__main__:Epoch 3000: loss = 0.045, val_loss = 0.046
INFO:__main__:Epoch 3050: loss = 0.044, val_loss = 0.047
INFO:__main__:Epoch 3100: loss = 0.043, val_loss = 0.042
INFO:__main__:Epoch 3150: loss = 0.042, val_loss = 0.051
INFO:__main__:Epoch 3200: loss = 0.041, val_loss = 0.041
13
INFO:__main__:Epoch 3250: loss = 0.039, val_loss = 0.044
INFO:__main__:Epoch 3300: loss = 0.039, val_loss = 0.041
INFO:__main__:Epoch 3350: loss = 0.039, val_loss = 0.043
INFO:__main__:Epoch 3400: loss = 0.037, val_loss = 0.044
INFO:__main__:Epoch 3450: loss = 0.036, val_loss = 0.038
INFO:__main__:Epoch 3500: loss = 0.035, val_loss = 0.042
INFO:__main__:Epoch 3550: loss = 0.036, val_loss = 0.040
INFO:__main__:Epoch 3600: loss = 0.036, val_loss = 0.040
INFO:__main__:Epoch 3650: loss = 0.033, val_loss = 0.039
INFO:__main__:Epoch 3700: loss = 0.032, val_loss = 0.037
INFO:__main__:Epoch 3750: loss = 0.034, val_loss = 0.039
INFO:__main__:Epoch 3800: loss = 0.033, val_loss = 0.036
INFO:__main__:Epoch 3850: loss = 0.031, val_loss = 0.037
INFO:__main__:Epoch 3900: loss = 0.032, val_loss = 0.033
INFO:__main__:Epoch 3950: loss = 0.031, val_loss = 0.034
INFO:__main__:Epoch 4000: loss = 0.031, val_loss = 0.034
INFO:__main__:Epoch 4050: loss = 0.031, val_loss = 0.035
INFO:__main__:Epoch 4100: loss = 0.030, val_loss = 0.033
INFO:__main__:Epoch 4150: loss = 0.029, val_loss = 0.033
INFO:__main__:Epoch 4200: loss = 0.028, val_loss = 0.034
INFO:__main__:Epoch 4250: loss = 0.027, val_loss = 0.033
INFO:__main__:Epoch 4300: loss = 0.030, val_loss = 0.031
INFO:__main__:Epoch 4350: loss = 0.029, val_loss = 0.031
INFO:__main__:Epoch 4400: loss = 0.028, val_loss = 0.031
INFO:__main__:Epoch 4450: loss = 0.028, val_loss = 0.031
INFO:__main__:Epoch 4500: loss = 0.027, val_loss = 0.029
INFO:__main__:Epoch 4550: loss = 0.026, val_loss = 0.031
INFO:__main__:Epoch 4600: loss = 0.026, val_loss = 0.030
INFO:__main__:Epoch 4650: loss = 0.026, val_loss = 0.029
INFO:__main__:Epoch 4700: loss = 0.027, val_loss = 0.031
INFO:__main__:Epoch 4750: loss = 0.027, val_loss = 0.029
INFO:__main__:Epoch 4800: loss = 0.025, val_loss = 0.029
INFO:__main__:Epoch 4850: loss = 0.024, val_loss = 0.026
INFO:__main__:Epoch 4900: loss = 0.025, val_loss = 0.029
INFO:__main__:Epoch 4950: loss = 0.024, val_loss = 0.026
INFO:__main__:Epoch 5000: loss = 0.025, val_loss = 0.027
INFO:__main__:Epoch 5050: loss = 0.023, val_loss = 0.028
INFO:__main__:Epoch 5100: loss = 0.023, val_loss = 0.024
INFO:__main__:Epoch 5150: loss = 0.024, val_loss = 0.028
INFO:__main__:Epoch 5200: loss = 0.022, val_loss = 0.025
INFO:__main__:Epoch 5250: loss = 0.023, val_loss = 0.028
INFO:__main__:Epoch 5300: loss = 0.024, val_loss = 0.026
INFO:__main__:Epoch 5350: loss = 0.023, val_loss = 0.026
INFO:__main__:Epoch 5400: loss = 0.021, val_loss = 0.025
INFO:__main__:Epoch 5450: loss = 0.022, val_loss = 0.023
INFO:__main__:Epoch 5500: loss = 0.020, val_loss = 0.027
INFO:__main__:Epoch 5550: loss = 0.023, val_loss = 0.025
INFO:__main__:Epoch 5600: loss = 0.022, val_loss = 0.024
14
INFO:__main__:Epoch 5650: loss = 0.022, val_loss = 0.027
INFO:__main__:Epoch 5700: loss = 0.022, val_loss = 0.024
INFO:__main__:Epoch 5750: loss = 0.019, val_loss = 0.022
INFO:__main__:Epoch 5800: loss = 0.021, val_loss = 0.025
INFO:__main__:Epoch 5850: loss = 0.021, val_loss = 0.023
INFO:__main__:Epoch 5900: loss = 0.020, val_loss = 0.023
INFO:__main__:Epoch 5950: loss = 0.019, val_loss = 0.026
INFO:__main__:Epoch 6000: loss = 0.021, val_loss = 0.025
INFO:__main__:Epoch 6050: loss = 0.020, val_loss = 0.023
INFO:__main__:Epoch 6100: loss = 0.019, val_loss = 0.024
INFO:__main__:Epoch 6150: loss = 0.020, val_loss = 0.024
INFO:__main__:Epoch 6200: loss = 0.018, val_loss = 0.026
INFO:__main__:Epoch 6250: loss = 0.020, val_loss = 0.023
INFO:__main__:Epoch 6300: loss = 0.018, val_loss = 0.020
INFO:__main__:Epoch 6350: loss = 0.019, val_loss = 0.023
INFO:__main__:Epoch 6400: loss = 0.019, val_loss = 0.022
INFO:__main__:Epoch 6450: loss = 0.017, val_loss = 0.025
INFO:__main__:Epoch 6500: loss = 0.018, val_loss = 0.021
INFO:__main__:Epoch 6550: loss = 0.019, val_loss = 0.022
INFO:__main__:Epoch 6600: loss = 0.018, val_loss = 0.020
INFO:__main__:Epoch 6650: loss = 0.017, val_loss = 0.023
INFO:__main__:Epoch 6700: loss = 0.018, val_loss = 0.022
INFO:__main__:Epoch 6750: loss = 0.019, val_loss = 0.019
INFO:__main__:Epoch 6800: loss = 0.018, val_loss = 0.024
INFO:__main__:Epoch 6850: loss = 0.018, val_loss = 0.022
INFO:__main__:Epoch 6900: loss = 0.016, val_loss = 0.021
INFO:__main__:Epoch 6950: loss = 0.017, val_loss = 0.022
INFO:__main__:Epoch 7000: loss = 0.016, val_loss = 0.021
INFO:__main__:Epoch 7050: loss = 0.016, val_loss = 0.021
INFO:__main__:Epoch 7100: loss = 0.016, val_loss = 0.022
INFO:__main__:Epoch 7150: loss = 0.016, val_loss = 0.025
INFO:__main__:Epoch 7200: loss = 0.016, val_loss = 0.021
INFO:__main__:Epoch 7250: loss = 0.016, val_loss = 0.020
INFO:__main__:Epoch 7300: loss = 0.016, val_loss = 0.020
INFO:__main__:Epoch 7350: loss = 0.016, val_loss = 0.022
INFO:__main__:Epoch 7400: loss = 0.015, val_loss = 0.020
INFO:__main__:Epoch 7450: loss = 0.015, val_loss = 0.018
INFO:__main__:Epoch 7500: loss = 0.016, val_loss = 0.020
INFO:__main__:Epoch 7550: loss = 0.015, val_loss = 0.021
INFO:__main__:Epoch 7600: loss = 0.016, val_loss = 0.016
INFO:__main__:Epoch 7650: loss = 0.016, val_loss = 0.020
INFO:__main__:Epoch 7700: loss = 0.015, val_loss = 0.021
INFO:__main__:Epoch 7750: loss = 0.015, val_loss = 0.017
INFO:__main__:Epoch 7800: loss = 0.015, val_loss = 0.015
INFO:__main__:Epoch 7850: loss = 0.013, val_loss = 0.017
INFO:__main__:Epoch 7900: loss = 0.016, val_loss = 0.016
INFO:__main__:Epoch 7950: loss = 0.014, val_loss = 0.020
INFO:__main__:Epoch 8000: loss = 0.014, val_loss = 0.017
15
INFO:__main__:Epoch 8050: loss = 0.014, val_loss = 0.022
INFO:__main__:Epoch 8100: loss = 0.014, val_loss = 0.019
INFO:__main__:Epoch 8150: loss = 0.014, val_loss = 0.020
INFO:__main__:Epoch 8200: loss = 0.014, val_loss = 0.020
INFO:__main__:Epoch 8250: loss = 0.014, val_loss = 0.020
INFO:__main__:Epoch 8300: loss = 0.013, val_loss = 0.020
INFO:__main__:Epoch 8350: loss = 0.014, val_loss = 0.017
INFO:__main__:Epoch 8400: loss = 0.013, val_loss = 0.018
INFO:__main__:Epoch 8450: loss = 0.014, val_loss = 0.019
INFO:__main__:Epoch 8500: loss = 0.013, val_loss = 0.019
INFO:__main__:Epoch 8550: loss = 0.013, val_loss = 0.021
INFO:__main__:Epoch 8600: loss = 0.014, val_loss = 0.018
INFO:__main__:Epoch 8650: loss = 0.013, val_loss = 0.021
INFO:__main__:Epoch 8700: loss = 0.013, val_loss = 0.018
INFO:__main__:Epoch 8750: loss = 0.013, val_loss = 0.019
INFO:__main__:Epoch 8800: loss = 0.013, val_loss = 0.016
INFO:__main__:Epoch 8850: loss = 0.013, val_loss = 0.018
INFO:__main__:Epoch 8900: loss = 0.012, val_loss = 0.018
INFO:__main__:Epoch 8950: loss = 0.013, val_loss = 0.019
INFO:__main__:Epoch 9000: loss = 0.013, val_loss = 0.015
INFO:__main__:Epoch 9050: loss = 0.013, val_loss = 0.018
INFO:__main__:Epoch 9100: loss = 0.012, val_loss = 0.017
INFO:__main__:Epoch 9150: loss = 0.012, val_loss = 0.016
INFO:__main__:Epoch 9200: loss = 0.012, val_loss = 0.021
INFO:__main__:Epoch 9250: loss = 0.012, val_loss = 0.018
INFO:__main__:Epoch 9300: loss = 0.012, val_loss = 0.017
INFO:__main__:Epoch 9350: loss = 0.012, val_loss = 0.016
INFO:__main__:Epoch 9400: loss = 0.012, val_loss = 0.017
INFO:__main__:Epoch 9450: loss = 0.012, val_loss = 0.016
INFO:__main__:Epoch 9500: loss = 0.011, val_loss = 0.020
INFO:__main__:Epoch 9550: loss = 0.012, val_loss = 0.017
INFO:__main__:Epoch 9600: loss = 0.012, val_loss = 0.018
INFO:__main__:Epoch 9650: loss = 0.012, val_loss = 0.018
INFO:__main__:Epoch 9700: loss = 0.011, val_loss = 0.019
INFO:__main__:Epoch 9750: loss = 0.012, val_loss = 0.017
INFO:__main__:Epoch 9800: loss = 0.014, val_loss = 0.016
INFO:__main__:Epoch 9850: loss = 0.011, val_loss = 0.015
INFO:__main__:Epoch 9900: loss = 0.012, val_loss = 0.018
INFO:__main__:Epoch 9950: loss = 0.012, val_loss = 0.018
Accuracy: 96.11%
16
previous = smoothed_points[-1]
smoothed_points.append(previous * factor + point * (1 - factor))
else:
smoothed_points.append(point)
return smoothed_points
smooth_train_loss = smooth_curve(trainer.train_loss)
smooth_val_loss = smooth_curve(trainer.val_loss)
17
1.9 Fine-Tune NN
[ ]: def objective(trial):
# Define hyperparameters
n_layers = trial.suggest_int('n_layers', 1, 10)
# n_layers=1
hidden_sizes = [trial.suggest_int(f'hidden_size_{i}', 32, 128) for i in␣
↪range(n_layers)]
↪grad_clip=clip_value)
return accuracy
1.10 Predict
[91]: best_trial = {'n_layers': 1, 'hidden_size_0': 33, 'dropout_rate': 0.
↪001524880086886879, 'learning_rate': 0.09916060658342357, 'init_method':␣
best_value = 0.978
18
[92]: epochs = 20000
best_nn = NeuralNetwork(layers=[input_size, study.best_trial.
↪params['hidden_size_0'], output_size],
init_method=study.best_trial.params['init_method'],
loss_func=loss_func,
dropout_rate=study.best_trial.params['dropout_rate'],
clip_type=study.best_trial.params['clip_type'],
grad_clip=study.best_trial.params['clip_value'])
best_trainer = Trainer(best_nn, loss_func)
best_trainer.train(X_train, y_train, X_test, y_test, epochs, study.best_trial.
↪params['learning_rate'], early_stopping=False)
19
INFO:__main__:Epoch 1500: loss = 0.038, val_loss = 0.038
INFO:__main__:Epoch 1550: loss = 0.037, val_loss = 0.037
INFO:__main__:Epoch 1600: loss = 0.036, val_loss = 0.036
INFO:__main__:Epoch 1650: loss = 0.035, val_loss = 0.035
INFO:__main__:Epoch 1700: loss = 0.034, val_loss = 0.034
INFO:__main__:Epoch 1750: loss = 0.034, val_loss = 0.033
INFO:__main__:Epoch 1800: loss = 0.033, val_loss = 0.032
INFO:__main__:Epoch 1850: loss = 0.032, val_loss = 0.032
INFO:__main__:Epoch 1900: loss = 0.031, val_loss = 0.031
INFO:__main__:Epoch 1950: loss = 0.031, val_loss = 0.031
INFO:__main__:Epoch 2000: loss = 0.030, val_loss = 0.030
INFO:__main__:Epoch 2050: loss = 0.029, val_loss = 0.029
INFO:__main__:Epoch 2100: loss = 0.029, val_loss = 0.029
INFO:__main__:Epoch 2150: loss = 0.028, val_loss = 0.028
INFO:__main__:Epoch 2200: loss = 0.028, val_loss = 0.028
INFO:__main__:Epoch 2250: loss = 0.027, val_loss = 0.027
INFO:__main__:Epoch 2300: loss = 0.027, val_loss = 0.027
INFO:__main__:Epoch 2350: loss = 0.027, val_loss = 0.027
INFO:__main__:Epoch 2400: loss = 0.026, val_loss = 0.026
INFO:__main__:Epoch 2450: loss = 0.026, val_loss = 0.026
INFO:__main__:Epoch 2500: loss = 0.025, val_loss = 0.025
INFO:__main__:Epoch 2550: loss = 0.025, val_loss = 0.025
INFO:__main__:Epoch 2600: loss = 0.025, val_loss = 0.025
INFO:__main__:Epoch 2650: loss = 0.024, val_loss = 0.024
INFO:__main__:Epoch 2700: loss = 0.024, val_loss = 0.024
INFO:__main__:Epoch 2750: loss = 0.024, val_loss = 0.024
INFO:__main__:Epoch 2800: loss = 0.023, val_loss = 0.023
INFO:__main__:Epoch 2850: loss = 0.023, val_loss = 0.023
INFO:__main__:Epoch 2900: loss = 0.023, val_loss = 0.023
INFO:__main__:Epoch 2950: loss = 0.022, val_loss = 0.023
INFO:__main__:Epoch 3000: loss = 0.022, val_loss = 0.023
INFO:__main__:Epoch 3050: loss = 0.022, val_loss = 0.022
INFO:__main__:Epoch 3100: loss = 0.022, val_loss = 0.022
INFO:__main__:Epoch 3150: loss = 0.021, val_loss = 0.022
INFO:__main__:Epoch 3200: loss = 0.021, val_loss = 0.022
INFO:__main__:Epoch 3250: loss = 0.021, val_loss = 0.022
INFO:__main__:Epoch 3300: loss = 0.021, val_loss = 0.021
INFO:__main__:Epoch 3350: loss = 0.020, val_loss = 0.021
INFO:__main__:Epoch 3400: loss = 0.020, val_loss = 0.021
INFO:__main__:Epoch 3450: loss = 0.020, val_loss = 0.021
INFO:__main__:Epoch 3500: loss = 0.019, val_loss = 0.020
INFO:__main__:Epoch 3550: loss = 0.019, val_loss = 0.020
INFO:__main__:Epoch 3600: loss = 0.019, val_loss = 0.020
INFO:__main__:Epoch 3650: loss = 0.019, val_loss = 0.020
INFO:__main__:Epoch 3700: loss = 0.019, val_loss = 0.020
INFO:__main__:Epoch 3750: loss = 0.019, val_loss = 0.019
INFO:__main__:Epoch 3800: loss = 0.018, val_loss = 0.019
INFO:__main__:Epoch 3850: loss = 0.018, val_loss = 0.019
20
INFO:__main__:Epoch 3900: loss = 0.018, val_loss = 0.019
INFO:__main__:Epoch 3950: loss = 0.018, val_loss = 0.019
INFO:__main__:Epoch 4000: loss = 0.018, val_loss = 0.019
INFO:__main__:Epoch 4050: loss = 0.018, val_loss = 0.019
INFO:__main__:Epoch 4100: loss = 0.017, val_loss = 0.018
INFO:__main__:Epoch 4150: loss = 0.017, val_loss = 0.019
INFO:__main__:Epoch 4200: loss = 0.017, val_loss = 0.019
INFO:__main__:Epoch 4250: loss = 0.017, val_loss = 0.018
INFO:__main__:Epoch 4300: loss = 0.017, val_loss = 0.018
INFO:__main__:Epoch 4350: loss = 0.017, val_loss = 0.018
INFO:__main__:Epoch 4400: loss = 0.017, val_loss = 0.018
INFO:__main__:Epoch 4450: loss = 0.016, val_loss = 0.018
INFO:__main__:Epoch 4500: loss = 0.016, val_loss = 0.018
INFO:__main__:Epoch 4550: loss = 0.016, val_loss = 0.018
INFO:__main__:Epoch 4600: loss = 0.016, val_loss = 0.017
INFO:__main__:Epoch 4650: loss = 0.016, val_loss = 0.018
INFO:__main__:Epoch 4700: loss = 0.016, val_loss = 0.018
INFO:__main__:Epoch 4750: loss = 0.016, val_loss = 0.017
INFO:__main__:Epoch 4800: loss = 0.015, val_loss = 0.017
INFO:__main__:Epoch 4850: loss = 0.015, val_loss = 0.017
INFO:__main__:Epoch 4900: loss = 0.015, val_loss = 0.017
INFO:__main__:Epoch 4950: loss = 0.015, val_loss = 0.017
INFO:__main__:Epoch 5000: loss = 0.015, val_loss = 0.017
INFO:__main__:Epoch 5050: loss = 0.015, val_loss = 0.017
INFO:__main__:Epoch 5100: loss = 0.015, val_loss = 0.017
INFO:__main__:Epoch 5150: loss = 0.015, val_loss = 0.017
INFO:__main__:Epoch 5200: loss = 0.014, val_loss = 0.017
INFO:__main__:Epoch 5250: loss = 0.015, val_loss = 0.016
INFO:__main__:Epoch 5300: loss = 0.014, val_loss = 0.016
INFO:__main__:Epoch 5350: loss = 0.014, val_loss = 0.016
INFO:__main__:Epoch 5400: loss = 0.014, val_loss = 0.016
INFO:__main__:Epoch 5450: loss = 0.014, val_loss = 0.016
INFO:__main__:Epoch 5500: loss = 0.014, val_loss = 0.016
INFO:__main__:Epoch 5550: loss = 0.014, val_loss = 0.016
INFO:__main__:Epoch 5600: loss = 0.014, val_loss = 0.016
INFO:__main__:Epoch 5650: loss = 0.013, val_loss = 0.016
INFO:__main__:Epoch 5700: loss = 0.013, val_loss = 0.016
INFO:__main__:Epoch 5750: loss = 0.013, val_loss = 0.016
INFO:__main__:Epoch 5800: loss = 0.013, val_loss = 0.016
INFO:__main__:Epoch 5850: loss = 0.013, val_loss = 0.016
INFO:__main__:Epoch 5900: loss = 0.013, val_loss = 0.015
INFO:__main__:Epoch 5950: loss = 0.013, val_loss = 0.015
INFO:__main__:Epoch 6000: loss = 0.013, val_loss = 0.017
INFO:__main__:Epoch 6050: loss = 0.013, val_loss = 0.016
INFO:__main__:Epoch 6100: loss = 0.013, val_loss = 0.016
INFO:__main__:Epoch 6150: loss = 0.013, val_loss = 0.015
INFO:__main__:Epoch 6200: loss = 0.013, val_loss = 0.015
INFO:__main__:Epoch 6250: loss = 0.012, val_loss = 0.015
21
INFO:__main__:Epoch 6300: loss = 0.012, val_loss = 0.015
INFO:__main__:Epoch 6350: loss = 0.012, val_loss = 0.015
INFO:__main__:Epoch 6400: loss = 0.012, val_loss = 0.015
INFO:__main__:Epoch 6450: loss = 0.012, val_loss = 0.015
INFO:__main__:Epoch 6500: loss = 0.012, val_loss = 0.015
INFO:__main__:Epoch 6550: loss = 0.012, val_loss = 0.015
INFO:__main__:Epoch 6600: loss = 0.012, val_loss = 0.015
INFO:__main__:Epoch 6650: loss = 0.012, val_loss = 0.015
INFO:__main__:Epoch 6700: loss = 0.011, val_loss = 0.016
INFO:__main__:Epoch 6750: loss = 0.012, val_loss = 0.015
INFO:__main__:Epoch 6800: loss = 0.011, val_loss = 0.015
INFO:__main__:Epoch 6850: loss = 0.012, val_loss = 0.015
INFO:__main__:Epoch 6900: loss = 0.011, val_loss = 0.015
INFO:__main__:Epoch 6950: loss = 0.011, val_loss = 0.015
INFO:__main__:Epoch 7000: loss = 0.011, val_loss = 0.014
INFO:__main__:Epoch 7050: loss = 0.011, val_loss = 0.014
INFO:__main__:Epoch 7100: loss = 0.011, val_loss = 0.014
INFO:__main__:Epoch 7150: loss = 0.011, val_loss = 0.015
INFO:__main__:Epoch 7200: loss = 0.011, val_loss = 0.014
INFO:__main__:Epoch 7250: loss = 0.011, val_loss = 0.014
INFO:__main__:Epoch 7300: loss = 0.011, val_loss = 0.014
INFO:__main__:Epoch 7350: loss = 0.011, val_loss = 0.014
INFO:__main__:Epoch 7400: loss = 0.011, val_loss = 0.014
INFO:__main__:Epoch 7450: loss = 0.011, val_loss = 0.014
INFO:__main__:Epoch 7500: loss = 0.010, val_loss = 0.014
INFO:__main__:Epoch 7550: loss = 0.010, val_loss = 0.014
INFO:__main__:Epoch 7600: loss = 0.011, val_loss = 0.014
INFO:__main__:Epoch 7650: loss = 0.010, val_loss = 0.014
INFO:__main__:Epoch 7700: loss = 0.010, val_loss = 0.013
INFO:__main__:Epoch 7750: loss = 0.010, val_loss = 0.014
INFO:__main__:Epoch 7800: loss = 0.010, val_loss = 0.014
INFO:__main__:Epoch 7850: loss = 0.010, val_loss = 0.014
INFO:__main__:Epoch 7900: loss = 0.010, val_loss = 0.014
INFO:__main__:Epoch 7950: loss = 0.010, val_loss = 0.014
INFO:__main__:Epoch 8000: loss = 0.010, val_loss = 0.014
INFO:__main__:Epoch 8050: loss = 0.010, val_loss = 0.014
INFO:__main__:Epoch 8100: loss = 0.010, val_loss = 0.014
INFO:__main__:Epoch 8150: loss = 0.010, val_loss = 0.014
INFO:__main__:Epoch 8200: loss = 0.010, val_loss = 0.014
INFO:__main__:Epoch 8250: loss = 0.010, val_loss = 0.014
INFO:__main__:Epoch 8300: loss = 0.009, val_loss = 0.014
INFO:__main__:Epoch 8350: loss = 0.009, val_loss = 0.014
INFO:__main__:Epoch 8400: loss = 0.009, val_loss = 0.014
INFO:__main__:Epoch 8450: loss = 0.009, val_loss = 0.014
INFO:__main__:Epoch 8500: loss = 0.009, val_loss = 0.013
INFO:__main__:Epoch 8550: loss = 0.009, val_loss = 0.014
INFO:__main__:Epoch 8600: loss = 0.009, val_loss = 0.014
INFO:__main__:Epoch 8650: loss = 0.009, val_loss = 0.013
22
INFO:__main__:Epoch 8700: loss = 0.009, val_loss = 0.014
INFO:__main__:Epoch 8750: loss = 0.009, val_loss = 0.014
INFO:__main__:Epoch 8800: loss = 0.009, val_loss = 0.013
INFO:__main__:Epoch 8850: loss = 0.009, val_loss = 0.014
INFO:__main__:Epoch 8900: loss = 0.009, val_loss = 0.013
INFO:__main__:Epoch 8950: loss = 0.009, val_loss = 0.013
INFO:__main__:Epoch 9000: loss = 0.009, val_loss = 0.013
INFO:__main__:Epoch 9050: loss = 0.009, val_loss = 0.013
INFO:__main__:Epoch 9100: loss = 0.009, val_loss = 0.013
INFO:__main__:Epoch 9150: loss = 0.009, val_loss = 0.014
INFO:__main__:Epoch 9200: loss = 0.009, val_loss = 0.013
INFO:__main__:Epoch 9250: loss = 0.009, val_loss = 0.013
INFO:__main__:Epoch 9300: loss = 0.009, val_loss = 0.013
INFO:__main__:Epoch 9350: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 9400: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 9450: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 9500: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 9550: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 9600: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 9650: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 9700: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 9750: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 9800: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 9850: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 9900: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 9950: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 10000: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 10050: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 10100: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 10150: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 10200: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 10250: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 10300: loss = 0.008, val_loss = 0.012
INFO:__main__:Epoch 10350: loss = 0.007, val_loss = 0.014
INFO:__main__:Epoch 10400: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 10450: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 10500: loss = 0.008, val_loss = 0.013
INFO:__main__:Epoch 10550: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 10600: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 10650: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 10700: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 10750: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 10800: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 10850: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 10900: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 10950: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 11000: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 11050: loss = 0.007, val_loss = 0.013
23
INFO:__main__:Epoch 11100: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 11150: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 11200: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 11250: loss = 0.007, val_loss = 0.012
INFO:__main__:Epoch 11300: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 11350: loss = 0.007, val_loss = 0.012
INFO:__main__:Epoch 11400: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 11450: loss = 0.007, val_loss = 0.012
INFO:__main__:Epoch 11500: loss = 0.007, val_loss = 0.012
INFO:__main__:Epoch 11550: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 11600: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 11650: loss = 0.007, val_loss = 0.012
INFO:__main__:Epoch 11700: loss = 0.007, val_loss = 0.012
INFO:__main__:Epoch 11750: loss = 0.007, val_loss = 0.013
INFO:__main__:Epoch 11800: loss = 0.007, val_loss = 0.012
INFO:__main__:Epoch 11850: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 11900: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 11950: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12000: loss = 0.006, val_loss = 0.013
INFO:__main__:Epoch 12050: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12100: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12150: loss = 0.006, val_loss = 0.013
INFO:__main__:Epoch 12200: loss = 0.006, val_loss = 0.013
INFO:__main__:Epoch 12250: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12300: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12350: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12400: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12450: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12500: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12550: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12600: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12650: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12700: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12750: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12800: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12850: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12900: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 12950: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13000: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13050: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13100: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13150: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13200: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13250: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13300: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13350: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13400: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13450: loss = 0.006, val_loss = 0.012
24
INFO:__main__:Epoch 13500: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13550: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13600: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13650: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13700: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 13750: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13800: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13850: loss = 0.006, val_loss = 0.012
INFO:__main__:Epoch 13900: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 13950: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14000: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14050: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14100: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14150: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14200: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14250: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14300: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14350: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14400: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14450: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14500: loss = 0.005, val_loss = 0.013
INFO:__main__:Epoch 14550: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14600: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14650: loss = 0.005, val_loss = 0.013
INFO:__main__:Epoch 14700: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14750: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14800: loss = 0.005, val_loss = 0.013
INFO:__main__:Epoch 14850: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14900: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 14950: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15000: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15050: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15100: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15150: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15200: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15250: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15300: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15350: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15400: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15450: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15500: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15550: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15600: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15650: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15700: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15750: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15800: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15850: loss = 0.005, val_loss = 0.012
25
INFO:__main__:Epoch 15900: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 15950: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 16000: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 16050: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 16100: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 16150: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 16200: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 16250: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 16300: loss = 0.005, val_loss = 0.012
INFO:__main__:Epoch 16350: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 16400: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 16450: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 16500: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 16550: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 16600: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 16650: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 16700: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 16750: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 16800: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 16850: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 16900: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 16950: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17000: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17050: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17100: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17150: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17200: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17250: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17300: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17350: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17400: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17450: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17500: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17550: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17600: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17650: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17700: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17750: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17800: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17850: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17900: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 17950: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18000: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18050: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18100: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18150: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18200: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18250: loss = 0.004, val_loss = 0.012
26
INFO:__main__:Epoch 18300: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18350: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18400: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18450: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18500: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18550: loss = 0.004, val_loss = 0.011
INFO:__main__:Epoch 18600: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18650: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18700: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18750: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18800: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18850: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18900: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 18950: loss = 0.004, val_loss = 0.011
INFO:__main__:Epoch 19000: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 19050: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 19100: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 19150: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 19200: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 19250: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 19300: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 19350: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 19400: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 19450: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 19500: loss = 0.003, val_loss = 0.012
INFO:__main__:Epoch 19550: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 19600: loss = 0.003, val_loss = 0.013
INFO:__main__:Epoch 19650: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 19700: loss = 0.003, val_loss = 0.012
INFO:__main__:Epoch 19750: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 19800: loss = 0.003, val_loss = 0.011
INFO:__main__:Epoch 19850: loss = 0.003, val_loss = 0.012
INFO:__main__:Epoch 19900: loss = 0.004, val_loss = 0.012
INFO:__main__:Epoch 19950: loss = 0.003, val_loss = 0.012
Best accuracy: 97.78%
smooth_train_loss = smooth_curve(trainer.train_loss)
27
smooth_val_loss = smooth_curve(trainer.val_loss)
28