# Copyright (c) 2011-2014, B.I.Stepanov Institute of Physics, National Academy
# of Sciences of Belarus.
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

import datetime
import numpy

from MatlabProcess import *
from PreparedLidarInput import *
from PolarOutput import *
from PolarParams import *

__all__ = ['PolarRetrievalProcess']

# *****************************************************************************
class PolarRetrievalProcess(MatlabProcess):
    """Wrapper around a standalone Matlab application that implements aerosol
    backscatter and depolarization retrieval algorithm utilizing polarized
    lidar signals.

    Attributes:
      - 'lidarInputs': a pair of 'PreparedLidarInput' instances representing
        measurements for the base and cross-polarized lidar channels. Base
        lidar measurement mey be either unpolarized or parallel-polarized.
      - 'polarParams': a 'PolarParams' instance.

      - 'profileInitParallel', 'profileInitCross': initial approximations to
        the aerosol profiles to be retrieved, or 'None's if default values
        should be used instead.

      - 'hideIntermediateOutputs': 'True', if 'getAllRetrievalOutputs' method
        of this class should return only the last retrieval output instance
        (similar to 'getRetrievalOutput()'). This is 'False' by default."""

    # ---- Public overridden methods ------------------------------------------
    def __init__(self, lidarInputs, polarParams):

        MatlabProcess.__init__(self)

        # ---- Members ----
        self.lidarInputs = lidarInputs
        self.polarParams = polarParams

        self.profileInitParallel = None
        self.profileInitCross = None

        self.ratiosInit = None

        self.hideIntermediateOutputs = False

    # ---- Public methods -----------------------------------------------------
    def setInitData(self, profileInitParallel, profileInitCross):

        self.profileInitParallel = profileInitParallel
        self.profileInitCross = profileInitCross

    def setHideIntermediateOutputs(self, hideIntermediateOutputs):

        self.hideIntermediateOutputs = hideIntermediateOutputs

    # ---- Public overridden methods ------------------------------------------
    def getAllRetrievalOutputs(self):

        if self.hideIntermediateOutputs:
            return [self.getRetrievalOutput()]

        return MatlabProcess.getAllRetrievalOutputs(self)

    # ---- Private overridden methods -----------------------------------------
    def getMatlabProgramName(self):
        return 'polar'

    def writeInputData(self, hdfGroup):
        """Prepare input file for the Matlab application."""

        assert self.lidarInputs[0].polarization in (0, 3)
        assert self.lidarInputs[1].polarization == 2

        baseSuffix = '0' if self.lidarInputs[0].polarization == 0 else '3'

        gridSize = max(input.lastInputIndex + 1 for input in self.lidarInputs)

        # Weighting coefficients.
        hdfGroup.create_dataset('k', data = [
            getattr(self.polarParams, 'weighting' + baseSuffix),
            getattr(self.polarParams, 'weighting2' + baseSuffix)])
        hdfGroup.create_dataset('nu', data = [
            getattr(self.polarParams, 'weightSmooth2'),
            getattr(self.polarParams, 'weightSmooth3')])

        heightStep = self.lidarInputs[0].gridHeightStep
        assert all(input.gridHeightStep == heightStep
            for input in self.lidarInputs)

        hdfGroup.create_dataset('hStep', data = [heightStep])

        assert all(input.firstInputIndex == self.lidarInputs[0].firstInputIndex
            for input in self.lidarInputs)
        assert all(input.refPointIndex == self.lidarInputs[0].refPointIndex
            for input in self.lidarInputs)
        assert all(input.lastInputIndex == self.lidarInputs[0].lastInputIndex
            for input in self.lidarInputs)

        # Store all the numbers as floats (the default Matlab data type).
        # Integers are not accepted by some Matlab functions (e.g.
        # 'sparse', for indexing) and need to be converted anyway.
        hdfGroup.create_dataset('Nbeg', data = [
            float(self.lidarInputs[0].firstInputIndex)])
        hdfGroup.create_dataset('Nref', data = [
            float(self.lidarInputs[0].refPointIndex)])
        hdfGroup.create_dataset('Nfin', data = [
            float(self.lidarInputs[0].lastInputIndex)])

        hdfGroup.create_dataset('betaMref', data = [
            self.lidarInputs[0].getCorrectedRefMolBackscatter(
                self.polarParams.molDepolarization,
                self.polarParams.parallelLeakage)])

        hdfGroup.create_dataset('gamma', data = [
            self.polarParams.lidarRatio])

        # All the arrays are converted to 'float64' data type, as otherwise
        # some Matlab functions (e.g. 'sparse') would refuse to work.
        hdfGroup.create_dataset('S', data = [
            self.extendArray(
                input.normalizedSignal, input.firstInputIndex, gridSize)
                for input in self.lidarInputs])

        hdfGroup.create_dataset('Omega', data = [
            self.extendArray(
                input.lidarDispersion, input.firstInputIndex, gridSize)
                for input in self.lidarInputs])

        hdfGroup.create_dataset('betaM', data = [
            self.extendArray(input.getCorrectedMolBackscatter(
                # Use the same depolarization for all the wavelengths, as
                # the difference in values is negligible.
                self.polarParams.molDepolarization),
                input.firstInputIndex, gridSize)
                for input in self.lidarInputs])

        hdfGroup.create_dataset('tauM', data = [
            self.extendArray(self.lidarInputs[0].molThickness,
            self.lidarInputs[0].firstInputIndex, gridSize)
        ])

        hdfGroup.create_dataset('beta2Coeff', data = [
            1.0 if baseSuffix == '0' else 0.0])
        hdfGroup.create_dataset('leakage', data = [
            self.polarParams.parallelLeakage])

        hdfGroup.create_dataset('Z', data = [
            self.extendArray(self.lidarInputs[0].getInputZenithAngles(),
            self.lidarInputs[0].firstInputIndex, gridSize)])

        # Initial approximation to parallel backscatter profile.
        hdfGroup.create_dataset('beta0', data = [
            numpy.zeros(gridSize) if self.profileInitParallel is None
                else self.profileInitParallel.astype(numpy.float64)])
        # Initial approximation to depolarization ratio profile.
        hdfGroup.create_dataset('d0', data = [
            numpy.zeros(gridSize) if self.profileInitParallel is None
                else (self.profileInitCross / self.profileInitParallel).astype(
                    numpy.float64)])

        # Middle value for aerosol backscatter ratio at the reference point.
        hdfGroup.create_dataset('R0', data = [
            getattr(self.polarParams, 'aerBackscatterRatioRef' + baseSuffix)])
        # Middle value for the lidar correction coefficient.
        hdfGroup.create_dataset('Q0', data = [
            getattr(self.polarParams, 'lidarCalibration' + baseSuffix)])

        # Relative boundary deltas for the correction variables.
        hdfGroup.create_dataset('Rtolerance', data = [
            getattr(self.polarParams, 'aerBackscatterTolerance' + baseSuffix)])
        hdfGroup.create_dataset('Qtolerance', data = [
            getattr(self.polarParams, 'calibrationTolerance' + baseSuffix)])

        # Set initial approximations of the correction coefficients to be
        # coincident with the middle values.
        hdfGroup.create_dataset('Rinit', data = [
            getattr(self.polarParams, 'aerBackscatterRatioRef' + baseSuffix)])
        hdfGroup.create_dataset('Qinit', data = [
            getattr(self.polarParams, 'lidarCalibration' + baseSuffix)])

        hdfGroup.create_dataset('TolFun', data = [
            self.polarParams.tolFun])
        hdfGroup.create_dataset('TolX', data = [
            self.polarParams.tolX])

    def readOutputData(self, hdfGroup, iteration, errorPrefix):
        """Extract data for the given iteration from the Matlab data file and
        return a 'PolarOutput' instance holding it."""

        output = PolarOutput()

        baseSuffix = '0' if self.lidarInputs[0].polarization == 0 else '3'

        # Assume that geodetic coordinates of lidar measurements coincide.
        output.latitude = self.lidarInputs[0].latitude
        output.longitude = self.lidarInputs[0].longitude
        output.altitude = self.lidarInputs[0].altitude

        # Obtain boundary dates and times for the set of lidar measurements.
        startDateTime = min(input.getStartDateTime()
            for input in self.lidarInputs)
        stopDateTime = max(input.getStopDateTime()
            for input in self.lidarInputs)

        output.startDate = startDateTime.date()
        output.startTime = startDateTime.time()
        output.stopDate = stopDateTime.date()
        output.stopTime = stopDateTime.time()

        # Save the date and time of the moment when the retrieval has finished.
        retrievalDateTime = datetime.datetime.now()

        output.retrievalDate = retrievalDateTime.date()
        output.retrievalTime = retrievalDateTime.time()

        # Copy the common data grid height step from a lidar input. Always use
        # zero zenith angle in output databases.
        output.gridStep = self.lidarInputs[0].gridHeightStep
        output.gridZenithAngle = 0.0

        output.polarizationBase = self.lidarInputs[0].polarization

        # These two values could initially be different, but now must be the
        # same by the assumptions of the algorithm.
        output.firstInputIndexBase = self.lidarInputs[0].firstInputIndex
        output.firstInputIndexCross  = self.lidarInputs[1].firstInputIndex

        # Convert TropoExport's molecular atmosphere IDs to constants used by
        # the automated retriever, if possible.
        atmoModels = {'STD' : 'Standard atmosphere', 'CIRA' : 'CIRA 1986',
            'CUST' : 'Custom'}

        if self.lidarInputs[0].atmoModel not in atmoModels:
            output.atmosphereModel = None
        else:
            output.atmosphereModel = atmoModels[self.lidarInputs[0].atmoModel]

        gridSize = max(input.lastInputIndex + 1 for input in self.lidarInputs)

        # Suffixes used to denote lidar channels in attribute names.
        channelIds = ['Base', 'Cross']

        # Copy parameters from the lidar inputs.
        for i in range(len(self.lidarInputs)):

            setattr(output, 'signalId' + channelIds[i],
                self.lidarInputs[i].localId)

            setattr(output, 'dispersion' + channelIds[i],
                # Store exactly the same value that was stored as 'Omega'
                # in 'writeInputData'.
                self.extendArray(self.lidarInputs[i].lidarDispersion,
                    self.lidarInputs[i].firstInputIndex, gridSize).astype(
                    # Convert 'float64' values returned by 'self.extendArray'
                    # to 'float32' ones used by the 'PolarOutput' attributes.
                    numpy.float32))

            setattr(output, 'molThicknessError' + channelIds[i],
                self.lidarInputs[i].molThicknessError)
            setattr(output, 'aerosolThicknessError' + channelIds[i],
                self.lidarInputs[i].aerosolThicknessError)
            setattr(output, 'totalBackscatterError' + channelIds[i],
                self.lidarInputs[i].totalBackscatterError)

        # Copy algorithm parameters from 'self.polarParams'.
        output.weightingBase = getattr(self.polarParams,
            'weighting' + baseSuffix)
        output.weightingCross = getattr(self.polarParams,
            'weighting2' + baseSuffix)

        output.weightSmoothParallel = self.polarParams.weightSmooth3
        output.weightSmoothCross = self.polarParams.weightSmooth2

        # Copy polarization parameters from 'self.polarParams'.
        output.parallelLeakage = self.polarParams.parallelLeakage
        output.molDepolarization = self.polarParams.molDepolarization
        output.lidarRatio = self.polarParams.lidarRatio

        # Read the retrieved aerosol backscatter profiles.
        output.backscatterParallel = self.readArray(
            hdfGroup, 'beta', 0, iteration, errorPrefix)
        output.backscatterCross = output.backscatterParallel * self.readArray(
            hdfGroup, 'd', 0, iteration, errorPrefix)

        # Copy molecular backscatter from the lidar input.
        output.molBackscatter = self.extendArray(
            # Store exactly the same value that was stored as 'betaM'
            # in 'writeInputData'.
            self.lidarInputs[0].getCorrectedMolBackscatter(
            self.polarParams.molDepolarization),
            self.lidarInputs[0].firstInputIndex, gridSize).astype(
            # Convert 'float64' values returned by 'self.extendArray'
            # to 'float32' ones used by the 'RamanOutput' attributes.
            numpy.float32)

        # Read the profiles participating in the optimization equations.
        output.measuredSignalBase = self.readArray(
            hdfGroup, 'LmeasR', 0, iteration, errorPrefix)
        output.measuredSignalCross = self.readArray(
            hdfGroup, 'YmeasQ', 0, iteration, errorPrefix)
        output.calculatedSignalBase = self.readArray(
            hdfGroup, 'Lcalc', 0, iteration, errorPrefix)
        output.calculatedSignalCross = self.readArray(
            hdfGroup, 'Ycalc', 0, iteration, errorPrefix)

        # Read the retrieved aerosol backscatter ratio at the reference point.
        output.aerBackscatterRatioRef = output.lidarCorrection = self.readScalar(
            hdfGroup, 'R', 0, iteration, errorPrefix)
        # Read the retrieved cross-polarized lidar correction coefficient.
        output.lidarCorrection = self.readScalar(
            hdfGroup, 'Q', 0, iteration, errorPrefix)

        # Read residuals of the optimization equations.
        output.residualBase = self.readScalar(
            hdfGroup, 'Psi1', 0, iteration, errorPrefix)
        output.residualCross = self.readScalar(
            hdfGroup, 'Psi1', 1, iteration, errorPrefix)

        output.resSmoothParallel = self.readScalar(
            hdfGroup, 'Psi2', 0, iteration, errorPrefix)
        output.resSmoothCross = self.readScalar(
            hdfGroup, 'Psi2', 1, iteration, errorPrefix)

        output.onDataLoaded()
        return output
