# Copyright (c) 2011-2014, B.I.Stepanov Institute of Physics, National Academy
# of Sciences of Belarus.
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

import numpy

from common.utils import txt

from common.DataRecord import *

from LidarInput import *

__all__ = ['PreparedLidarInput']

# *****************************************************************************
class PreparedLidarInput(LidarInput):
    """Lidar measurement at a single wavelength, prepared for aerosol profile
    retrieval and supplemented with molecular atmosphere profiles.

    Attributes (in addition to those defined in 'LidarInput'):
      - 'firstInputIndex', 'lastInputIndex': first and last indices of data
        selected for aerosol retrieval (with respect to the data grid defined
        by 'LidarInput.gridStep' and 'LidarInput.gridZenithAngle').
      - 'refPointIndex': index of the lidar signal normalization point.

      - 'refMolBackscatter': molecular backscatter at the reference point (for
        the unpolarized lidar channel).

      - 'molThicknessError', 'aerosolThicknessError', 'totalBackscatterError':
        estimated values for relative errors of molecular optical thickness,
        aerosol optical thickness and total (molecular + aerosol) backscatter
        coefficient used in the dispersion calculation.

      - 'atmoModel': textual identifier for the molecular atmosphere model used
        during the export procedure: 'STD' for International Standard
        Atmosphere 1976, 'CIRA' for COSPAR International Reference Atmosphere
        1986, and 'CUST' for a custom molecular atmosphere model.

      - 'normalizedSignal': Numpy array holding the lidar signal, normalized to
        the reference point. Index zero of the array corresponds to
        'firstInputIndex' of the data grid; array length is 'lastInputIndex -
        firstInputIndex + 1'.
      - 'lidarDispersion': Numpy array holding weighting coefficients for lidar
        signal equations (same dimensions as in 'normalizedSignal').
      - 'molBackscatter': Numpy array holding molecular backscatter
        coefficients for the unpolarized lidar channel (same dimensions as in
        'normalizedSignal').
      - 'molThickness': Numpy array holding molecular optical thickness
        relative to the reference point (same dimensions as in
        'normalizedSignal').

      - 'lidarRefValue': averaged value of the lidar signal at the reference
        point."""

    # ---- Class attributes ---------------------------------------------------
    fields = LidarInput.fields + [
        DataField('firstInputIndex', 'Left',   'SMALLINT'),
        DataField('lastInputIndex',  'Right',  'SMALLINT'),
        DataField('refPointIndex',   'RP',     'SMALLINT'),

        DataField('refMolBackscatter', 'BRM', 'SINGLE'),

        DataField('molThicknessError',     'RE OMT', 'SINGLE'),
        DataField('aerosolThicknessError', 'RE OAT', 'SINGLE'),
        DataField('totalBackscatterError', 'RE GBS', 'SINGLE'),

        DataField('atmoModel', 'AtmoModel', 'VARCHAR', required = False),

        DataField('normalizedSignal', 'Effective Value', 'LONGBINARY',
            numpy.float32),
        DataField('lidarDispersion', 'Dispersion Value', 'LONGBINARY',
            numpy.float32),
        DataField('molBackscatter', 'Molecular Model', 'LONGBINARY',
            numpy.float32),
        DataField('molThickness', 'Optical Thickness', 'LONGBINARY',
            numpy.float32),

        DataField('lidarRefValue', 'RefVal', 'SINGLE', required = False)
    ]

    extraFieldNames = ['TropoExport']
    selectStatementExtraText = (' WHERE TropoExport < 0 OR '
        'TropoExport = 32767 OR TropoExport = 1')

    openDatabaseErrorMessage = (
        'Lidar input file is not a valid Access database or may not be opened '
        'for reading')
    queryFailureErrorMessage = (
        'Lidar input: failed to query data from the database file')
    invalidFormatErrorPrefix = (
        'Lidar input file format is invalid: ')

    # ---- Public methods -----------------------------------------------------
    def getInputZenithAngles(self):
        """Return a view of 'self.getZenithAngles()' array representing only
        those data that have been selected for aerosol retrieval, so that array
        dimensinons are fully compatible with 'normalizedSignal' etc."""

        return self.getZenithAngles()[
            self.firstInputIndex - self.firstDataIndex :
            self.lastInputIndex + 1 - self.firstDataIndex]

    def getCorrectedRefMolBackscatter(self, molDepolarization,
        parallelLeakage = 0.0):
        """Return molecular backscatter coefficient at the reference point,
        corrected according to polarization of the 'self' lidar measurement."""

        return self.correctMolBackscatter(self.refMolBackscatter,
            molDepolarization, parallelLeakage)

    def getCorrectedMolBackscatter(self, molDepolarization,
        parallelLeakage = 0.0):
        """Return Numpy array holding molecular backscatter coefficients
        corrected according to polarization of the 'self' lidar measurement."""

        return self.correctMolBackscatter(self.molBackscatter,
            molDepolarization, parallelLeakage)

    def getErrorMessage(self):
        """Return an error message describing the malformed data or 'None' if
        all attribute values are plausible and consistent."""

        baseMessage = LidarInput.getErrorMessage(self)
        if baseMessage is not None:
            return baseMessage

        # Check if some of the values were missing in the database.
        # Note: this will actually duplicate checks for the base class.
        for name in self.getRequiredAttributeNames():
            if getattr(self, name) is None:
                return ('%s database field contains no data' %
                    txt.quote(self.getFieldName(name)))

        if self.firstInputIndex >= self.lastInputIndex:
            return ('%s and %s database field values are inconsistent' %
                (txt.quote(self.getFieldName('firstInputIndex')),
                txt.quote(self.getFieldName('lastInputIndex'))))

        if (self.firstInputIndex < self.firstDataIndex or
            self.lastInputIndex > self.lastDataIndex):
            return ('%s and %s database field values are inconsistent with '
                '%s and %s' %
                (txt.quote(self.getFieldName('firstInputIndex')),
                txt.quote(self.getFieldName('lastInputIndex')),
                txt.quote(self.getFieldName('firstDataIndex')),
                txt.quote(self.getFieldName('lastDataIndex'))))

        if self.refPointIndex < self.firstInputIndex:
            return ('%s database field value is smaller than %s' %
                (txt.quote(self.getFieldName('refPointIndex')),
                txt.quote(self.getFieldName('firstInputIndex'))))

        inputArrayLength = self.lastInputIndex - self.firstInputIndex + 1

        # Check sizes of the input arrays.
        for name in ('normalizedSignal', 'lidarDispersion',
            'molBackscatter', 'molThickness'):

            value = getattr(self, name)

            if len(value) != inputArrayLength:
                valueBytes = len(value) * value.itemsize
                requiredBytes = inputArrayLength * value.itemsize

                return (
                    'size of the %s database field is %s bytes instead of %s' %
                    (txt.quote(self.getFieldName(name)),
                    txt.quoteNumber(valueBytes),
                    txt.quoteNumber(requiredBytes)))

    # ---- Private methods ----------------------------------------------------
    def correctMolBackscatter(self, betaTotal, molDepolarization,
        parallelLeakage = 0.0):
        """Transform the given molecular backscatter coefficient for the
        unpolarized lidar channel (either scalar on an array) into an
        appropriate one for the 'self' lidar measurement.

        Return corrected version of the coefficient."""

        if self.polarization in (0, 1):
            # For unpolarized and Raman channels, correction is not needed.
            return betaTotal

        elif self.polarization == 3:
            # For parallel-polarized channels, subtract the cross-polarized
            # component from the total backscatter.
            return 1.0 / (1.0 + molDepolarization) * betaTotal

        elif self.polarization == 2:
            # For cross-polarized channels, include both the cross-polarized
            # component and a correction term for instrumental inaccuracies.
            return ((molDepolarization + parallelLeakage) /
                (1.0 + molDepolarization) * betaTotal)
