# Copyright (c) 2011-2014, B.I.Stepanov Institute of Physics, National Academy
# of Sciences of Belarus.
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

import datetime
import math
import numpy

from common.utils import txt

from common.DataRecord import *

__all__ = ['LidarInput']

# *****************************************************************************
class LidarInput(DataRecord):
    """Input data from a single unprocessed lidar measurement.

    Attributes:
      - 'localId': a text that presumedly uniquely identifies the measurement.
      - 'latitude', 'longitude', 'altitude': geodetic coordinates of the
        measurement site, in degrees and meters above sea level.
      - 'startDate', 'stopDate', 'startTime', 'stopTime': time boundaries for
        the series of lidar shots that constitute the measurement.

      - 'regNumber': registration number of the lidar channel, denoting its
        wavelength, polarization, aperture etc. in a uniform way.
      - 'wavelength': nominal lidar signal wavelength in nanometers.
      - 'polarization': '0' for unpolarized signal, '1' for Raman signal, '2'
        and '3' for cross and parallel recievers for a linearly polarized
        signal.

      - 'gridStep': (uniform) step of the data grid, in meters.
      - 'gridZenithAngle': zenith angle of the data grid, in degrees. Real
        height step is 'gridStep * cos(gridZenithAngle)'. Index zero of the
        grid corresponds to an along-track segment that has 'gridStep' length
        and starts at the measurement point. If 'originalZenithAgnles' array
        is 'None', this will correspond to the actual zenith angle of the lidar
        measurement; otherwise 'originalZenithAngles' should be used for that.
      - 'firstDataIndex', 'lastDataIndex': first and last indices of the data
        grid for which data is present in the arrays. Add 'firstDataIndex' to
        physical indices of the data arrays to obtain indices of the data grid.

      - 'accumulation': number of lidar shots combined in this measurement.
      - 'background': background signal or 0.0 if this signal is a joined one.
      - 'nonlinearity': estimation of error coefficient describing nonlinearity
        of the receiving channel.
      - 'syncNoise': estimation of error coefficient describing the receiver's
        synchronous noise.
      - 'nonSyncNoise': estimation of error coefficient describing the
        receiver's nonsynchronous noise.

      - 'signal': Numpy array holding the raw lidar signal profile, including
        the background (unless this signal is a joined one, in which case
        background signal would had been subtracted during a join operation,
        whereas portions of the signal itself may had been multiplied by some
        appropriate factors in order to make the final profile smooth).

      - 'originalZenithAngles': Numpy array holding zenith angles of the lidar
        measurements, in degrees, at the given height, before a signal joining
        operation, or 'None' if this signal is not a joined one.
      - 'originalSignal': Numpy array holding raw lidar signal values before a
        signal joining operation, or 'None' if this signal is not a joined one.
      - 'originalAccumulations': Numpy array holding 'accumulaion's of the
        original lidar signals, or 'None' if this signal is not a joined one.
      - 'originalBackgrounds': Numpy array holding 'background's of the
        original lidar signals, or 'None' if this signal is not a joined
        one.

    Attributes that are calculated manually in 'onDataLoaded':
      - 'gridHeightStep': height step of the data grid (see 'gridStep' and
        'gridZenithAngle')."""

    # ---- Class attributes ---------------------------------------------------
    # Name of the database table that holds the data records.
    tableName = 'Specific'

    fields = [
        DataField('localId',   'IDLocal',   'VARCHAR'),
        DataField('latitude',  'Latitude',  'SINGLE' ),
        DataField('longitude', 'Longitude', 'SINGLE' ),
        DataField('altitude',  'Altitude',  'SINGLE' ),
        DataField('startDate', 'StartDate', 'DATE'   ),
        DataField('stopDate',  'StopDate',  'DATE'   ),
        DataField('startTime', 'StartTime', 'TIME'   ),
        DataField('stopTime',  'StopTime',  'TIME'   ),

        DataField('regNumber',    'Reg Number',   'VARCHAR', required = False),
        DataField('wavelength',   'Wavelength',   'SINGLE'  ),
        DataField('polarization', 'Polarization', 'SMALLINT'),

        DataField('gridStep',        'Step',   'SINGLE'  ),
        DataField('gridZenithAngle', 'Zenith', 'SINGLE'  ),
        DataField('firstDataIndex',  'N1',     'SMALLINT'),
        DataField('lastDataIndex',   'N2',     'SMALLINT'),

        DataField('accumulation',  'Accumulation',  'INT'    ),
        DataField('background',    'Background',    'SINGLE' ),
        DataField('nonlinearity',  'Nonlinear',     'SINGLE' ),
        DataField('syncNoise',     'Synchron',      'SINGLE' ),
        DataField('nonSyncNoise',  'NonSynchron',   'SINGLE' ),

        DataField('signal', 'OLEData', 'LONGBINARY', numpy.float32),

        DataField('originalZenithAngles', 'ZenithArray', 'LONGBINARY',
            numpy.float32, required = False),
        DataField('originalSignal', 'OriginalSignalArray', 'LONGBINARY',
            numpy.float32, required = False),
        DataField('originalAccumulations', 'AccumulationArray', 'LONGBINARY',
            numpy.float32, required = False),
        DataField('originalBackgrounds', 'BackgroundArray', 'LONGBINARY',
            numpy.float32, required = False)
    ]

    openDatabaseErrorMessage = (
        'Lidar input file is not a valid Access database or may not be opened '
        'for reading')
    queryFailureErrorMessage = (
        'Lidar input: failed to query data from the database file')
    invalidFormatErrorPrefix = (
        'Lidar input file format is invalid: ')

    # ---- Public methods -----------------------------------------------------
    def getStartDateTime(self):

        if self.startDate is None or self.startTime is None:
            return None
        else:
            return datetime.datetime.combine(self.startDate, self.startTime)

    def getStopDateTime(self):

        if self.stopDate is None or self.stopTime is None:
            return None
        else:
            return datetime.datetime.combine(self.stopDate, self.stopTime)

    def getZenithAngles(self):
        """Return actual zenith angles for the original lidar measurements."""

        return self.getOriginalProfile(
            self.originalZenithAngles, self.gridZenithAngle)

    def getAccumulations(self):
        """Return actual accumulations for the original lidar measurements."""

        return self.getOriginalProfile(
            self.originalAccumulations, self.accumulation)

    def getBackgrounds(self):
        """Return actual backgrounds for the original lidar measurements."""

        return self.getOriginalProfile(
            self.originalBackgrounds, self.background)

    def getRawSignal(self):
        """Return raw signal values for the original lidar measurements."""

        if self.originalSignal is not None:
            return self.originalSignal
        else:
            return self.signal

    def getChannelId(self):
        """Return the lidar channel's textual identifier for a retrieval
        algorithm or 'None' if this lidar data may not be used with any of the
        retrieval algorithms."""

        if self.wavelength == 355.0 and self.polarization in (0, 3):
            return '355'
        if self.wavelength == 355.0 and self.polarization == 2:
            return '355C'
        if self.wavelength == 387.0 and self.polarization == 1:
            return '387R'
        elif self.wavelength == 532.0 and self.polarization in (0, 3):
            return '532'
        if self.wavelength == 607.0 and self.polarization == 1:
            return '607R'
        elif self.wavelength == 1064.0 and self.polarization in (0, 3):
            return '1064'
        elif self.wavelength == 532.0 and self.polarization == 2:
            return '532C'

    def getErrorMessage(self):
        """Return an error message describing the malformed data or 'None' if
        the attribute values are plausible and consistent."""

        # Check if some of the values were missing in the database.
        for name in self.getRequiredAttributeNames():
            if getattr(self, name) is None:
                return ('%s database field contains no data' %
                    txt.quote(self.getFieldName(name)))

        if self.getStartDateTime() >= self.getStopDateTime():
            return ('%s/%s and %s/%s database field values are inconsistent' %
                (txt.quote(self.getFieldName('startDate')),
                txt.quote(self.getFieldName('startTime')),
                txt.quote(self.getFieldName('stopDate')),
                txt.quote(self.getFieldName('stopTime'))))

        if not (-90.0 <= self.latitude <= 90.0):
            return ('%s database field value is out of range' %
                txt.quote(self.getFieldName('latitude')))

        if not (-180.0 <= self.longitude <= 180.0):
            return ('%s database field value is out of range' %
                txt.quote(self.getFieldName('longitude')))

        # if self.regNumber == '':
            # return ('%s database field value is an empty string' %
                # txt.quote(self.getFieldName('regNumber')))

        if self.wavelength <= 0.0:
            return ('%s database field value is less than or equal to zero' %
                txt.quote(self.getFieldName('wavelength')))

        if self.polarization not in [0, 1, 2, 3]:
            return ('%s database field has an inadmissible value' %
                txt.quote(self.getFieldName('polarization')))

        if self.gridStep <= 0.0:
            return ('%s database field value is less than or equal to zero' %
                txt.quote(self.getFieldName('gridStep')))

        if not (0.0 <= self.gridZenithAngle < 90.0):
            return ('%s database field value is out of range' %
                txt.quote(self.getFieldName('gridZenithAngle')))

        if not (0 <= self.firstDataIndex < self.lastDataIndex):
            return ('%s and %s database field values are inconsistent' %
                (txt.quote(self.getFieldName('firstDataIndex')),
                txt.quote(self.getFieldName('lastDataIndex'))))

        if self.accumulation <= 0:
            return ('%s database field value is less than or equal to zero' %
                txt.quote(self.getFieldName('accumulation')))

        for name in ('nonlinearity', 'syncNoise', 'nonSyncNoise'):
            if getattr(self, name) < 0.0:
                return ('%s database field value is less than zero' %
                    txt.quote(self.getFieldName(name)))

        dataArrayLength = self.lastDataIndex - self.firstDataIndex + 1

        # Check sizes of the arrays.
        for name in ('signal', 'originalZenithAngles', 'originalSignal',
            'originalAccumulations', 'originalBackgrounds'):

            arrayValue = getattr(self, name)
            if arrayValue is not None and len(arrayValue) != dataArrayLength:

                valueBytes = len(arrayValue) * arrayValue.itemsize
                requiredBytes = dataArrayLength * arrayValue.itemsize

                return (
                    'size of the %s database field is %s bytes instead of %s' %
                    (txt.quote(self.getFieldName(name)),
                    txt.quoteNumber(valueBytes),
                    txt.quoteNumber(requiredBytes)))

        # Check constraints for the 'originalZenithAngles' array.
        if self.originalZenithAngles is not None:

            # Use simultaneous comparison to speed up the process.
            if (numpy.any(self.originalZenithAngles < 0) or
                numpy.any(self.originalZenithAngles >= 90.0)):
                return ('one of the items of the %s array is out of range' %
                    txt.quote(self.getFieldName('originalZenithAngles')))

    # ---- Private overridden methods -----------------------------------------
    def onDataLoaded(self):

        # Calculate the vertical height step of the data grid.
        if self.gridStep is None or self.gridZenithAngle is None:
            self.gridHeigthStep = None
        else:
            self.gridHeightStep = (self.gridStep *
                math.cos(self.gridZenithAngle * math.pi / 180.0))

        if self.regNumber is None:
            self.regNumber = ''

    # ---- Private methods ----------------------------------------------------
    def getOriginalProfile(self, originalProfile, scalarValue):
        """Return a Numpy array holding original signal data regardless of
        whether 'self' is a raw lidar signal or a joined one. 'originalProfile'
        has to be the array used to store the data in joined signals, and
        'scalarValue' has to be the value of the scalar database field used to
        store the same data in normal measurements."""

        # For a joined signal, some of its attributes may change with height.
        if originalProfile is not None:
            return originalProfile

        # Return 'None' for malformed data.
        if (scalarValue is None or
            self.firstDataIndex is None or self.lastDataIndex is None):
            return None

        # Attributes of normal signals always remain the same.
        dataArrayLength = self.lastDataIndex - self.firstDataIndex + 1
        return numpy.array([scalarValue] * dataArrayLength)
