# Copyright (c) 2011-2014, B.I.Stepanov Institute of Physics, National Academy
# of Sciences of Belarus.
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

import datetime
import numpy

from MatlabProcess import *
from PreparedLidarInput import *
from RamanOutput import *
from RamanParams import *

__all__ = ['RamanRetrievalProcess']

# *****************************************************************************
class RamanRetrievalProcess(MatlabProcess):
    """Wrapper around a standalone Matlab application that implements aerosol
    backscatter and lidar ratio retrieval algorithm utilizing Raman lidar
    signals.

    Attributes:
      - 'lidarInputs': a pair of 'PreparedLidarInput' instances representing
        measurements for the source and Raman lidar channels. Source lidar
        measurement mey be either unpolarized or parallel-polarized.
      - 'ramanParams': a 'RamanParams' instance.

      - 'profileInitBackscatter', 'profileInitLidar': initial approximations to
        the aerosol profiles to be retrieved, or 'None's if default values
        should be used instead.
      - 'ratiosInit': initial approximations to backscatter ratios at the
        reference points (a list of 2 numbers), or 'None' if default values
        should be used instead.

      - 'hideIntermediateOutputs': 'True', if 'getAllRetrievalOutputs' method
        of this class should return only the last retrieval output instance
        (similar to 'getRetrievalOutput()'). This is 'False' by default."""

    # ---- Public overridden methods ------------------------------------------
    def __init__(self, lidarInputs, ramanParams):

        MatlabProcess.__init__(self)

        # ---- Members ----
        self.lidarInputs = lidarInputs
        self.ramanParams = ramanParams

        self.profileInitBackscatter = None
        self.profileInitLidar = None

        self.ratiosInit = None

        self.hideIntermediateOutputs = False

    # ---- Public methods -----------------------------------------------------
    def setInitData(self, profileInitBackscatter, profileInitLidar,
        ratiosInit):

        assert len(ratiosInit) == len(self.lidarInputs) == 2

        self.profileInitBackscatter = profileInitBackscatter
        self.profileInitLidar = profileInitLidar

        self.ratiosInit = ratiosInit

    def setHideIntermediateOutputs(self, hideIntermediateOutputs):

        self.hideIntermediateOutputs = hideIntermediateOutputs

    # ---- Public overridden methods ------------------------------------------
    def getAllRetrievalOutputs(self):

        if self.hideIntermediateOutputs:
            return [self.getRetrievalOutput()]

        return MatlabProcess.getAllRetrievalOutputs(self)

    # ---- Private overridden methods -----------------------------------------
    def getMatlabProgramName(self):
        return 'raman'

    def writeInputData(self, hdfGroup):
        """Prepare input file for the Matlab application."""

        (suffixSource, suffixRaman) = self.getLidarChannelSuffixes()

        gridSize = max(input.lastInputIndex + 1 for input in self.lidarInputs)

        # Weighting coefficients.
        hdfGroup.create_dataset('k', data = [
            getattr(self.ramanParams, 'weighting' + suffixSource),
            getattr(self.ramanParams, 'weighting' + suffixRaman)])
        hdfGroup.create_dataset('d', data = [
            getattr(self.ramanParams,
                'weightSmoothBackscatter' + suffixSource),
            getattr(self.ramanParams, 'weightSmoothLidar' + suffixSource),
            getattr(self.ramanParams, 'weightDeviateLidar' + suffixSource)])

        # Aerosol attenuation ratio is specified in the parameters.
        hdfGroup.create_dataset('etaA', data = [
            getattr(self.ramanParams, 'attenuationRatio' + suffixSource)])

        # Molecular backscatter/attenuation ratio is also required to handle
        # cases when zenith angle profiles and/or boundaries of the lidar data
        # grid sections selected for the retrieval are not the same.
        hdfGroup.create_dataset('etaM', data = [
            getMolBackscatterCoeff(self.lidarInputs[1].wavelength) /
            getMolBackscatterCoeff(self.lidarInputs[0].wavelength)])

        # Initial approximation to the backscatter ratio profile.
        hdfGroup.create_dataset('Theta0', data = [
            numpy.ones(gridSize) * 0.1 if self.profileInitBackscatter is None
                else self.profileInitBackscatter.astype(numpy.float64)])
        # Initial approximation to the lidar ratio profile.
        hdfGroup.create_dataset('Gamma0', data = [
            numpy.ones(gridSize) * 50.0 if self.profileInitLidar is None
                else self.profileInitLidar.astype(numpy.float64)])

        heightStep = self.lidarInputs[0].gridHeightStep
        assert all(input.gridHeightStep == heightStep
            for input in self.lidarInputs)

        hdfGroup.create_dataset('hStep', data = [heightStep])

        # Store all the numbers as floats (the default Matlab data type).
        # Integers are not accepted by some Matlab functions (e.g.
        # 'sparse', for indexing) and need to be converted anyway.
        hdfGroup.create_dataset('Nbeg', data = [
            float(input.firstInputIndex) for input in self.lidarInputs])
        hdfGroup.create_dataset('Nref', data = [
            float(input.refPointIndex) for input in self.lidarInputs])
        hdfGroup.create_dataset('Nfin', data = [
            float(input.lastInputIndex) for input in self.lidarInputs])

        hdfGroup.create_dataset('betaMref', data = [
            input.getCorrectedRefMolBackscatter(
                # Use the same depolarization for all the wavelengths, as
                # the difference in values is negligible.
                self.ramanParams.molDepolarization)
                for input in self.lidarInputs])

        # All the arrays are converted to 'float64' data type, as otherwise
        # some Matlab functions (e.g. 'sparse') would refuse to work.
        hdfGroup.create_dataset('S', data = [
            self.extendArray(
                input.normalizedSignal, input.firstInputIndex, gridSize)
                for input in self.lidarInputs])

        hdfGroup.create_dataset('Omega', data = [
            self.extendArray(
                input.lidarDispersion, input.firstInputIndex, gridSize)
                for input in self.lidarInputs])

        hdfGroup.create_dataset('betaM', data = [
            self.extendArray(input.getCorrectedMolBackscatter(
                # Use the same depolarization for all the wavelengths, as
                # the difference in values is negligible.
                self.ramanParams.molDepolarization),
                input.firstInputIndex, gridSize)
                for input in self.lidarInputs])

        hdfGroup.create_dataset('tauM', data = [
            self.extendArray(input.molThickness,
            input.firstInputIndex, gridSize) for input in self.lidarInputs])

        hdfGroup.create_dataset('Z', data = [
            self.extendArray(input.getInputZenithAngles(),
            input.firstInputIndex, gridSize) for input in self.lidarInputs])

        hdfGroup.create_dataset('R0', data =
            [self.ramanParams.ratioSource, self.ramanParams.ratioRaman]
            if self.ratiosInit is None else self.ratiosInit)

        hdfGroup.create_dataset('RLower', data = [
            self.ramanParams.ratioSource *
                (1.0 - self.ramanParams.ratioSourceLimit),
            self.ramanParams.ratioRaman *
                (1.0 - self.ramanParams.ratioRamanLimit)
        ])
        hdfGroup.create_dataset('RUpper', data = [
            self.ramanParams.ratioSource *
                (1.0 + self.ramanParams.ratioSourceLimit),
            self.ramanParams.ratioRaman *
                (1.0 + self.ramanParams.ratioRamanLimit)
        ])

        hdfGroup.create_dataset('TolFun', data = [
            self.ramanParams.tolFun])
        hdfGroup.create_dataset('TolX', data = [
            self.ramanParams.tolX])

    def readOutputData(self, hdfGroup, iteration, errorPrefix):
        """Extract data for the given iteration from the Matlab data file and
        return a 'RamanOutput' instance holding it."""

        output = RamanOutput()

        (suffixSource, suffixRaman) = self.getLidarChannelSuffixes()

        # Assume that geodetic coordinates of lidar measurements coincide.
        output.latitude = self.lidarInputs[0].latitude
        output.longitude = self.lidarInputs[0].longitude
        output.altitude = self.lidarInputs[0].altitude

        # Obtain boundary dates and times for the set of lidar measurements.
        startDateTime = min(input.getStartDateTime()
            for input in self.lidarInputs)
        stopDateTime = max(input.getStopDateTime()
            for input in self.lidarInputs)

        output.startDate = startDateTime.date()
        output.startTime = startDateTime.time()
        output.stopDate = stopDateTime.date()
        output.stopTime = stopDateTime.time()

        # Save the date and time of the moment when the retrieval has finished.
        retrievalDateTime = datetime.datetime.now()

        output.retrievalDate = retrievalDateTime.date()
        output.retrievalTime = retrievalDateTime.time()

        # Copy the common data grid height step from a lidar input. Always use
        # zero zenith angle in output databases.
        output.gridStep = self.lidarInputs[0].gridHeightStep
        output.gridZenithAngle = 0.0

        output.firstInputIndexSource = self.lidarInputs[0].firstInputIndex
        output.firstInputIndexRaman  = self.lidarInputs[1].firstInputIndex

        output.wavelengthSource = self.lidarInputs[0].wavelength
        output.wavelengthRaman  = self.lidarInputs[1].wavelength

        output.polarizationSource = self.lidarInputs[0].polarization

        # Convert TropoExport's molecular atmosphere IDs to constants used by
        # the automated retriever, if possible.
        atmoModels = {'STD' : 'Standard atmosphere', 'CIRA' : 'CIRA 1986',
            'CUST' : 'Custom'}

        if self.lidarInputs[0].atmoModel not in atmoModels:
            output.atmosphereModel = None
        else:
            output.atmosphereModel = atmoModels[self.lidarInputs[0].atmoModel]

        gridSize = max(input.lastInputIndex + 1 for input in self.lidarInputs)

        # Suffixes used to denote lidar channels in attribute names.
        channelIds = ['Source', 'Raman']

        # Copy parameters from the lidar inputs.
        for i in range(len(self.lidarInputs)):

            setattr(output, 'signalId' + channelIds[i],
                self.lidarInputs[i].localId)

            setattr(output, 'polarization' + channelIds[i],
                self.lidarInputs[i].polarization)

            setattr(output, 'dispersion' + channelIds[i],
                # Store exactly the same value that was stored as 'Omega'
                # in 'writeInputData'.
                self.extendArray(self.lidarInputs[i].lidarDispersion,
                    self.lidarInputs[i].firstInputIndex, gridSize).astype(
                    # Convert 'float64' values returned by 'self.extendArray'
                    # to 'float32' ones used by the 'RamanOutput' attributes.
                    numpy.float32))

            setattr(output, 'molThicknessError' + channelIds[i],
                self.lidarInputs[i].molThicknessError)
            setattr(output, 'aerosolThicknessError' + channelIds[i],
                self.lidarInputs[i].aerosolThicknessError)
            setattr(output, 'totalBackscatterError' + channelIds[i],
                self.lidarInputs[i].totalBackscatterError)

        # Copy algorithm parameters from 'self.ramanParams'.
        setattr(output, 'weighting' + channelIds[0],
            getattr(self.ramanParams, 'weighting' + suffixSource))
        setattr(output, 'weighting' + channelIds[1],
            getattr(self.ramanParams, 'weighting' + suffixRaman))

        output.weightSmoothBackscatter = getattr(
            self.ramanParams, 'weightSmoothBackscatter' + suffixSource)
        output.weightSmoothLidar = getattr(
            self.ramanParams, 'weightSmoothLidar' + suffixSource)

        # Copy polarization parameters from 'self.ramanParams'.
        output.molDepolarization = self.ramanParams.molDepolarization
        output.attenuationRatio = getattr(
            self.ramanParams, 'attenuationRatio' + suffixSource)

        # Read retrieved profiles of the aerosol characteristics.
        output.profileBackscatter = self.readArray(
            hdfGroup, 'Theta', 0, iteration, errorPrefix)
        output.profileLidar = self.readArray(
            hdfGroup, 'Gamma', 0, iteration, errorPrefix)

        # Copy molecular backscatter from the lidar input.
        output.molBackscatter = self.extendArray(
            # Store exactly the same value that was stored as 'betaM'
            # in 'writeInputData'.
            self.lidarInputs[0].getCorrectedMolBackscatter(
            self.ramanParams.molDepolarization),
            self.lidarInputs[0].firstInputIndex, gridSize).astype(
            # Convert 'float64' values returned by 'self.extendArray'
            # to 'float32' ones used by the 'RamanOutput' attributes.
            numpy.float32)

        # Calculate and store actual characteristics of the aerosol.
        output.aerBackscatter = (output.profileBackscatter *
            output.molBackscatter)
        output.aerExtinction = output.profileLidar * output.aerBackscatter

        # For lidar signals, return only the relevant portions of the arrays.
        for i in range(len(self.lidarInputs)):
            setattr(output, 'measuredSignal' + channelIds[i],
                self.readArray(hdfGroup, 'LmeasR', i, iteration, errorPrefix,
                self.lidarInputs[i].lastInputIndex + 1))

            setattr(output, 'calculatedSignal' + channelIds[i],
                self.readArray(hdfGroup, 'Lcalc', i, iteration, errorPrefix,
                self.lidarInputs[i].lastInputIndex + 1))

        # Read retrieved backscatter ratios at the reference points.
        for i in range(len(self.lidarInputs)):
            setattr(output, 'ratio' + channelIds[i],
                self.readScalar(hdfGroup, 'R', i, iteration, errorPrefix))

        # Read residuals of the optimization equations.
        for i in range(len(self.lidarInputs)):
            setattr(output, 'residual' + channelIds[i],
                self.readScalar(hdfGroup, 'Psi1', i, iteration, errorPrefix))

        output.resSmoothBackscatter = self.readScalar(
            hdfGroup, 'Psi2', 0, iteration, errorPrefix)
        output.resSmoothLidar = self.readScalar(
            hdfGroup, 'Psi2', 1, iteration, errorPrefix)

        output.resDeviateLidar = self.readScalar(
            hdfGroup, 'Psi2', 2, iteration, errorPrefix)

        output.onDataLoaded()
        return output

    # ---- Private methods ----------------------------------------------------
    def getLidarChannelSuffixes(self):
        """Make sure that the number and order of lidar measurements in the
        'self.lidarInputs' list, as well as their wavelengths and
        polarizations, meet the expectations of the retrieval algorithm.

        Return a pair of suffix strings used to denote the current source and
        Raman lidar channels in 'RamanParams'."""

        # There must be two measurements: the source one and the Raman one.
        assert len(self.lidarInputs) == 2

        assert self.lidarInputs[0].polarization in (0, 3)
        assert self.lidarInputs[1].polarization == 1

        # Currently, there are only two pairs of lidar wavelengths supported.
        if self.lidarInputs[0].wavelength == 355.0:
            assert self.lidarInputs[1].wavelength == 387.0
            return ('355', '387')

        elif self.lidarInputs[0].wavelength == 532.0:
            assert self.lidarInputs[1].wavelength == 607.0
            return ('532', '607')

        else:
            assert False

# **** Private functions ******************************************************
def getMolBackscatterCoeff(wavelength):
    """Return relative scale coefficient for molecular backscatter at the given
    wavelentgh. These values may be used to convert molecular backscatters from
    one wavelength into another.

    Molecular backscatter profile at a given wavelength is proportional to the
    molecular density profile multiplied by this coefficient."""

    # This code was copied from 'getBackscatterByDensity' in
    # 'AtmosphereModels.py'. Molecular backscatter coefficient's dependence on
    # wavelength is the same regardless of the atmosphere model.

    # Molecular backscatter decreases roughly proportional to fourth power of
    # wavelength.
    power = 4.0

    # Correct the formula to fit empirical data.
    if wavelength < 607.0:
        power += 0.000234 * (607.0 - wavelength)

    # Constants are stripped from the return value.
    return (1064.0 / wavelength) ** power
