#
# SciNet's DAT112, Neural Network Programming.
# Lecture 6, 23 April 2026.
# Erik Spence.
#
# This file, mnist_vae.py, contains code used for lecture 6.  It is a
# script which builds and trains a variational autoencoder, and
# applies it to MNIST data.  Some of the code has been stolen from
# https://keras.io/examples/generative/vae/
#

import tensorflow as tf

## These commands are needed if using a GPU on Mist.
#gpus = tf.config.list_physical_devices('GPU')
#for gpu in gpus:
#      tf.config.experimental.set_memory_growth(gpu, True)

import keras.models as km
import keras.layers as kl
import keras.ops as ko
import keras.metrics as kmetrics
import keras.random as kr

from keras.datasets import mnist

## the below two lines are needed for current versions of tensorflow
## with slightly older versions of numpy.
#from tensorflow.python.framework.ops import disable_eager_execution
#disable_eager_execution()

import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "3"


####################################################################


# The size of the latent space.
latent_dim = 100

# Get the data.
(x_train, _), (x_test, _) = mnist.load_data()

# Reshape and scale.
x_train = x_train.astype('float32').reshape(60000, 28, 28, 1) / 255.
x_test = x_test.astype('float32').reshape(10000, 28, 28, 1) / 255.


####################################################################


def sampling(args):
    
    """This function reparameterizes the random sampling which is needed
    to feed the decoding network.

    """

    # Take apart the input arguments.
    z_mean, z_log_var = args

    # By default, random_normal has mean = 0 and std = 1.0.  Generate
    # some random numbers of the correct shape.
    #epsilon = K.random_normal(K.shape(z_mean))
    epsilon = kr.normal(ko.shape(z_mean))

    # Return the reparameterized result.
    return z_mean + ko.exp(0.5 * z_log_var) * epsilon


####################################################################

# The Encoder

# The image inputted into the encoder.
input_image = kl.Input(shape = (28, 28, 1))

# Add a 2D convolutional layer, with 16 feature maps.
# input size = 28 x 28 x 1
# output size = 28 x 28 x 16
x = kl.Conv2D(16, kernel_size = (3, 3),
              activation = 'relu',
              padding = 'same')(input_image)

# Add a max pooling layer.
# input size 28 x 28 x 16
# output size 14 x 14 x 16
x = kl.MaxPooling2D(pool_size = (2, 2),
                    strides = (2, 2))(x)

# Add a 2D convolutional layer, with 32 feature maps.
# input size = 14 x 14 x 16
# output size = 14 x 14 x 32
x = kl.Conv2D(32, kernel_size = (3, 3),
              activation = 'relu',
              padding = 'same')(x)

# Add a max pooling layer.
# input size 14 x 14 x 32
# output size 7 x 7 x 32
x = kl.MaxPooling2D(pool_size = (2, 2),
                    strides = (2, 2))(x)

# Flatten the output so that it can be fed into the two output layers.
x = kl.Flatten()(x)

# The two output layers for the encoder.
z_mean = kl.Dense(latent_dim, activation = 'linear')(x)
z_log_var = kl.Dense(latent_dim, activation = 'linear')(x)


# Create a layer which applies the sampling function to the previous
# to inputs.
z = kl.Lambda(sampling, output_shape = (latent_dim,))([z_mean, z_log_var])


####################################################################

# The Decoder

# The decoder input.
decoder_input = kl.Input(shape = (latent_dim,))

# A fully-connected layer, to bulk things up to start.
x = kl.Dense(7 * 7 * 32, activation = 'relu')(decoder_input)

# Reshape to the correct starting shape.
x = kl.Reshape((7, 7, 32))(x)

# Add a 2D transpose convolutional layer, with 32 feature maps.
# input size = 7 x 7 x 32
# output size = 7 x 7 x 32
x = kl.Conv2DTranspose(32, kernel_size = (3, 3),
                       activation = 'relu',
                       padding = 'same')(x)

# Add upsampling
# input size = 7 x 7 x 32
# output size = 14 x 14 x 32
x = kl.UpSampling2D(size = (2, 2))(x)

# Add a 2D transpose convolutional layer, with 16 feature maps.
# input size = 14 x 14 x 32
# output size = 14 x 14 x 16
x = kl.Conv2DTranspose(16, kernel_size = (3, 3),
                       activation = 'relu',
                       padding = 'same')(x)

# Add upsampling
# input size = 14 x 14 x 16
# output size = 28 x 28 x 16
x = kl.UpSampling2D(size = (2, 2))(x)


# Add a 2D transpose convolutional layer, with 1 feature map.  This is
# the decoder output.
# input size = 28 x 28 x 16
# output size = 28 x 28 x 1
decoded = kl.Conv2DTranspose(1, kernel_size = (3, 3),
                             activation = 'sigmoid',
                             padding = 'same')(x)


####################################################################

# Build the models
encoder = km.Model(inputs = input_image,
                   outputs = [z_mean, z_log_var, z])
decoder = km.Model(inputs = decoder_input,
                   outputs = decoded)


####################################################################


class VAE(km.Model):

    ## Initialize our VAE class.
    def __init__(self, encoder, decoder, **kwargs):

        ## Initialize the km.Model class.
        super().__init__(**kwargs)

        ## Add the encoder and decoder.
        self.encoder = encoder
        self.decoder = decoder

        ## Add some loss trackers.
        self.total_loss_tracker = kmetrics.Mean()
        self.reconstruction_loss_tracker = kmetrics.Mean()
        self.kl_loss_tracker = kmetrics.Mean()


    ## The metrics that we're going to track.
    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]


    ## Alas, the training step needs to be built explicitly.
    def train_step(self, input_image):

        ## Build the part of the training loop that involves
        ## gradients, which are needed for SGD.
        with tf.GradientTape() as tape:

            ## Run the data through the encoder.
            z_mean, z_log_var, z = self.encoder(input_image)

            ## Run the sampled data, z, through the decoder, to get
            ## the output image.
            output_image = self.decoder(z)

            ## Calculate the reconstruction loss.  We'll use
            ## mean-squared error.  We multiply by 784 to scale the
            ## mean back up to a higher level, otherwise it's too
            ## small to matter.
            r_loss = 784 * ko.mean(ko.square(input_image - output_image))

            ## Build the KL loss function.
            kl_loss = 1 + z_log_var - ko.square(z_mean) - ko.exp(z_log_var)
            kl_loss = -0.5 * ko.sum(kl_loss, axis = -1)

            ## Put together the total loss.
            total_loss = r_loss + kl_loss

        ## Take the gradient of the total loss, with respect to all of
        ## the trainable weights in the model.
        grads = tape.gradient(total_loss, self.trainable_weights)

        ## Update the weights and biases, using our optimizer.
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        ## Update the loss trackers.
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(r_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        ## Return the losses.
        return {
            "total_loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }


    ## An extra function, so that we can evaluate the test data.
    def evaluate(self, input_image):
        return self.train_step(input_image)


## Build our VAE class, compile and fit.
vae = VAE(encoder, decoder)
vae.compile(optimizer = 'adam')
fit = vae.fit(x_train, epochs = 30, batch_size = 128, verbose = 2)

# Check the test data.
score = vae.evaluate(x_test)
print('score is', score['total_loss'])

# Save the models.
encoder.save('mnist_encoder.keras')
decoder.save('mnist_decoder.keras')


####################################################################
