# Copyright (c) 2011-2014, B.I.Stepanov Institute of Physics, National Academy
# of Sciences of Belarus.
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

import copy
import numpy
import numpy.random
import random

from PyQt4.QtCore import *

from RamanPerturbanceParams import *
from RamanRetrievalProcess import *
from MatlabDummyProcess import *
from MatlabMultiProcess import *

__all__ = ['RamanPerturbanceProcess']

# *****************************************************************************
class RamanPerturbanceProcess(MatlabMultiProcess):

    # ---- Public methods -----------------------------------------------------
    def __init__(self, lidarInputs, ramanParams, perturbanceParams):
        """Parameters:
          - 'lidarInputs', 'ramanParams': same as in 'RamanRetrievalProcess'.
          - 'perturbanceParams' : 'RamanPerturbanceParams' instance."""

        MatlabMultiProcess.__init__(self)

        self.processGenerator = RamanPerturbanceGenerator(lidarInputs,
            ramanParams, perturbanceParams)

    # ---- Private overridden methods -----------------------------------------
    def getMatlabProcessGenerator(self):

        return self.processGenerator.generateMatlabProcesses()

# *****************************************************************************
class RamanPerturbanceGenerator:

    # ---- Public methods -----------------------------------------------------
    def __init__(self, lidarInputs, ramanParams, perturbanceParams):

        self.lidarInputs = lidarInputs
        self.ramanParams = ramanParams
        self.perturbanceParams = perturbanceParams

    def generateMatlabProcesses(self):

        # Retrieve aerosol profiles without any input data perturbations.
        mainRamanProcess = RamanRetrievalProcess(self.lidarInputs,
            self.ramanParams)

        yield mainRamanProcess

        # If perturbance analysis is off, one retrieval is just enough.
        if not self.perturbanceParams.perturbanceEnabled:
            return

        # When appended with perturbance-induced dispersion data, this will
        # become the final result of the multi-step retrieval process.
        #
        # Do not modify the original retrieval output instance, because it may
        # still be used by the retrieval output dialog.
        retrievalOutput = copy.deepcopy(mainRamanProcess.getRetrievalOutput())

        initDispersionData(retrievalOutput)

        for i in range(self.perturbanceParams.evaluations):

            modifiedLidarInputs = [createModifiedLidarInput(lidarInput,
                self.perturbanceParams) for lidarInput in self.lidarInputs]

            ramanProcess = RamanRetrievalProcess(modifiedLidarInputs,
                self.ramanParams)

            ramanProcess.setInitData(
                retrievalOutput.profileBackscatter,
                retrievalOutput.profileLidar,
                [retrievalOutput.ratioSource, retrievalOutput.ratioRaman])

            ramanProcess.setHideIntermediateOutputs(True)

            yield ramanProcess

            updateDispersionData(retrievalOutput,
                ramanProcess.getRetrievalOutput())

        finalizeDispersionData(retrievalOutput,
            self.perturbanceParams.evaluations)

        wlSuffixSource = self.lidarInputs[0].getChannelId()
        wlSuffixRaman = self.lidarInputs[1].getChannelId()

        # Strip the 'R' letter from the channel IDs of Raman signals.
        wlSuffixRaman = wlSuffixRaman[:-1]

        writeNoiseParams(retrievalOutput, self.perturbanceParams,
            wlSuffixSource, wlSuffixRaman)

        # Append the final retrieval result to the list of intermediate results
        # realized during the multi-step retrieval process.
        yield MatlabDummyProcess(retrievalOutput)

# **** Private functions ------------------------------------------------------
def writeNoiseParams(retrievalOutput, perturbanceParams, wlSuffixSource,
    wlSuffixRaman):

    retrievalOutput.perturbanceEvaluations = perturbanceParams.evaluations

    retrievalOutput.whiteNoiseSource = (
        getattr(perturbanceParams, 'whiteNoise' + wlSuffixSource) *
        perturbanceParams.whiteNoiseCommon)
    retrievalOutput.whiteNoiseRaman = (
        getattr(perturbanceParams, 'whiteNoise' + wlSuffixRaman) *
        perturbanceParams.whiteNoiseCommon)

    retrievalOutput.linearNoiseSource = (
        getattr(perturbanceParams, 'linear' + wlSuffixSource) *
        perturbanceParams.linearCommon)
    retrievalOutput.linearNoiseRaman = (
        getattr(perturbanceParams, 'linear' + wlSuffixRaman) *
        perturbanceParams.linearCommon)

def initDispersionData(mainRetrievalOutput):

    for attrPair in getDispersionAttributePairs():

        originalValue = getattr(mainRetrievalOutput, attrPair[0])
        # This will be either zero array or zero scalar depending on the data
        # type of 'originalAttr'.
        dispersionInit = originalValue - originalValue

        setattr(mainRetrievalOutput, attrPair[1], dispersionInit)

def updateDispersionData(mainRetrievalOutput, perturbedOutput):

    for attrPair in getDispersionAttributePairs():

        originalValue = getattr(mainRetrievalOutput, attrPair[0])
        perturbedValue = getattr(perturbedOutput, attrPair[0])
        dispersion = getattr(mainRetrievalOutput, attrPair[1])

        deviation = perturbedValue - originalValue
        updatedDispersion = dispersion + deviation * deviation

        setattr(mainRetrievalOutput, attrPair[1], updatedDispersion)

def finalizeDispersionData(mainRetrievalOutput, evaluations):

    for attrPair in getDispersionAttributePairs():

        dispersion = getattr(mainRetrievalOutput, attrPair[1])

        finalizedDispersion = numpy.sqrt(dispersion / evaluations)

        setattr(mainRetrievalOutput, attrPair[1], finalizedDispersion)

def getDispersionAttributePairs():

    return [
        ('profileBackscatter', 'profileDispersionBackscatter'),
        ('profileLidar', 'profileDispersionLidar'),
        ('ratioSource', 'ratioDispersionSource'),
        ('ratioRaman', 'ratioDispersionRaman'),
    ]

def createModifiedLidarInput(mainLidarInput, perturbanceParams):

    lidarInput = copy.deepcopy(mainLidarInput)

    wlSuffix = lidarInput.getChannelId()

    # Strip the 'R' letter from the channel IDs of Raman signals.
    if wlSuffix.endswith('R'):
        wlSuffix = wlSuffix[:-1]

    lidarIndex = numpy.arange(lidarInput.firstInputIndex,
        lidarInput.lastInputIndex + 1, dtype = numpy.float64)
    lidarInputSize = len(lidarIndex)
    assert lidarInputSize == len(lidarInput.normalizedSignal)

    # Calculate white noise perturbance for 'lidarInput'.
    noiseAmplitude = (getattr(perturbanceParams, 'whiteNoise' + wlSuffix) *
        perturbanceParams.whiteNoiseCommon)

    # Lidar noise amplitude is equal to 'noiseAmplitude' at the reference
    # point, and also is proportional to distance from the ground squared.
    if noiseAmplitude != 0.0:
        lidarNoise = (
            numpy.random.normal(0.0, noiseAmplitude, lidarInputSize) *
            (lidarIndex / lidarInput.refPointIndex) ** 2)

        # Suppress lidar noise at the reference point.
        lidarNoise[lidarInput.refPointIndex - lidarInput.firstInputIndex] = 0.0

    else:
        lidarNoise = numpy.zeros(lidarInputSize)

    # Calculate linear perturbance for 'lidarInput'.
    linearFactorBound = (getattr(perturbanceParams, 'linear' + wlSuffix) *
        perturbanceParams.linearCommon)

    linearFactor = random.uniform(-linearFactorBound, linearFactorBound)

    # This coefficient is 'linearFactor' at the ground level and linearly
    # diminishes to 0.0 at the reference point.
    linearPerturbance = linearFactor * (
        numpy.ones(lidarInputSize) - lidarIndex / lidarInput.refPointIndex)

    # Introduce the calculated perturbances into the 'lidarInput' instance.
    lidarInput.normalizedSignal = (
        (lidarInput.normalizedSignal + lidarNoise) *
        (1.0 + linearPerturbance)).astype(numpy.float32)

    return lidarInput
