# Copyright (c) 2011-2014, B.I.Stepanov Institute of Physics, National Academy
# of Sciences of Belarus.
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

import copy
import numpy
import numpy.random
import random

from PyQt4.QtCore import *

from LiricPerturbanceParams import *
from LiricRetrievalProcess import *
from MatlabDummyProcess import *
from MatlabMultiProcess import *

__all__ = ['LiricPerturbanceProcess']

# *****************************************************************************
class LiricPerturbanceProcess(MatlabMultiProcess):

    # ---- Public methods -----------------------------------------------------
    def __init__(self, lidarInputs, photometerInput, algorithmParams,
        perturbanceParams):
        """Parameters:
          - 'lidarInputs', 'photometerInput', 'algorithmParams': same as in
          'LiricRetrievalProcess'.
          - 'perturbanceParams' : 'LiricPerturbanceParams' instance."""

        MatlabMultiProcess.__init__(self)

        self.processGenerator = LiricPerturbanceGenerator(lidarInputs,
            photometerInput, algorithmParams, perturbanceParams)

    # ---- Private overridden methods -----------------------------------------
    def getMatlabProcessGenerator(self):

        return self.processGenerator.generateMatlabProcesses()

# *****************************************************************************
class LiricPerturbanceGenerator:

    # ---- Public methods -----------------------------------------------------
    def __init__(self, lidarInputs, photometerInput, algorithmParams,
        perturbanceParams):

        self.lidarInputs = lidarInputs
        self.photometerInput = photometerInput
        self.algorithmParams = algorithmParams
        self.perturbanceParams = perturbanceParams

    def generateMatlabProcesses(self):

        # Retrieve aerosol profiles without any input data perturbations.
        mainLiricProcess = LiricRetrievalProcess(self.lidarInputs,
            self.photometerInput, self.algorithmParams)

        yield mainLiricProcess

        # If perturbance analysis is off, one retrieval is just enough.
        if not self.perturbanceParams.perturbanceEnabled:
            return

        # When appended with perturbance-induced dispersion data, this will
        # become the final result of the multi-step retrieval process.
        #
        # Do not modify the original retrieval output instance, because it may
        # still be used by the retrieval output dialog.
        retrievalOutput = copy.deepcopy(mainLiricProcess.getRetrievalOutput())

        isPolarimetric = (len(self.lidarInputs) == 4)

        initDispersionData(retrievalOutput, isPolarimetric)

        for i in range(self.perturbanceParams.evaluations):

            modifiedLidarInputs = [createModifiedLidarInput(lidarInput,
                self.perturbanceParams) for lidarInput in self.lidarInputs]

            modifiedPhotometerInput = createModifiedPhotometerInput(
                self.photometerInput, self.perturbanceParams)

            liricProcess = LiricRetrievalProcess(modifiedLidarInputs,
                modifiedPhotometerInput, self.algorithmParams)

            if isPolarimetric:
                liricProcess.setInitDataPolarimetric(
                    retrievalOutput.profileFine,
                    retrievalOutput.profileSpherical,
                    retrievalOutput.profileSpheroid,
                    [retrievalOutput.ratio355, retrievalOutput.ratio532,
                        retrievalOutput.ratio1064, retrievalOutput.ratio532C])
            else:
                liricProcess.setInitDataSimplified(
                    retrievalOutput.profileFine,
                    retrievalOutput.profileCoarse,
                    [retrievalOutput.ratio355, retrievalOutput.ratio532,
                        retrievalOutput.ratio1064])

            liricProcess.setHideIntermediateOutputs(True)

            yield liricProcess

            updateDispersionData(retrievalOutput,
                liricProcess.getRetrievalOutput(), isPolarimetric)

        finalizeDispersionData(retrievalOutput,
            self.perturbanceParams.evaluations, isPolarimetric)

        writeNoiseParams(retrievalOutput, self.perturbanceParams,
            isPolarimetric)

        # Append the final retrieval result to the list of intermediate results
        # realized during the multi-step retrieval process.
        yield MatlabDummyProcess(retrievalOutput)

# **** Private functions ------------------------------------------------------
def writeNoiseParams(retrievalOutput, perturbanceParams, isPolarimetric):

    retrievalOutput.perturbanceEvaluations = perturbanceParams.evaluations

    retrievalOutput.whiteNoise355 = (perturbanceParams.whiteNoise355 *
        perturbanceParams.whiteNoiseCommon)
    retrievalOutput.whiteNoise532 = (perturbanceParams.whiteNoise532 *
        perturbanceParams.whiteNoiseCommon)
    retrievalOutput.whiteNoise1064 = (perturbanceParams.whiteNoise1064 *
        perturbanceParams.whiteNoiseCommon)
    if isPolarimetric:
        retrievalOutput.whiteNoise532C = (perturbanceParams.whiteNoise532C *
            perturbanceParams.whiteNoiseCommon)

    retrievalOutput.linearNoise355 = (perturbanceParams.linear355 *
        perturbanceParams.linearCommon)
    retrievalOutput.linearNoise532 = (perturbanceParams.linear532 *
        perturbanceParams.linearCommon)
    retrievalOutput.linearNoise1064 = (perturbanceParams.linear1064 *
        perturbanceParams.linearCommon)
    if isPolarimetric:
        retrievalOutput.linearNoise532C = (perturbanceParams.linear532C *
            perturbanceParams.linearCommon)

    retrievalOutput.vFineNoise = perturbanceParams.concentrationFine
    retrievalOutput.vCoarseNoise = perturbanceParams.concentrationCoarse
    retrievalOutput.sphericityNoise = perturbanceParams.concentrationSphericity

    retrievalOutput.f11NoiseFine355 = (perturbanceParams.phaseF11Fine355 *
        perturbanceParams.phaseCommon)
    retrievalOutput.f11NoiseFine532 = (perturbanceParams.phaseF11Fine532 *
        perturbanceParams.phaseCommon)
    retrievalOutput.f11NoiseFine1064 = (perturbanceParams.phaseF11Fine1064 *
        perturbanceParams.phaseCommon)

    if isPolarimetric:
        retrievalOutput.f11NoiseSpherical355 = (
            perturbanceParams.phaseF11Spherical355 *
            perturbanceParams.phaseCommon)
        retrievalOutput.f11NoiseSpherical532 = (
            perturbanceParams.phaseF11Spherical532 *
            perturbanceParams.phaseCommon)
        retrievalOutput.f11NoiseSpherical1064 = (
            perturbanceParams.phaseF11Spherical1064 *
            perturbanceParams.phaseCommon)

        retrievalOutput.f11NoiseSpheroid355 = (
            perturbanceParams.phaseF11Spheroid355 *
            perturbanceParams.phaseCommon)
        retrievalOutput.f11NoiseSpheroid532 = (
            perturbanceParams.phaseF11Spheroid532 *
            perturbanceParams.phaseCommon)
        retrievalOutput.f11NoiseSpheroid1064 = (
            perturbanceParams.phaseF11Spheroid1064 *
            perturbanceParams.phaseCommon)
        retrievalOutput.f22NoiseSpheroid532 = (
            perturbanceParams.phaseF22Spheroid532 *
            perturbanceParams.phaseCommon)

    else:
        retrievalOutput.f11NoiseCoarse355 = (
            perturbanceParams.phaseF11Coarse355 *
            perturbanceParams.phaseCommon)
        retrievalOutput.f11NoiseCoarse532 = (
            perturbanceParams.phaseF11Coarse532 *
            perturbanceParams.phaseCommon)
        retrievalOutput.f11NoiseCoarse1064 = (
            perturbanceParams.phaseF11Coarse1064 *
            perturbanceParams.phaseCommon)

def initDispersionData(mainRetrievalOutput, isPolarimetric):

    for attrPair in getDispersionAttributePairs(isPolarimetric):

        originalValue = getattr(mainRetrievalOutput, attrPair[0])
        # This will be either zero array or zero scalar depending on the data
        # type of 'originalAttr'.
        dispersionInit = originalValue - originalValue

        setattr(mainRetrievalOutput, attrPair[1], dispersionInit)

def updateDispersionData(mainRetrievalOutput, perturbedOutput, isPolarimetric):

    for attrPair in getDispersionAttributePairs(isPolarimetric):

        originalValue = getattr(mainRetrievalOutput, attrPair[0])
        perturbedValue = getattr(perturbedOutput, attrPair[0])
        dispersion = getattr(mainRetrievalOutput, attrPair[1])

        deviation = perturbedValue - originalValue
        updatedDispersion = dispersion + deviation * deviation

        setattr(mainRetrievalOutput, attrPair[1], updatedDispersion)

def finalizeDispersionData(mainRetrievalOutput, evaluations, isPolarimetric):

    for attrPair in getDispersionAttributePairs(isPolarimetric):

        dispersion = getattr(mainRetrievalOutput, attrPair[1])

        finalizedDispersion = numpy.sqrt(dispersion / evaluations)

        setattr(mainRetrievalOutput, attrPair[1], finalizedDispersion)

def getDispersionAttributePairs(isPolarimetric):

    if isPolarimetric:
        return [
            ('profileFine', 'profileDispersionFine'),
            ('profileSpherical', 'profileDispersionSpherical'),
            ('profileSpheroid', 'profileDispersionSpheroid'),
            ('ratio355', 'ratioDispersion355'),
            ('ratio532', 'ratioDispersion532'),
            ('ratio1064', 'ratioDispersion1064'),
            ('ratio532C', 'ratioDispersion532C'),
        ]
    else:
        return [
            ('profileFine', 'profileDispersionFine'),
            ('profileCoarse', 'profileDispersionCoarse'),
            ('ratio355', 'ratioDispersion355'),
            ('ratio532', 'ratioDispersion532'),
            ('ratio1064', 'ratioDispersion1064'),
        ]

def createModifiedLidarInput(mainLidarInput, perturbanceParams):

    lidarInput = copy.deepcopy(mainLidarInput)

    wlSuffix = lidarInput.getChannelId()

    lidarIndex = numpy.arange(lidarInput.firstInputIndex,
        lidarInput.lastInputIndex + 1, dtype = numpy.float64)
    lidarInputSize = len(lidarIndex)
    assert lidarInputSize == len(lidarInput.normalizedSignal)

    # Calculate white noise perturbance for 'lidarInput'.
    noiseAmplitude = (getattr(perturbanceParams, 'whiteNoise' + wlSuffix) *
        perturbanceParams.whiteNoiseCommon)

    # Lidar noise amplitude is equal to 'noiseAmplitude' at the reference
    # point, and also is proportional to distance from the ground squared.
    if noiseAmplitude != 0.0:
        lidarNoise = (
            numpy.random.normal(0.0, noiseAmplitude, lidarInputSize) *
            (lidarIndex / lidarInput.refPointIndex) ** 2)
    else:
        lidarNoise = numpy.zeros(lidarInputSize)

    # Calculate linear perturbance for 'lidarInput'.
    linearFactorBound = (getattr(perturbanceParams, 'linear' + wlSuffix) *
        perturbanceParams.linearCommon)

    linearFactor = random.uniform(-linearFactorBound, linearFactorBound)

    # This coefficient is 'linearFactor' at the ground level and linearly
    # diminishes to 0.0 at the reference point.
    linearPerturbance = linearFactor * (
        numpy.ones(lidarInputSize) - lidarIndex / lidarInput.refPointIndex)

    # Introduce the calculated perturbances into the 'lidarInput' instance.
    lidarInput.normalizedSignal = (
        (lidarInput.normalizedSignal + lidarNoise) *
        (1.0 + linearPerturbance)).astype(numpy.float32)

    return lidarInput

def createModifiedPhotometerInput(mainPhotometerInput, perturbanceParams):

    photometerInput = copy.deepcopy(mainPhotometerInput)

    for wlSuffix in ('355', '532', '1064'):
        for modeSuffix in ('Fine', 'Coarse', 'Spherical', 'Spheroid'):

            f11FactorBound = (getattr(perturbanceParams,
                'phaseF11' + modeSuffix + wlSuffix) *
                perturbanceParams.phaseCommon)

            f11Factor = 1.0 + random.uniform(
                -f11FactorBound, f11FactorBound)

            if wlSuffix == '532' and modeSuffix == 'Spheroid':
                f22By11FactorBound = (
                    perturbanceParams.phaseF22Spheroid532 *
                    perturbanceParams.phaseCommon)

                f22By11Factor = 1.0 + random.uniform(
                    -f22By11FactorBound, f22By11FactorBound)
            else:
                f22By11Factor = 1.0

            photometerInput.modifyPhaseFunction(modeSuffix, wlSuffix,
                f11Factor, f22By11Factor)

    # Modify total aerosol concentrations and aerosol sphericity.
    concentrationDeltaFine = (photometerInput.vFine *
        perturbanceParams.concentrationFine)
    photometerInput.vFine += random.uniform(
        -concentrationDeltaFine, concentrationDeltaFine)

    concentrationDeltaCoarse = (photometerInput.vCoarse *
        perturbanceParams.concentrationCoarse)
    photometerInput.vCoarse += random.uniform(
        -concentrationDeltaCoarse, concentrationDeltaCoarse)

    sphericityDelta = (photometerInput.sphericity *
        perturbanceParams.concentrationSphericity)
    sphericityLowerBound = max(
        photometerInput.sphericity - sphericityDelta, 0.0)
    sphericityUpperBound = min(
        photometerInput.sphericity + sphericityDelta, 100.0)
    photometerInput.sphericity = random.uniform(
        sphericityLowerBound, sphericityUpperBound)

    # Modify attributes for spherical and spheroid concentrations.
    photometerInput.onDataLoaded()

    return photometerInput
