#
# Compute Ontario Summer School
# Machine Learning in Python
# 13 June 2018
# Erik Spence
#
# This file, plot_residuals_hist.py, contains some functions for
# plotting the the fit to some noisy data, and then alternatively
# either the histogram of the residuals or the residuals vs x.
# Severals such plot are generated for the class.
#

#######################################################################


import numpy as np
import matplotlib.pylab as plt

# The module containing the function which generates the data.
import regression as reg


# Limits for the plots.
MIN = -1.0
MAX = 1.0


#######################################################################


def calc_residuals(x, y, degree):

    """
    Given some x and y data, calculate the polynomial fit to the data,
    given the polynomial degree, and then return the fit and the
    residuals.

    """
    
    # Fit to the new data.
    p = np.polyfit(x, y, degree)
    
    # Create a polynomial representation.
    fit = np.poly1d(p)
    
    # Calculate and return the residuals.
    return fit, y - fit(x)


#######################################################################


def make_hist_plot(x, y, x2, fit, res, degree):

    """
    Given x and y data, the polynomial fit to that data, the fit
    residuals, the polynomial degree, and some x points for plotting
    the fit, plot two subplots on a single plot:

    1) The data and the polynomial fit, with the degree of the fit
    indicated above the plot
    
    2) The histogram of the residuals.

    The plot is then saved.

    """

    # Create empty figure, and set up first subplot.
    plt.figure(figsize = (3, 3))
    plt.subplot(2, 1, 1)

    # Put the title over the subplot.
    plt.title("Degree " + str(degree))

    # Plot the data.
    plt.plot(x, y, 'ko', ms = 4)
    plt.xlim((MIN, MAX))        

    # Plot the fit.
    plt.plot(x2, fit(x2), 'g-')

    # Switch to the lower plot.
    plt.subplot(2,1,2)

    # Plot the histogram of the residuals.
    num, bins, patches = plt.hist(res)

    # Create the file name for the image.
    f = 'images/fit_residuals_hist_' + str(degree) + '.pdf'

    # Clean up and save the figure.
    plt.tight_layout(pad = 0.1)
    #.#plt.savefig(f, transparent = True)
    #.#plt.close()


#######################################################################
    

def make_res_plot(x, y, x2, fit, res, degree):

    """
    Given x and y data, the polynomial fit to that data, the fit
    residuals, the polynomial degree, and some x points for plotting
    the fit, plot two subplots on a single plot:

    1) The data and the polynomial fit, with the degree of the fit
    indicated above the plot
    
    2) The residuals versus x.

    The plot is then saved.

    """
    
    # Create empty figure, and set up first subplot.
    plt.figure(figsize = (3, 3))
    plt.subplot(2,1,1)
    
    # Put the title over the subplot.
    plt.title("Degree " + str(degree))

    # Plot the data.
    plt.plot(x, y, 'ko', ms = 4)
    plt.xlim((MIN, MAX))        

    # Plot the fit.
    plt.plot(x2, fit(x2), 'g-')

    # Switch to the lower plot.
    plt.subplot(2,1,2)

    # Plot the residuals, and put in a horizontal line a y = 0.
    plt.plot(x, res, 'ko', ms = 4)
    plt.axhline(0)
    
    # Create the file name for the image.
    f = 'images/fit_residuals_' + str(degree) + '.pdf'

    # Clean up and save the figure.
    plt.tight_layout(pad = 0.1)
    #.#plt.savefig(f, transparent = True)
    #.#plt.close()

    
#######################################################################


if __name__ == "__main__":

    # The number of point in the fit.
    n = 40
    
    # Get new data.
    x, y = reg.noisy_data(n)

    # The x values for the fit line.
    x2 = np.linspace(-1, 1, 100)

    # Degree 1 fit and histogram.
    fit1, res1 = calc_residuals(x, y, 1)
    make_hist_plot(x, y, x2, fit1, res1, 1)

    # Degree 13 fit, histogram and residuals.
    fit13, res13 = calc_residuals(x, y, 13)
    make_hist_plot(x, y, x2, fit13, res13, 13)
    make_res_plot(x, y, x2, fit13, res13, 13)

    # Degree 3 fit and residuals.
    fit3, res3 = calc_residuals(x, y, 3)
    make_res_plot(x, y, x2, fit3, res3, 3)


