In [13]:
import numpy as np
from numpy import random
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models, datasets, callbacks, losses, optimizers, metrics
import tensorflow.keras.backend as K
In [2]:
IMAGE_SIZE = 32
BATCH_SIZE= 100
VALIDATION_SPLIT = 0.2
EMBEDDING_DIM = 2
BETA = 500
In [3]:
(x_train,y_train),(x_test,y_test) = datasets.fashion_mnist.load_data()
x_train = x_train.astype("float32")/255
x_train = np.pad(x_train, ((0,0), (2,2), (2,2)), constant_values=0.0)
x_train = np.expand_dims(x_train,-1)
x_test = x_test.astype("float32")/255
x_test = np.pad(x_test, ((0,0), (2,2), (2,2)), constant_values=0.0)
x_test = np.expand_dims(x_test,-1)
def display(
images, n=10, size=(20, 3), cmap="gray_r", as_type="float32", save_to=None
):
"""
Displays n random images from each one of the supplied arrays.
"""
if images.max() > 1.0:
images = images / 255.0
elif images.min() < 0.0:
images = (images + 1.0) / 2.0
plt.figure(figsize=size)
for i in range(n):
_ = plt.subplot(1, n, i + 1)
plt.imshow(images[i].astype(as_type), cmap=cmap)
plt.axis("off")
if save_to:
plt.savefig(save_to)
print(f"\nSaved to {save_to}")
plt.show()
display(x_train)
Before we build the VAE, we need to define a new layer for sampling the latent vectors (z_mean, z_log_var)
In [14]:
class Sampling(layers.Layer):
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = K.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
In [15]:
# Encoder
encoder_input = layers.Input(
shape=(IMAGE_SIZE, IMAGE_SIZE, 1), name="encoder_input"
)
x = layers.Conv2D(32, (3, 3), strides=2, activation="relu", padding="same")(
encoder_input
)
x = layers.Conv2D(64, (3, 3), strides=2, activation="relu", padding="same")(x)
x = layers.Conv2D(128, (3, 3), strides=2, activation="relu", padding="same")(x)
x = layers.Flatten()(x)
z_mean = layers.Dense(EMBEDDING_DIM, name="z_mean")(x)
z_log_var = layers.Dense(EMBEDDING_DIM, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = models.Model(encoder_input, [z_mean, z_log_var, z], name="encoder")
encoder.summary()
Model: "encoder"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ encoder_input │ (None, 32, 32, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2d_3 (Conv2D) │ (None, 16, 16, │ 320 │ encoder_input[0]… │ │ │ 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2d_4 (Conv2D) │ (None, 8, 8, 64) │ 18,496 │ conv2d_3[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2d_5 (Conv2D) │ (None, 4, 4, 128) │ 73,856 │ conv2d_4[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ flatten_1 (Flatten) │ (None, 2048) │ 0 │ conv2d_5[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ z_mean (Dense) │ (None, 2) │ 4,098 │ flatten_1[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ z_log_var (Dense) │ (None, 2) │ 4,098 │ flatten_1[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ sampling_1 │ (None, 2) │ 0 │ z_mean[0][0], │ │ (Sampling) │ │ │ z_log_var[0][0] │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 100,868 (394.02 KB)
Trainable params: 100,868 (394.02 KB)
Non-trainable params: 0 (0.00 B)
In [16]:
# Decoder
decoder_input = layers.Input(shape=(EMBEDDING_DIM,), name="decoder_input")
x = layers.Dense(2048)(decoder_input)
x = layers.Reshape((4,4,128))(x)
x = layers.Conv2DTranspose(
128, (3, 3), strides=2, activation="relu", padding="same"
)(x)
x = layers.Conv2DTranspose(
64, (3, 3), strides=2, activation="relu", padding="same"
)(x)
x = layers.Conv2DTranspose(
32, (3, 3), strides=2, activation="relu", padding="same"
)(x)
decoder_output = layers.Conv2D(
1,
(3, 3),
strides=1,
activation="sigmoid",
padding="same",
name="decoder_output",
)(x)
decoder = models.Model(decoder_input, decoder_output)
decoder.summary()
Model: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ decoder_input (InputLayer) │ (None, 2) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 2048) │ 6,144 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ reshape_1 (Reshape) │ (None, 4, 4, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_transpose_3 │ (None, 8, 8, 128) │ 147,584 │ │ (Conv2DTranspose) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_transpose_4 │ (None, 16, 16, 64) │ 73,792 │ │ (Conv2DTranspose) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_transpose_5 │ (None, 32, 32, 32) │ 18,464 │ │ (Conv2DTranspose) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ decoder_output (Conv2D) │ (None, 32, 32, 1) │ 289 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 246,273 (962.00 KB)
Trainable params: 246,273 (962.00 KB)
Non-trainable params: 0 (0.00 B)
In [17]:
class VAE(models.Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.total_loss_tracker = metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = metrics.Mean(
name="reconstruction_loss"
)
self.kl_loss_tracker = metrics.Mean(name="kl_loss")
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def call(self, inputs):
"""Call the model on a particular input."""
z_mean, z_log_var, z = encoder(inputs)
reconstruction = decoder(z)
return z_mean, z_log_var, reconstruction
def train_step(self, data):
"""Step run during training."""
with tf.GradientTape() as tape:
z_mean, z_log_var, reconstruction = self(data)
reconstruction_loss = tf.reduce_mean(
BETA
* losses.binary_crossentropy(
data, reconstruction, axis=(1, 2, 3)
)
)
kl_loss = tf.reduce_mean(
tf.reduce_sum(
-0.5
* (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)),
axis=1,
)
)
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {m.name: m.result() for m in self.metrics}
def test_step(self, data):
"""Step run during validation."""
if isinstance(data, tuple):
data = data[0]
z_mean, z_log_var, reconstruction = self(data)
reconstruction_loss = tf.reduce_mean(
BETA
* losses.binary_crossentropy(data, reconstruction, axis=(1, 2, 3))
)
kl_loss = tf.reduce_mean(
tf.reduce_sum(
-0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)),
axis=1,
)
)
total_loss = reconstruction_loss + kl_loss
return {
"loss": total_loss,
"reconstruction_loss": reconstruction_loss,
"kl_loss": kl_loss,
}
In [18]:
vae = VAE(encoder,decoder)
In [19]:
optimizer = optimizers.Adam(learning_rate=0.0005)
vae.compile(optimizer=optimizer)
In [20]:
callback_ckpt = [
keras.callbacks.ModelCheckpoint(
filepath = "vae.keras",
save_best_only = True,
verbose = 1,
mode = "min",
monitor = "val_loss")]
callback_stopping = [
keras.callbacks.EarlyStopping(
monitor = "val_loss",
patience = 5,
)
]
history = vae.fit(
x_train,
epochs=50,
batch_size=BATCH_SIZE,
shuffle=True,
validation_data=(x_test, x_test),
callbacks = [callback_ckpt, callback_stopping],
)
Epoch 1/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - kl_loss: 3.6191 - reconstruction_loss: 197.0072 - total_loss: 200.6263 Epoch 1: val_loss improved from inf to 143.72160, saving model to vae.keras 600/600 ━━━━━━━━━━━━━━━━━━━━ 42s 67ms/step - kl_loss: 3.6204 - reconstruction_loss: 196.9417 - total_loss: 200.5622 - val_kl_loss: 4.8065 - val_loss: 143.7216 - val_reconstruction_loss: 138.9151 Epoch 2/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - kl_loss: 4.8654 - reconstruction_loss: 133.1712 - total_loss: 138.0367 Epoch 2: val_loss improved from 143.72160 to 139.37596, saving model to vae.keras 600/600 ━━━━━━━━━━━━━━━━━━━━ 45s 74ms/step - kl_loss: 4.8655 - reconstruction_loss: 133.1698 - total_loss: 138.0353 - val_kl_loss: 5.1290 - val_loss: 139.3760 - val_reconstruction_loss: 134.2469 Epoch 3/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 72ms/step - kl_loss: 5.0043 - reconstruction_loss: 130.0686 - total_loss: 135.0729 Epoch 3: val_loss improved from 139.37596 to 137.09091, saving model to vae.keras 600/600 ━━━━━━━━━━━━━━━━━━━━ 46s 76ms/step - kl_loss: 5.0043 - reconstruction_loss: 130.0680 - total_loss: 135.0723 - val_kl_loss: 5.2566 - val_loss: 137.0909 - val_reconstruction_loss: 131.8344 Epoch 4/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - kl_loss: 5.0656 - reconstruction_loss: 128.6648 - total_loss: 133.7305 Epoch 4: val_loss improved from 137.09091 to 136.22583, saving model to vae.keras 600/600 ━━━━━━━━━━━━━━━━━━━━ 43s 71ms/step - kl_loss: 5.0657 - reconstruction_loss: 128.6645 - total_loss: 133.7302 - val_kl_loss: 5.0631 - val_loss: 136.2258 - val_reconstruction_loss: 131.1627 Epoch 5/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - kl_loss: 5.1394 - reconstruction_loss: 127.7582 - total_loss: 132.8975 Epoch 5: val_loss improved from 136.22583 to 135.71503, saving model to vae.keras 600/600 ━━━━━━━━━━━━━━━━━━━━ 43s 72ms/step - kl_loss: 5.1394 - reconstruction_loss: 127.7581 - total_loss: 132.8974 - val_kl_loss: 5.1251 - val_loss: 135.7150 - val_reconstruction_loss: 130.5899 Epoch 6/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - kl_loss: 5.1423 - reconstruction_loss: 127.0273 - total_loss: 132.1695 Epoch 6: val_loss improved from 135.71503 to 134.45528, saving model to vae.keras 600/600 ━━━━━━━━━━━━━━━━━━━━ 43s 71ms/step - kl_loss: 5.1423 - reconstruction_loss: 127.0276 - total_loss: 132.1699 - val_kl_loss: 5.2265 - val_loss: 134.4553 - val_reconstruction_loss: 129.2288 Epoch 7/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - kl_loss: 5.1915 - reconstruction_loss: 126.8037 - total_loss: 131.9953 Epoch 7: val_loss did not improve from 134.45528 600/600 ━━━━━━━━━━━━━━━━━━━━ 45s 75ms/step - kl_loss: 5.1915 - reconstruction_loss: 126.8038 - total_loss: 131.9953 - val_kl_loss: 5.3299 - val_loss: 135.1975 - val_reconstruction_loss: 129.8676 Epoch 8/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 72ms/step - kl_loss: 5.2241 - reconstruction_loss: 126.0335 - total_loss: 131.2575 Epoch 8: val_loss did not improve from 134.45528 600/600 ━━━━━━━━━━━━━━━━━━━━ 46s 76ms/step - kl_loss: 5.2241 - reconstruction_loss: 126.0342 - total_loss: 131.2582 - val_kl_loss: 5.3514 - val_loss: 134.6660 - val_reconstruction_loss: 129.3146 Epoch 9/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - kl_loss: 5.2268 - reconstruction_loss: 126.4103 - total_loss: 131.6370 Epoch 9: val_loss improved from 134.45528 to 134.22743, saving model to vae.keras 600/600 ━━━━━━━━━━━━━━━━━━━━ 42s 71ms/step - kl_loss: 5.2268 - reconstruction_loss: 126.4100 - total_loss: 131.6367 - val_kl_loss: 5.3149 - val_loss: 134.2274 - val_reconstruction_loss: 128.9125 Epoch 10/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - kl_loss: 5.2455 - reconstruction_loss: 126.0977 - total_loss: 131.3432 Epoch 10: val_loss improved from 134.22743 to 134.18462, saving model to vae.keras 600/600 ━━━━━━━━━━━━━━━━━━━━ 42s 70ms/step - kl_loss: 5.2455 - reconstruction_loss: 126.0974 - total_loss: 131.3429 - val_kl_loss: 5.5390 - val_loss: 134.1846 - val_reconstruction_loss: 128.6456 Epoch 11/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - kl_loss: 5.2846 - reconstruction_loss: 125.8529 - total_loss: 131.1376 Epoch 11: val_loss improved from 134.18462 to 134.08937, saving model to vae.keras 600/600 ━━━━━━━━━━━━━━━━━━━━ 42s 71ms/step - kl_loss: 5.2846 - reconstruction_loss: 125.8528 - total_loss: 131.1374 - val_kl_loss: 5.3501 - val_loss: 134.0894 - val_reconstruction_loss: 128.7393 Epoch 12/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - kl_loss: 5.2762 - reconstruction_loss: 125.6316 - total_loss: 130.9078 Epoch 12: val_loss improved from 134.08937 to 133.87436, saving model to vae.keras 600/600 ━━━━━━━━━━━━━━━━━━━━ 42s 70ms/step - kl_loss: 5.2762 - reconstruction_loss: 125.6314 - total_loss: 130.9076 - val_kl_loss: 5.3252 - val_loss: 133.8744 - val_reconstruction_loss: 128.5491 Epoch 13/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - kl_loss: 5.3113 - reconstruction_loss: 125.2049 - total_loss: 130.5162 Epoch 13: val_loss improved from 133.87436 to 133.60133, saving model to vae.keras 600/600 ━━━━━━━━━━━━━━━━━━━━ 43s 71ms/step - kl_loss: 5.3113 - reconstruction_loss: 125.2051 - total_loss: 130.5164 - val_kl_loss: 5.3738 - val_loss: 133.6013 - val_reconstruction_loss: 128.2276 Epoch 14/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - kl_loss: 5.3304 - reconstruction_loss: 124.8728 - total_loss: 130.2032 Epoch 14: val_loss improved from 133.60133 to 133.26132, saving model to vae.keras 600/600 ━━━━━━━━━━━━━━━━━━━━ 42s 70ms/step - kl_loss: 5.3304 - reconstruction_loss: 124.8733 - total_loss: 130.2037 - val_kl_loss: 5.4722 - val_loss: 133.2613 - val_reconstruction_loss: 127.7891 Epoch 15/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - kl_loss: 5.3427 - reconstruction_loss: 124.8862 - total_loss: 130.2288 Epoch 15: val_loss improved from 133.26132 to 133.07672, saving model to vae.keras 600/600 ━━━━━━━━━━━━━━━━━━━━ 44s 73ms/step - kl_loss: 5.3427 - reconstruction_loss: 124.8863 - total_loss: 130.2290 - val_kl_loss: 5.5144 - val_loss: 133.0767 - val_reconstruction_loss: 127.5624 Epoch 16/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - kl_loss: 5.3468 - reconstruction_loss: 124.8457 - total_loss: 130.1925 Epoch 16: val_loss did not improve from 133.07672 600/600 ━━━━━━━━━━━━━━━━━━━━ 44s 73ms/step - kl_loss: 5.3468 - reconstruction_loss: 124.8456 - total_loss: 130.1924 - val_kl_loss: 5.5070 - val_loss: 133.0948 - val_reconstruction_loss: 127.5878 Epoch 17/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - kl_loss: 5.3613 - reconstruction_loss: 124.6065 - total_loss: 129.9678 Epoch 17: val_loss did not improve from 133.07672 600/600 ━━━━━━━━━━━━━━━━━━━━ 45s 75ms/step - kl_loss: 5.3613 - reconstruction_loss: 124.6067 - total_loss: 129.9680 - val_kl_loss: 5.7375 - val_loss: 133.5044 - val_reconstruction_loss: 127.7669 Epoch 18/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - kl_loss: 5.3761 - reconstruction_loss: 124.4428 - total_loss: 129.8188 Epoch 18: val_loss improved from 133.07672 to 132.99593, saving model to vae.keras 600/600 ━━━━━━━━━━━━━━━━━━━━ 45s 74ms/step - kl_loss: 5.3761 - reconstruction_loss: 124.4430 - total_loss: 129.8191 - val_kl_loss: 5.5359 - val_loss: 132.9959 - val_reconstruction_loss: 127.4600 Epoch 19/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - kl_loss: 5.3926 - reconstruction_loss: 124.4739 - total_loss: 129.8665 Epoch 19: val_loss did not improve from 132.99593 600/600 ━━━━━━━━━━━━━━━━━━━━ 42s 70ms/step - kl_loss: 5.3926 - reconstruction_loss: 124.4739 - total_loss: 129.8665 - val_kl_loss: 5.3588 - val_loss: 133.1993 - val_reconstruction_loss: 127.8404 Epoch 20/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 72ms/step - kl_loss: 5.4165 - reconstruction_loss: 124.0994 - total_loss: 129.5159 Epoch 20: val_loss did not improve from 132.99593 600/600 ━━━━━━━━━━━━━━━━━━━━ 45s 76ms/step - kl_loss: 5.4165 - reconstruction_loss: 124.0997 - total_loss: 129.5163 - val_kl_loss: 5.6431 - val_loss: 133.0858 - val_reconstruction_loss: 127.4427 Epoch 21/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - kl_loss: 5.4349 - reconstruction_loss: 124.1673 - total_loss: 129.6022 Epoch 21: val_loss improved from 132.99593 to 132.51706, saving model to vae.keras 600/600 ━━━━━━━━━━━━━━━━━━━━ 43s 72ms/step - kl_loss: 5.4349 - reconstruction_loss: 124.1672 - total_loss: 129.6022 - val_kl_loss: 5.5932 - val_loss: 132.5171 - val_reconstruction_loss: 126.9238 Epoch 22/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - kl_loss: 5.4462 - reconstruction_loss: 124.2803 - total_loss: 129.7265 Epoch 22: val_loss did not improve from 132.51706 600/600 ━━━━━━━━━━━━━━━━━━━━ 39s 65ms/step - kl_loss: 5.4462 - reconstruction_loss: 124.2800 - total_loss: 129.7262 - val_kl_loss: 5.5098 - val_loss: 132.6221 - val_reconstruction_loss: 127.1123 Epoch 23/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - kl_loss: 5.4536 - reconstruction_loss: 124.0727 - total_loss: 129.5263 Epoch 23: val_loss improved from 132.51706 to 132.10178, saving model to vae.keras 600/600 ━━━━━━━━━━━━━━━━━━━━ 38s 63ms/step - kl_loss: 5.4536 - reconstruction_loss: 124.0726 - total_loss: 129.5262 - val_kl_loss: 5.5333 - val_loss: 132.1018 - val_reconstruction_loss: 126.5685 Epoch 24/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - kl_loss: 5.4688 - reconstruction_loss: 123.9932 - total_loss: 129.4619 Epoch 24: val_loss did not improve from 132.10178 600/600 ━━━━━━━━━━━━━━━━━━━━ 38s 63ms/step - kl_loss: 5.4688 - reconstruction_loss: 123.9930 - total_loss: 129.4617 - val_kl_loss: 5.5359 - val_loss: 132.7426 - val_reconstruction_loss: 127.2067 Epoch 25/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - kl_loss: 5.4781 - reconstruction_loss: 123.9249 - total_loss: 129.4030 Epoch 25: val_loss did not improve from 132.10178 600/600 ━━━━━━━━━━━━━━━━━━━━ 38s 64ms/step - kl_loss: 5.4781 - reconstruction_loss: 123.9248 - total_loss: 129.4029 - val_kl_loss: 5.5658 - val_loss: 132.4855 - val_reconstruction_loss: 126.9197 Epoch 26/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - kl_loss: 5.4895 - reconstruction_loss: 123.9124 - total_loss: 129.4019 Epoch 26: val_loss did not improve from 132.10178 600/600 ━━━━━━━━━━━━━━━━━━━━ 38s 64ms/step - kl_loss: 5.4895 - reconstruction_loss: 123.9121 - total_loss: 129.4016 - val_kl_loss: 5.5336 - val_loss: 132.5003 - val_reconstruction_loss: 126.9667 Epoch 27/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - kl_loss: 5.4955 - reconstruction_loss: 123.6972 - total_loss: 129.1927 Epoch 27: val_loss did not improve from 132.10178 600/600 ━━━━━━━━━━━━━━━━━━━━ 38s 64ms/step - kl_loss: 5.4955 - reconstruction_loss: 123.6970 - total_loss: 129.1925 - val_kl_loss: 5.3785 - val_loss: 132.5270 - val_reconstruction_loss: 127.1486 Epoch 28/50 600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - kl_loss: 5.5002 - reconstruction_loss: 123.7011 - total_loss: 129.2012 Epoch 28: val_loss did not improve from 132.10178 600/600 ━━━━━━━━━━━━━━━━━━━━ 38s 63ms/step - kl_loss: 5.5002 - reconstruction_loss: 123.7009 - total_loss: 129.2011 - val_kl_loss: 5.5304 - val_loss: 132.2309 - val_reconstruction_loss: 126.7005
In [21]:
#best_model = keras.models.load_model("models/VAE.keras")
#test_loss = best_model.evaluate(x_test,x_test)
#print(f"test loss = {test_loss:.3f}")
def training_curves(history):
train_loss = history.history["total_loss"]
val_loss = history.history["val_loss"]
epochs = np.arange(1,len(train_loss) + 1)
plt.plot(epochs, train_loss, "b--", label="p-training total loss")
plt.plot(epochs, val_loss, "b", label="validation total loss")
plt.title("p-Training and Validation VAE total Losses")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
training_curves(history)
In [22]:
n_to_predict = 5000
example_images = x_test[:n_to_predict]
example_labels = y_test[:n_to_predict]
z_mean, z_log_var, reconstructions = vae.predict(example_images)
print("Example real clothing items")
display(example_images)
print("Reconstructions")
display(reconstructions)
157/157 ━━━━━━━━━━━━━━━━━━━━ 2s 11ms/step Example real clothing items
Reconstructions
get encoder "samples" of latent space for each example image
In [23]:
# Encode the example images
z_mean, z_var, z = encoder.predict(example_images)
figsize = 8
fig = plt.figure(figsize=(figsize * 2, figsize))
ax = fig.add_subplot(1, 2, 1)
plot_1 = ax.scatter(
z[:, 0], z[:, 1], cmap="rainbow", c=example_labels, alpha=0.8, s=3
)
plt.colorbar(plot_1)
# Convert original embeddings and sampled embeddings to p-values
from scipy.stats import norm
p = norm.cdf(z)
157/157 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step
In [24]:
# Sample some points in the latent space, from the standard normal distribution
grid_width, grid_height = (6, 3)
z_sample = np.random.normal(size=(grid_width * grid_height, 2))
# Decode the sampled points
reconstructions = decoder.predict(z_sample)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step
In [25]:
# Draw a plot of...
figsize = 8
plt.figure(figsize=(figsize, figsize))
# ... the original embeddings ...
plt.scatter(z[:, 0], z[:, 1], c="black", alpha=0.5, s=2)
# ... and the newly generated points in the latent space
plt.scatter(z_sample[:, 0], z_sample[:, 1], c="#00B0F0", alpha=1, s=40)
plt.show()
# Add underneath a grid of the decoded images
fig = plt.figure(figsize=(figsize, grid_height * 2))
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i in range(grid_width * grid_height):
ax = fig.add_subplot(grid_height, grid_width, i + 1)
ax.axis("off")
ax.text(
0.5,
-0.35,
str(np.round(z_sample[i, :], 1)),
fontsize=10,
ha="center",
transform=ax.transAxes,
)
ax.imshow(reconstructions[i, :, :], cmap="Greys")
In [27]:
# Colour the embeddings by their label (clothing type - see table)
figsize = 12
grid_size = 15
plt.figure(figsize=(figsize, figsize))
plt.scatter(
z[:, 0], z[:, 1], cmap="rainbow", c=example_labels, alpha=0.8, s=50
)
plt.colorbar()
x = np.linspace(min(z[:, 0]), max(z[:, 0]), grid_size)
y = np.linspace(max(z[:, 1]), min(z[:, 1]), grid_size)
x = norm.ppf(np.linspace(0, 1, grid_size))
y = norm.ppf(np.linspace(1, 0, grid_size))
xv, yv = np.meshgrid(x, y)
xv = xv.flatten()
yv = yv.flatten()
grid = np.array(list(zip(xv, yv)))
reconstructions = decoder.predict(grid)
plt.scatter(grid[:, 0], grid[:, 1], c="black", alpha=1, s=10)
plt.show()
fig = plt.figure(figsize=(figsize, figsize))
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i in range(grid_size**2):
ax = fig.add_subplot(grid_size, grid_size, i + 1)
ax.axis("off")
ax.imshow(reconstructions[i, :, :], cmap="Greys")
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step