# SciNet's DAT112, Neural Network Programming.
# Lecture 8, 30 April 2026.
# Erik Spence.
#
# This file, q_cartplole_player.py, contains code used for the example
# used in lecture 8.  It defines a player which will play and learn
# cartpole.
#

########################################################################


import gymnasium as gym

import random
import numpy as np
import numpy.random as npr

import keras.layers as kl
import keras.models as km


#############################################################


class QCartPolePlayer(object):

    def __init__(self):
        """
        Plays CartPole by implementing a NN Q-learning strategy.
        """

        ## Load CartPole environment.
        self.env = gym.make("CartPole-v1", render_mode = 'human')
        
        # The future discount rate.        
        self.future_reward_discount = 0.95

        ## Size of observation space.
        self.state_dim = self.env.observation_space.shape[0]

        ## The number of possible actions (left, right)
        self.num_actions = int(self.env.action_space.n)
        
        ## The probabilities of using a random move, instead of one
        ## from the NN.
        self.final_random_prob = 0.05
        self.random_decay = 0.995
        self.random_action_prob = 1.0

        ## Build the neural network.
        self.build_model()

        ## Size of the observations collection.
        self.max_obs_length = 10000
        self.observations = []
        
        self.mini_batch_size = 128
        
        # Have we starting training the NN yet?
        self.started_training = False
        

    ####################################

    
    def build_model(self):

        ## This function builds the NN.
        
        ## The two inputs.
        input_state = kl.Input(shape = (self.state_dim,))

        ## Create a NN with three fully connected hidden layers.
        x = kl.Dense(24, activation = 'relu')(input_state)
        x = kl.Dense(24, activation = 'relu')(x)
        
        ## The regular output layer, for the standard forward pass of
        ## the input_state.
        q = kl.Dense(self.num_actions, activation = 'linear')(x)

        ## Create the model.
        self.q_model = km.Model(inputs = input_state, outputs = q)

        self.q_model.compile(loss = 'mse' , optimizer = 'adam')
        

    ####################################
    

    def get_next_action(self, state):

        ## Choose the next action.  Either do it randomly, or predict
        ## the best next action using the neural network.
        
        if (npr.rand() < self.random_action_prob):

            action = npr.randint(self.num_actions)

        else:

            new_state = state.reshape(1, len(state))
            action = np.argmax(self.q_model.predict(new_state,
                                                    verbose = 0))
        
        return action

    
    ####################################

    
    def remember(self, observation):

        ## Save observations.  We use these to continuously train the
        ## NN.
        
        if len(self.observations) == self.max_obs_length:
            self.observations = self.observations[1:]

        self.observations.append(observation)

        
    ####################################

    
    def train_model(self):

        if len(self.observations) < self.mini_batch_size:
            return

        if (not self.started_training):
            print('Begin training')
            self.started_training = True

        ## Sample a mini-batch of observations on which to train.
        mini_batch = random.sample(self.observations,
                                   self.mini_batch_size)
        
        # Take the mini-batch apart.
        previous_states, actions, rewards, current_states, dones = \
                                list(zip(*mini_batch))

        ## For those actions which are not picked, we assign a target
        ## of the predicted value of the neural network, with no
        ## reward for choosing this action.
        q_values = self.q_model.predict(np.array(previous_states),
                                        verbose = 0)

        ## For those actions which are picked, we use this target,
        ## which includes rewards and determination of whether or not
        ## the chosen action resulted in an ending of the game.
        q_update = rewards + \
                   (1 - np.array(dones)) * self.future_reward_discount * \
                   np.max(self.q_model.predict(np.array(current_states),
                                               verbose = 0),
                          axis = 1)

        ## For those actions which were chosen, assign the above
        ## values.
        q_values[range(self.mini_batch_size), actions] = q_update

        ## Fit.
        self.q_model.fit(np.array(previous_states), q_values,
                         verbose = 0)

        ## Update the randomness variables.
        self.random_action_prob *= self.random_decay
        self.random_action_prob = max(self.final_random_prob,
                                      self.random_action_prob)
    

#############################################################


if __name__ == "__main__" :

    ## Create the cartpole object.
    player = QCartPolePlayer()

    ## We'll play the game 200 times.
    for i in range(200):

        ## Reset the state of the environment, to the starting
        ## defaults.
        state = player.env.reset()[0]

        ## Total score, kept just for kicks.
        total = 0
        
        ## Run until this game is done.
        done = False
        j = 0
        while not done:

            ## Update the screen.
            player.env.render()

            ## Grab the current state
            prev_state = state

            ## Determine our choice of the next action, given the
            ## current state.
            action = player.get_next_action(prev_state)
            
            ## Step the game forward, given that choice of action.
            state, reward, done, _, info =  player.env.step(action)

            ## If the game ended, penalize the player.
            reward = reward if not done else -reward

            ## Add the last choice of action to our collection of
            ## observations.
            player.remember([prev_state, action, reward, state, done])

            ## Keep a running score.
            total += reward

            ## Update the neural network, given the new data.
            if ((j % 10) == 0):
                player.train_model()

            j += 1

        print(i, 'reward is ', total)


#############################################################
