In [13]:
import numpy as np
from numpy import random
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models, datasets, callbacks, losses, optimizers, metrics
import tensorflow.keras.backend as K
In [2]:
IMAGE_SIZE = 32
BATCH_SIZE= 100
VALIDATION_SPLIT = 0.2
EMBEDDING_DIM = 2

BETA = 500
In [3]:
(x_train,y_train),(x_test,y_test) = datasets.fashion_mnist.load_data()
x_train = x_train.astype("float32")/255
x_train = np.pad(x_train, ((0,0), (2,2), (2,2)), constant_values=0.0)
x_train = np.expand_dims(x_train,-1)
x_test = x_test.astype("float32")/255
x_test = np.pad(x_test, ((0,0), (2,2), (2,2)), constant_values=0.0)
x_test = np.expand_dims(x_test,-1)

def display(
    images, n=10, size=(20, 3), cmap="gray_r", as_type="float32", save_to=None
):
    """
    Displays n random images from each one of the supplied arrays.
    """
    if images.max() > 1.0:
        images = images / 255.0
    elif images.min() < 0.0:
        images = (images + 1.0) / 2.0

    plt.figure(figsize=size)
    for i in range(n):
        _ = plt.subplot(1, n, i + 1)
        plt.imshow(images[i].astype(as_type), cmap=cmap)
        plt.axis("off")

    if save_to:
        plt.savefig(save_to)
        print(f"\nSaved to {save_to}")

    plt.show()

display(x_train)
No description has been provided for this image

Before we build the VAE, we need to define a new layer for sampling the latent vectors (z_mean, z_log_var)

In [14]:
class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = K.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
In [15]:
# Encoder
encoder_input = layers.Input(
    shape=(IMAGE_SIZE, IMAGE_SIZE, 1), name="encoder_input"
)
x = layers.Conv2D(32, (3, 3), strides=2, activation="relu", padding="same")(
    encoder_input
)
x = layers.Conv2D(64, (3, 3), strides=2, activation="relu", padding="same")(x)
x = layers.Conv2D(128, (3, 3), strides=2, activation="relu", padding="same")(x)
 
x = layers.Flatten()(x)
z_mean = layers.Dense(EMBEDDING_DIM, name="z_mean")(x)
z_log_var = layers.Dense(EMBEDDING_DIM, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])

encoder = models.Model(encoder_input, [z_mean, z_log_var, z], name="encoder")
encoder.summary()
Model: "encoder"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃    Param # ┃ Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ encoder_input       │ (None, 32, 32, 1) │          0 │ -                 │
│ (InputLayer)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_3 (Conv2D)   │ (None, 16, 16,    │        320 │ encoder_input[0]… │
│                     │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_4 (Conv2D)   │ (None, 8, 8, 64)  │     18,496 │ conv2d_3[0][0]    │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_5 (Conv2D)   │ (None, 4, 4, 128) │     73,856 │ conv2d_4[0][0]    │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ flatten_1 (Flatten) │ (None, 2048)      │          0 │ conv2d_5[0][0]    │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ z_mean (Dense)      │ (None, 2)         │      4,098 │ flatten_1[0][0]   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ z_log_var (Dense)   │ (None, 2)         │      4,098 │ flatten_1[0][0]   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ sampling_1          │ (None, 2)         │          0 │ z_mean[0][0],     │
│ (Sampling)          │                   │            │ z_log_var[0][0]   │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 100,868 (394.02 KB)
 Trainable params: 100,868 (394.02 KB)
 Non-trainable params: 0 (0.00 B)
In [16]:
# Decoder
decoder_input = layers.Input(shape=(EMBEDDING_DIM,), name="decoder_input")
x = layers.Dense(2048)(decoder_input)
x = layers.Reshape((4,4,128))(x)
x = layers.Conv2DTranspose(
    128, (3, 3), strides=2, activation="relu", padding="same"
)(x)
x = layers.Conv2DTranspose(
    64, (3, 3), strides=2, activation="relu", padding="same"
)(x)
x = layers.Conv2DTranspose(
    32, (3, 3), strides=2, activation="relu", padding="same"
)(x)
decoder_output = layers.Conv2D(
    1,
    (3, 3),
    strides=1,
    activation="sigmoid",
    padding="same",
    name="decoder_output",
)(x)

decoder = models.Model(decoder_input, decoder_output)
decoder.summary()
Model: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape           ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ decoder_input (InputLayer)      │ (None, 2)              │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_1 (Dense)                 │ (None, 2048)           │         6,144 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ reshape_1 (Reshape)             │ (None, 4, 4, 128)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_transpose_3              │ (None, 8, 8, 128)      │       147,584 │
│ (Conv2DTranspose)               │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_transpose_4              │ (None, 16, 16, 64)     │        73,792 │
│ (Conv2DTranspose)               │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_transpose_5              │ (None, 32, 32, 32)     │        18,464 │
│ (Conv2DTranspose)               │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ decoder_output (Conv2D)         │ (None, 32, 32, 1)      │           289 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 246,273 (962.00 KB)
 Trainable params: 246,273 (962.00 KB)
 Non-trainable params: 0 (0.00 B)
In [17]:
class VAE(models.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def call(self, inputs):
        """Call the model on a particular input."""
        z_mean, z_log_var, z = encoder(inputs)
        reconstruction = decoder(z)
        return z_mean, z_log_var, reconstruction

    def train_step(self, data):
        """Step run during training."""
        with tf.GradientTape() as tape:
            z_mean, z_log_var, reconstruction = self(data)
            reconstruction_loss = tf.reduce_mean(
                BETA
                * losses.binary_crossentropy(
                    data, reconstruction, axis=(1, 2, 3)
                )
            )
            kl_loss = tf.reduce_mean(
                tf.reduce_sum(
                    -0.5
                    * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)),
                    axis=1,
                )
            )
            total_loss = reconstruction_loss + kl_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        """Step run during validation."""
        if isinstance(data, tuple):
            data = data[0]

        z_mean, z_log_var, reconstruction = self(data)
        reconstruction_loss = tf.reduce_mean(
            BETA
            * losses.binary_crossentropy(data, reconstruction, axis=(1, 2, 3))
        )
        kl_loss = tf.reduce_mean(
            tf.reduce_sum(
                -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)),
                axis=1,
            )
        )
        total_loss = reconstruction_loss + kl_loss

        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }
In [18]:
vae = VAE(encoder,decoder)
In [19]:
optimizer = optimizers.Adam(learning_rate=0.0005)
vae.compile(optimizer=optimizer)
In [20]:
callback_ckpt = [
    keras.callbacks.ModelCheckpoint(
        filepath = "vae.keras",
        save_best_only = True,
        verbose = 1,
        mode = "min",
        monitor = "val_loss")]
callback_stopping = [
    keras.callbacks.EarlyStopping(
        monitor = "val_loss",
        patience = 5,
    )
]


history = vae.fit(
    x_train,
    epochs=50,
    batch_size=BATCH_SIZE,
    shuffle=True,
    validation_data=(x_test, x_test),
    callbacks = [callback_ckpt, callback_stopping],
)
Epoch 1/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - kl_loss: 3.6191 - reconstruction_loss: 197.0072 - total_loss: 200.6263
Epoch 1: val_loss improved from inf to 143.72160, saving model to vae.keras
600/600 ━━━━━━━━━━━━━━━━━━━━ 42s 67ms/step - kl_loss: 3.6204 - reconstruction_loss: 196.9417 - total_loss: 200.5622 - val_kl_loss: 4.8065 - val_loss: 143.7216 - val_reconstruction_loss: 138.9151
Epoch 2/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - kl_loss: 4.8654 - reconstruction_loss: 133.1712 - total_loss: 138.0367
Epoch 2: val_loss improved from 143.72160 to 139.37596, saving model to vae.keras
600/600 ━━━━━━━━━━━━━━━━━━━━ 45s 74ms/step - kl_loss: 4.8655 - reconstruction_loss: 133.1698 - total_loss: 138.0353 - val_kl_loss: 5.1290 - val_loss: 139.3760 - val_reconstruction_loss: 134.2469
Epoch 3/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 72ms/step - kl_loss: 5.0043 - reconstruction_loss: 130.0686 - total_loss: 135.0729
Epoch 3: val_loss improved from 139.37596 to 137.09091, saving model to vae.keras
600/600 ━━━━━━━━━━━━━━━━━━━━ 46s 76ms/step - kl_loss: 5.0043 - reconstruction_loss: 130.0680 - total_loss: 135.0723 - val_kl_loss: 5.2566 - val_loss: 137.0909 - val_reconstruction_loss: 131.8344
Epoch 4/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - kl_loss: 5.0656 - reconstruction_loss: 128.6648 - total_loss: 133.7305
Epoch 4: val_loss improved from 137.09091 to 136.22583, saving model to vae.keras
600/600 ━━━━━━━━━━━━━━━━━━━━ 43s 71ms/step - kl_loss: 5.0657 - reconstruction_loss: 128.6645 - total_loss: 133.7302 - val_kl_loss: 5.0631 - val_loss: 136.2258 - val_reconstruction_loss: 131.1627
Epoch 5/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - kl_loss: 5.1394 - reconstruction_loss: 127.7582 - total_loss: 132.8975
Epoch 5: val_loss improved from 136.22583 to 135.71503, saving model to vae.keras
600/600 ━━━━━━━━━━━━━━━━━━━━ 43s 72ms/step - kl_loss: 5.1394 - reconstruction_loss: 127.7581 - total_loss: 132.8974 - val_kl_loss: 5.1251 - val_loss: 135.7150 - val_reconstruction_loss: 130.5899
Epoch 6/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - kl_loss: 5.1423 - reconstruction_loss: 127.0273 - total_loss: 132.1695
Epoch 6: val_loss improved from 135.71503 to 134.45528, saving model to vae.keras
600/600 ━━━━━━━━━━━━━━━━━━━━ 43s 71ms/step - kl_loss: 5.1423 - reconstruction_loss: 127.0276 - total_loss: 132.1699 - val_kl_loss: 5.2265 - val_loss: 134.4553 - val_reconstruction_loss: 129.2288
Epoch 7/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - kl_loss: 5.1915 - reconstruction_loss: 126.8037 - total_loss: 131.9953
Epoch 7: val_loss did not improve from 134.45528
600/600 ━━━━━━━━━━━━━━━━━━━━ 45s 75ms/step - kl_loss: 5.1915 - reconstruction_loss: 126.8038 - total_loss: 131.9953 - val_kl_loss: 5.3299 - val_loss: 135.1975 - val_reconstruction_loss: 129.8676
Epoch 8/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 72ms/step - kl_loss: 5.2241 - reconstruction_loss: 126.0335 - total_loss: 131.2575
Epoch 8: val_loss did not improve from 134.45528
600/600 ━━━━━━━━━━━━━━━━━━━━ 46s 76ms/step - kl_loss: 5.2241 - reconstruction_loss: 126.0342 - total_loss: 131.2582 - val_kl_loss: 5.3514 - val_loss: 134.6660 - val_reconstruction_loss: 129.3146
Epoch 9/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - kl_loss: 5.2268 - reconstruction_loss: 126.4103 - total_loss: 131.6370
Epoch 9: val_loss improved from 134.45528 to 134.22743, saving model to vae.keras
600/600 ━━━━━━━━━━━━━━━━━━━━ 42s 71ms/step - kl_loss: 5.2268 - reconstruction_loss: 126.4100 - total_loss: 131.6367 - val_kl_loss: 5.3149 - val_loss: 134.2274 - val_reconstruction_loss: 128.9125
Epoch 10/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - kl_loss: 5.2455 - reconstruction_loss: 126.0977 - total_loss: 131.3432
Epoch 10: val_loss improved from 134.22743 to 134.18462, saving model to vae.keras
600/600 ━━━━━━━━━━━━━━━━━━━━ 42s 70ms/step - kl_loss: 5.2455 - reconstruction_loss: 126.0974 - total_loss: 131.3429 - val_kl_loss: 5.5390 - val_loss: 134.1846 - val_reconstruction_loss: 128.6456
Epoch 11/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - kl_loss: 5.2846 - reconstruction_loss: 125.8529 - total_loss: 131.1376
Epoch 11: val_loss improved from 134.18462 to 134.08937, saving model to vae.keras
600/600 ━━━━━━━━━━━━━━━━━━━━ 42s 71ms/step - kl_loss: 5.2846 - reconstruction_loss: 125.8528 - total_loss: 131.1374 - val_kl_loss: 5.3501 - val_loss: 134.0894 - val_reconstruction_loss: 128.7393
Epoch 12/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - kl_loss: 5.2762 - reconstruction_loss: 125.6316 - total_loss: 130.9078
Epoch 12: val_loss improved from 134.08937 to 133.87436, saving model to vae.keras
600/600 ━━━━━━━━━━━━━━━━━━━━ 42s 70ms/step - kl_loss: 5.2762 - reconstruction_loss: 125.6314 - total_loss: 130.9076 - val_kl_loss: 5.3252 - val_loss: 133.8744 - val_reconstruction_loss: 128.5491
Epoch 13/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - kl_loss: 5.3113 - reconstruction_loss: 125.2049 - total_loss: 130.5162
Epoch 13: val_loss improved from 133.87436 to 133.60133, saving model to vae.keras
600/600 ━━━━━━━━━━━━━━━━━━━━ 43s 71ms/step - kl_loss: 5.3113 - reconstruction_loss: 125.2051 - total_loss: 130.5164 - val_kl_loss: 5.3738 - val_loss: 133.6013 - val_reconstruction_loss: 128.2276
Epoch 14/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - kl_loss: 5.3304 - reconstruction_loss: 124.8728 - total_loss: 130.2032
Epoch 14: val_loss improved from 133.60133 to 133.26132, saving model to vae.keras
600/600 ━━━━━━━━━━━━━━━━━━━━ 42s 70ms/step - kl_loss: 5.3304 - reconstruction_loss: 124.8733 - total_loss: 130.2037 - val_kl_loss: 5.4722 - val_loss: 133.2613 - val_reconstruction_loss: 127.7891
Epoch 15/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - kl_loss: 5.3427 - reconstruction_loss: 124.8862 - total_loss: 130.2288
Epoch 15: val_loss improved from 133.26132 to 133.07672, saving model to vae.keras
600/600 ━━━━━━━━━━━━━━━━━━━━ 44s 73ms/step - kl_loss: 5.3427 - reconstruction_loss: 124.8863 - total_loss: 130.2290 - val_kl_loss: 5.5144 - val_loss: 133.0767 - val_reconstruction_loss: 127.5624
Epoch 16/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - kl_loss: 5.3468 - reconstruction_loss: 124.8457 - total_loss: 130.1925
Epoch 16: val_loss did not improve from 133.07672
600/600 ━━━━━━━━━━━━━━━━━━━━ 44s 73ms/step - kl_loss: 5.3468 - reconstruction_loss: 124.8456 - total_loss: 130.1924 - val_kl_loss: 5.5070 - val_loss: 133.0948 - val_reconstruction_loss: 127.5878
Epoch 17/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - kl_loss: 5.3613 - reconstruction_loss: 124.6065 - total_loss: 129.9678
Epoch 17: val_loss did not improve from 133.07672
600/600 ━━━━━━━━━━━━━━━━━━━━ 45s 75ms/step - kl_loss: 5.3613 - reconstruction_loss: 124.6067 - total_loss: 129.9680 - val_kl_loss: 5.7375 - val_loss: 133.5044 - val_reconstruction_loss: 127.7669
Epoch 18/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - kl_loss: 5.3761 - reconstruction_loss: 124.4428 - total_loss: 129.8188
Epoch 18: val_loss improved from 133.07672 to 132.99593, saving model to vae.keras
600/600 ━━━━━━━━━━━━━━━━━━━━ 45s 74ms/step - kl_loss: 5.3761 - reconstruction_loss: 124.4430 - total_loss: 129.8191 - val_kl_loss: 5.5359 - val_loss: 132.9959 - val_reconstruction_loss: 127.4600
Epoch 19/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - kl_loss: 5.3926 - reconstruction_loss: 124.4739 - total_loss: 129.8665
Epoch 19: val_loss did not improve from 132.99593
600/600 ━━━━━━━━━━━━━━━━━━━━ 42s 70ms/step - kl_loss: 5.3926 - reconstruction_loss: 124.4739 - total_loss: 129.8665 - val_kl_loss: 5.3588 - val_loss: 133.1993 - val_reconstruction_loss: 127.8404
Epoch 20/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 72ms/step - kl_loss: 5.4165 - reconstruction_loss: 124.0994 - total_loss: 129.5159
Epoch 20: val_loss did not improve from 132.99593
600/600 ━━━━━━━━━━━━━━━━━━━━ 45s 76ms/step - kl_loss: 5.4165 - reconstruction_loss: 124.0997 - total_loss: 129.5163 - val_kl_loss: 5.6431 - val_loss: 133.0858 - val_reconstruction_loss: 127.4427
Epoch 21/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - kl_loss: 5.4349 - reconstruction_loss: 124.1673 - total_loss: 129.6022
Epoch 21: val_loss improved from 132.99593 to 132.51706, saving model to vae.keras
600/600 ━━━━━━━━━━━━━━━━━━━━ 43s 72ms/step - kl_loss: 5.4349 - reconstruction_loss: 124.1672 - total_loss: 129.6022 - val_kl_loss: 5.5932 - val_loss: 132.5171 - val_reconstruction_loss: 126.9238
Epoch 22/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - kl_loss: 5.4462 - reconstruction_loss: 124.2803 - total_loss: 129.7265
Epoch 22: val_loss did not improve from 132.51706
600/600 ━━━━━━━━━━━━━━━━━━━━ 39s 65ms/step - kl_loss: 5.4462 - reconstruction_loss: 124.2800 - total_loss: 129.7262 - val_kl_loss: 5.5098 - val_loss: 132.6221 - val_reconstruction_loss: 127.1123
Epoch 23/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - kl_loss: 5.4536 - reconstruction_loss: 124.0727 - total_loss: 129.5263
Epoch 23: val_loss improved from 132.51706 to 132.10178, saving model to vae.keras
600/600 ━━━━━━━━━━━━━━━━━━━━ 38s 63ms/step - kl_loss: 5.4536 - reconstruction_loss: 124.0726 - total_loss: 129.5262 - val_kl_loss: 5.5333 - val_loss: 132.1018 - val_reconstruction_loss: 126.5685
Epoch 24/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - kl_loss: 5.4688 - reconstruction_loss: 123.9932 - total_loss: 129.4619
Epoch 24: val_loss did not improve from 132.10178
600/600 ━━━━━━━━━━━━━━━━━━━━ 38s 63ms/step - kl_loss: 5.4688 - reconstruction_loss: 123.9930 - total_loss: 129.4617 - val_kl_loss: 5.5359 - val_loss: 132.7426 - val_reconstruction_loss: 127.2067
Epoch 25/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - kl_loss: 5.4781 - reconstruction_loss: 123.9249 - total_loss: 129.4030
Epoch 25: val_loss did not improve from 132.10178
600/600 ━━━━━━━━━━━━━━━━━━━━ 38s 64ms/step - kl_loss: 5.4781 - reconstruction_loss: 123.9248 - total_loss: 129.4029 - val_kl_loss: 5.5658 - val_loss: 132.4855 - val_reconstruction_loss: 126.9197
Epoch 26/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - kl_loss: 5.4895 - reconstruction_loss: 123.9124 - total_loss: 129.4019
Epoch 26: val_loss did not improve from 132.10178
600/600 ━━━━━━━━━━━━━━━━━━━━ 38s 64ms/step - kl_loss: 5.4895 - reconstruction_loss: 123.9121 - total_loss: 129.4016 - val_kl_loss: 5.5336 - val_loss: 132.5003 - val_reconstruction_loss: 126.9667
Epoch 27/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - kl_loss: 5.4955 - reconstruction_loss: 123.6972 - total_loss: 129.1927
Epoch 27: val_loss did not improve from 132.10178
600/600 ━━━━━━━━━━━━━━━━━━━━ 38s 64ms/step - kl_loss: 5.4955 - reconstruction_loss: 123.6970 - total_loss: 129.1925 - val_kl_loss: 5.3785 - val_loss: 132.5270 - val_reconstruction_loss: 127.1486
Epoch 28/50
600/600 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - kl_loss: 5.5002 - reconstruction_loss: 123.7011 - total_loss: 129.2012
Epoch 28: val_loss did not improve from 132.10178
600/600 ━━━━━━━━━━━━━━━━━━━━ 38s 63ms/step - kl_loss: 5.5002 - reconstruction_loss: 123.7009 - total_loss: 129.2011 - val_kl_loss: 5.5304 - val_loss: 132.2309 - val_reconstruction_loss: 126.7005
In [21]:
#best_model = keras.models.load_model("models/VAE.keras")
#test_loss = best_model.evaluate(x_test,x_test)
#print(f"test loss = {test_loss:.3f}")



def training_curves(history):
    train_loss = history.history["total_loss"]
    val_loss = history.history["val_loss"]
 
    epochs = np.arange(1,len(train_loss) + 1)


    plt.plot(epochs, train_loss, "b--", label="p-training total loss")
    plt.plot(epochs, val_loss, "b", label="validation total loss")
    plt.title("p-Training and Validation VAE total Losses")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()


  

training_curves(history)
No description has been provided for this image
In [22]:
n_to_predict = 5000
example_images = x_test[:n_to_predict]
example_labels = y_test[:n_to_predict]
z_mean, z_log_var, reconstructions = vae.predict(example_images)

print("Example real clothing items")
display(example_images)
print("Reconstructions")
display(reconstructions)
157/157 ━━━━━━━━━━━━━━━━━━━━ 2s 11ms/step
Example real clothing items
No description has been provided for this image
Reconstructions
No description has been provided for this image

get encoder "samples" of latent space for each example image

In [23]:
# Encode the example images
z_mean, z_var, z = encoder.predict(example_images)

figsize = 8
fig = plt.figure(figsize=(figsize * 2, figsize))
ax = fig.add_subplot(1, 2, 1)
plot_1 = ax.scatter(
    z[:, 0], z[:, 1], cmap="rainbow", c=example_labels, alpha=0.8, s=3
)
plt.colorbar(plot_1)

# Convert original embeddings and sampled embeddings to p-values
from scipy.stats import norm
p = norm.cdf(z)
157/157 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step
No description has been provided for this image
In [24]:
# Sample some points in the latent space, from the standard normal distribution
grid_width, grid_height = (6, 3)
z_sample = np.random.normal(size=(grid_width * grid_height, 2))

# Decode the sampled points
reconstructions = decoder.predict(z_sample)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step
In [25]:
# Draw a plot of...
figsize = 8
plt.figure(figsize=(figsize, figsize))

# ... the original embeddings ...
plt.scatter(z[:, 0], z[:, 1], c="black", alpha=0.5, s=2)

# ... and the newly generated points in the latent space
plt.scatter(z_sample[:, 0], z_sample[:, 1], c="#00B0F0", alpha=1, s=40)
plt.show()

# Add underneath a grid of the decoded images
fig = plt.figure(figsize=(figsize, grid_height * 2))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

for i in range(grid_width * grid_height):
    ax = fig.add_subplot(grid_height, grid_width, i + 1)
    ax.axis("off")
    ax.text(
        0.5,
        -0.35,
        str(np.round(z_sample[i, :], 1)),
        fontsize=10,
        ha="center",
        transform=ax.transAxes,
    )
    ax.imshow(reconstructions[i, :, :], cmap="Greys")
No description has been provided for this image
No description has been provided for this image
In [27]:
# Colour the embeddings by their label (clothing type - see table)
figsize = 12
grid_size = 15
plt.figure(figsize=(figsize, figsize))
plt.scatter(
    z[:, 0], z[:, 1], cmap="rainbow", c=example_labels, alpha=0.8, s=50
)
plt.colorbar()

x = np.linspace(min(z[:, 0]), max(z[:, 0]), grid_size)
y = np.linspace(max(z[:, 1]), min(z[:, 1]), grid_size)
x = norm.ppf(np.linspace(0, 1, grid_size))
y = norm.ppf(np.linspace(1, 0, grid_size))
xv, yv = np.meshgrid(x, y)
xv = xv.flatten()
yv = yv.flatten()
grid = np.array(list(zip(xv, yv)))

reconstructions = decoder.predict(grid)
plt.scatter(grid[:, 0], grid[:, 1], c="black", alpha=1, s=10)
plt.show()

fig = plt.figure(figsize=(figsize, figsize))
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i in range(grid_size**2):
    ax = fig.add_subplot(grid_size, grid_size, i + 1)
    ax.axis("off")
    ax.imshow(reconstructions[i, :, :], cmap="Greys")
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step
No description has been provided for this image
No description has been provided for this image