# Copyright (c) 2011-2014, B.I.Stepanov Institute of Physics, National Academy
# of Sciences of Belarus.
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

import copy
import numpy
import numpy.random
import random

from PyQt4.QtCore import *

from PolarPerturbanceParams import *
from PolarRetrievalProcess import *
from MatlabDummyProcess import *
from MatlabMultiProcess import *

__all__ = ['PolarPerturbanceProcess']

# *****************************************************************************
class PolarPerturbanceProcess(MatlabMultiProcess):

    # ---- Public methods -----------------------------------------------------
    def __init__(self, lidarInputs, polarParams, perturbanceParams):
        """Parameters:
          - 'lidarInputs', 'polarParams': same as in 'PolarRetrievalProcess'.
          - 'perturbanceParams' : 'PolarPerturbanceParams' instance."""

        MatlabMultiProcess.__init__(self)

        self.processGenerator = PolarPerturbanceGenerator(lidarInputs,
            polarParams, perturbanceParams)

    # ---- Private overridden methods -----------------------------------------
    def getMatlabProcessGenerator(self):

        return self.processGenerator.generateMatlabProcesses()

# *****************************************************************************
class PolarPerturbanceGenerator:

    # ---- Public methods -----------------------------------------------------
    def __init__(self, lidarInputs, polarParams, perturbanceParams):

        self.lidarInputs = lidarInputs
        self.polarParams = polarParams
        self.perturbanceParams = perturbanceParams

    def generateMatlabProcesses(self):

        # Retrieve aerosol profiles without any input data perturbations.
        mainPolarProcess = PolarRetrievalProcess(self.lidarInputs,
            self.polarParams)

        yield mainPolarProcess

        # If perturbance analysis is off, one retrieval is just enough.
        if not self.perturbanceParams.perturbanceEnabled:
            return

        # When appended with perturbance-induced dispersion data, this will
        # become the final result of the multi-step retrieval process.
        #
        # Do not modify the original retrieval output instance, because it may
        # still be used by the retrieval output dialog.
        retrievalOutput = copy.deepcopy(mainPolarProcess.getRetrievalOutput())

        initDispersionData(retrievalOutput)

        basePolarization = self.lidarInputs[0].polarization

        for i in range(self.perturbanceParams.evaluations):

            modifiedLidarInputs = [createModifiedLidarInput(lidarInput,
                self.perturbanceParams, basePolarization)
                for lidarInput in self.lidarInputs]

            modifiedPolarParams = createModifiedPolarParams(self.polarParams,
                self.perturbanceParams)

            polarProcess = PolarRetrievalProcess(modifiedLidarInputs,
                modifiedPolarParams)

            polarProcess.setInitData(
                retrievalOutput.backscatterParallel,
                retrievalOutput.backscatterCross)

            polarProcess.setHideIntermediateOutputs(True)

            yield polarProcess

            updateDispersionData(retrievalOutput,
                polarProcess.getRetrievalOutput())

        finalizeDispersionData(retrievalOutput,
            self.perturbanceParams.evaluations)

        wlSuffixBase = str(basePolarization)
        wlSuffixCross = '2' + wlSuffixBase

        writeNoiseParams(retrievalOutput, self.perturbanceParams,
            wlSuffixBase, wlSuffixCross)

        # Append the final retrieval result to the list of intermediate results
        # realized during the multi-step retrieval process.
        yield MatlabDummyProcess(retrievalOutput)

# **** Private functions ------------------------------------------------------
def writeNoiseParams(retrievalOutput, perturbanceParams, wlSuffixBase,
    wlSuffixCross):

    retrievalOutput.perturbanceEvaluations = perturbanceParams.evaluations

    retrievalOutput.whiteNoiseBase = (
        getattr(perturbanceParams, 'whiteNoise' + wlSuffixBase) *
        perturbanceParams.whiteNoiseCommon)
    retrievalOutput.whiteNoiseCross = (
        getattr(perturbanceParams, 'whiteNoise' + wlSuffixCross) *
        perturbanceParams.whiteNoiseCommon)

    retrievalOutput.linearNoiseBase = (
        getattr(perturbanceParams, 'linear' + wlSuffixBase) *
        perturbanceParams.linearCommon)
    retrievalOutput.linearNoiseCross = (
        getattr(perturbanceParams, 'linear' + wlSuffixCross) *
        perturbanceParams.linearCommon)

def initDispersionData(mainRetrievalOutput):

    originalBackscatterParallel = mainRetrievalOutput.backscatterParallel
    originalDepolarization = (mainRetrievalOutput.backscatterCross /
        mainRetrievalOutput.backscatterParallel)

    mainRetrievalOutput.backscatterDispersionParallel = (
        originalBackscatterParallel - originalBackscatterParallel)

    mainRetrievalOutput.depolarizationDispersion = (
        originalDepolarization - originalDepolarization)

def updateDispersionData(mainRetrievalOutput, perturbedOutput):

    deviationBackscatterParallel = (
        perturbedOutput.backscatterParallel -
        mainRetrievalOutput.backscatterParallel)

    deviationDepolarization = (
        perturbedOutput.backscatterCross /
            perturbedOutput.backscatterParallel -
        mainRetrievalOutput.backscatterCross /
            mainRetrievalOutput.backscatterParallel)

    mainRetrievalOutput.backscatterDispersionParallel += (
        deviationBackscatterParallel * deviationBackscatterParallel)

    mainRetrievalOutput.depolarizationDispersion += (
        deviationDepolarization * deviationDepolarization)

def finalizeDispersionData(mainRetrievalOutput, evaluations):

    mainRetrievalOutput.backscatterDispersionParallel = (numpy.sqrt(
        mainRetrievalOutput.backscatterDispersionParallel / evaluations))

    mainRetrievalOutput.depolarizationDispersion = (numpy.sqrt(
        mainRetrievalOutput.depolarizationDispersion / evaluations))

def createModifiedLidarInput(mainLidarInput, perturbanceParams,
    basePolarization):

    lidarInput = copy.deepcopy(mainLidarInput)

    if mainLidarInput.polarization == 2:
        wlSuffix = '2' + str(basePolarization)
    else:
        assert mainLidarInput.polarization == basePolarization
        wlSuffix = str(basePolarization)

    lidarIndex = numpy.arange(lidarInput.firstInputIndex,
        lidarInput.lastInputIndex + 1, dtype = numpy.float64)
    lidarInputSize = len(lidarIndex)
    assert lidarInputSize == len(lidarInput.normalizedSignal)

    # Calculate white noise perturbance for 'lidarInput'.
    noiseAmplitude = (getattr(perturbanceParams, 'whiteNoise' + wlSuffix) *
        perturbanceParams.whiteNoiseCommon)

    # Lidar noise amplitude is equal to 'noiseAmplitude' at the reference
    # point, and also is proportional to distance from the ground squared.
    if noiseAmplitude != 0.0:
        lidarNoise = (
            numpy.random.normal(0.0, noiseAmplitude, lidarInputSize) *
            (lidarIndex / lidarInput.refPointIndex) ** 2)

        # Suppress lidar noise at the reference point.
        lidarNoise[lidarInput.refPointIndex - lidarInput.firstInputIndex] = 0.0

    else:
        lidarNoise = numpy.zeros(lidarInputSize)

    # Calculate linear perturbance for 'lidarInput'.
    linearFactorBound = (getattr(perturbanceParams, 'linear' + wlSuffix) *
        perturbanceParams.linearCommon)

    linearFactor = random.uniform(-linearFactorBound, linearFactorBound)

    # This coefficient is 'linearFactor' at the ground level and linearly
    # diminishes to 0.0 at the reference point.
    linearPerturbance = linearFactor * (
        numpy.ones(lidarInputSize) - lidarIndex / lidarInput.refPointIndex)

    # Introduce the calculated perturbances into the 'lidarInput' instance.
    lidarInput.normalizedSignal = (
        (lidarInput.normalizedSignal + lidarNoise) *
        (1.0 + linearPerturbance)).astype(numpy.float32)

    return lidarInput

def createModifiedPolarParams(mainPolarParams, perturbanceParams):

    modifiedPolarParams = copy.deepcopy(mainPolarParams)

    # Modify all the available parameters, even those that won't be used in
    # this particular retrieval, to make things simpler.
    for wlSuffix in ('0', '3'):

        for coeffName in ('aerBackscatterRatioRef', 'lidarCalibration'):

            paramName = coeffName + wlSuffix

            mainValue = getattr(mainPolarParams, paramName)
            variation = getattr(perturbanceParams, paramName)

            modifiedValue = mainValue * (
                1.0 + random.uniform(-variation, variation)
            )

            setattr(modifiedPolarParams, paramName, modifiedValue)

    return modifiedPolarParams
