# Copyright (c) 2011-2014, B.I.Stepanov Institute of Physics, National Academy
# of Sciences of Belarus.
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

from PyQt4.QtCore import *
from PyQt4.QtGui import *

from common.utils import gui
from common.utils import txt
from common.utils import utils

from common.DataTableWidget import *
from common.StatusSignalingWidget import *

from PlotWidgets import *

__all__ = ['LidarTableWidget', 'LidarChannelInfo']

# *****************************************************************************
class LidarChannelInfo:
    """Description of a lidar input channel of an aerosol retrieval algorithm.

    Attributes:
      - 'wavelength': nominal lidar signal wavelength in nanometers.
      - 'polarizations': a tuple of polarization IDs admissible for this lidar
        channel (see 'LidarInput.polarization'). Normally this should contain
        just a single element. The only combination of several polarization IDs
        for a single lidar channel that is currently acceptable is '(0, 3)',
        which means that the algorithm is capable of processing both
        unpolarized and parallel-polarized lidar signals in a similar
        manner."""

    def __init__(self, wavelength, polarizations):
        """It's possible to pass an ordinary integer value for 'polarizations'
        (which will be equivalent to a singleton tuple)."""

        self.wavelength = wavelength

        # Convert integers to singleton tuples.
        if isinstance(polarizations, int):
            polarizations = (polarizations, )

        # Make sure that all the polarization IDs are valid.
        assert isinstance(polarizations, tuple)
        assert all(polarization in (0, 1, 2, 3)
            for polarization in polarizations)

        self.polarizations = polarizations

    def compatibleWith(self, other):
        """Check if 'self' and 'other' lidar channels may be used together as
        input channels in a single aerosol retrieval algorithm."""

        # Check that the lidar channels are not duplicate.
        return (self.wavelength != other.wavelength or
            len(set(self.polarizations) & set(other.polarizations)) == 0)

    def matchesLidarInput(self, lidarInput):
        """Check if the given 'LidarInput' instance is acceptable as input
        data for lidar channel described by 'self'."""

        return (lidarInput.wavelength == self.wavelength and
            lidarInput.polarization in self.polarizations)

    def getErrorMessageForLidarData(self, lidarInputList):
        """Check if the given list of 'LidarInput' instances contains exactly
        one lidar measurement matching the lidar channel described by 'self'.

        Return 'None' if the checking succeeds and an appropriate error string
        otherwise."""

        # The number of matching measurements within 'lidarInputList'.
        matchCount = len([lidarInput for lidarInput in lidarInputList
            if self.matchesLidarInput(lidarInput)])

        if matchCount == 1:
            return None

        # Use the same precision for wavelengths here and in the lidar table
        # of a 'LidarTableWidget'.
        wavelengthStr = txt.quote('%.1f nm' % self.wavelength,
            addQuotes = False)

        if matchCount == 0:
            if self.polarizations in ((0, ), (0, 3)):
                channelIdStr = 'Unpolarized lidar measurement'
            elif self.polarizations == (1, ):
                channelIdStr = 'Raman lidar measurement'
            elif self.polarizations == (2, ):
                channelIdStr = 'Cross-polarized lidar measurement'
            elif self.polarizations == (3, ):
                channelIdStr = 'Parallel-polarized lidar measurement'

            return '%s at %s is not selected' % (channelIdStr, wavelengthStr)

        else:
            if self.polarizations in ((0, ), (0, 3)):
                channelIdStr = 'unpolarized lidar measurements'
            elif self.polarizations == (1, ):
                channelIdStr = 'Raman lidar measurements'
            elif self.polarizations == (2, ):
                channelIdStr = 'cross-polarized lidar measurements'
            elif self.polarizations == (3, ):
                channelIdStr = 'parallel-polarized lidar measurements'

            return 'Too many %s are selected for %s' % (
                channelIdStr, wavelengthStr)

# *****************************************************************************
class LidarTableWidget(StatusSignalingWidget):
    """Widget responsible for selection of lidar input data for an aerosol
    retrieval algorithm.

    This widget is also responsible for checking if a set of measurements
    selected by the user is suitable for a retrieval. It is assumed that lidar
    channels to be selected must correspond exactly to one of the fixed sets
    of 'LidarChannelInfo' instances specified during the initialization of the
    widget (see 'addAllowedChannelSequence'). If measurements selected by the
    user exactly match one of the allowed channel sets, then error status of
    the widget is cleared. Otherwise, an appropriate error message is set with
    'setErrorMessage'."""

    # ---- Signals ------------------------------------------------------------
    # This is fired along with 'statusChanged' any time when the set of lidar
    # data records selected by the user changes.
    selectionChanged = pyqtSignal()

    # ---- Public methods -----------------------------------------------------
    def __init__(self, parent = None):
        StatusSignalingWidget.__init__(self, parent)

        # ---- Members ----
        # This will be a list of 'LidarInput' instances.
        self.dataList = []

        # This will be a list of sequences of 'LidarChannelInfo' instances.
        self.allowedChannelSequences = []
        # This will be a list of attributes of 'LidarInput's that must be
        # different from 'None' in order for the measurements to be considered
        # valid.
        self.requiredAttributes = []

        # ---- Layout ----
        layout = QVBoxLayout()

        self.table = DataTableWidget()
        # Allow multiple row selection for lidar signals.
        self.table.setSelectionMode(QAbstractItemView.MultiSelection)
        self.table.itemSelectionChanged.connect(self.onSelectionChanged)

        self.signalPlot = LidarPlotWidget()

        self.splitter = QSplitter(Qt.Horizontal)
        self.splitter.setChildrenCollapsible(False)
        self.splitter.addWidget(self.table)
        self.splitter.addWidget(self.signalPlot)
        layout.addWidget(self.splitter)

        layout.setContentsMargins(0, 0, 0, 0)
        self.setLayout(layout)

        # ---- Initialization -------------------------------------------------
        self.table.setColumns([
            TableWidgetColumn('startDate', 'Date', 'date',
                'Start date of the lidar measurement'),
            TableWidgetColumn('startTime', 'TStart', 'time',
                'Start time of the lidar measurement'),
            TableWidgetColumn('stopTime', 'TStop', 'time',
                'Stop time of the lidar measurement'),
            TableWidgetColumn('wavelength', 'Wave', '%.1f',
                'Lidar channel wavelength (nm)'),
            TableWidgetColumn('polarization', 'Polar', '%d',
                'Lidar channel polarization ID'),
            TableWidgetColumn('gridHeightStep', 'HStep', '%.1f',
                'Height step of the lidar data grid (m)'),
            TableWidgetColumn('firstInputIndex', 'Left', '%d',
                'First index of the data section prepared for retrieval'),
            TableWidgetColumn('lastInputIndex', 'Right', '%d',
                'Last index of the data section prepared for retrieval'),
            TableWidgetColumn('refPointIndex', 'RP', '%d',
                'Reference point index'),
            TableWidgetColumn('localId', 'Local ID', '%s',
                'Lidar measurement textual identifier')
        ])

    def addAllowedChannelSequence(self, channelSequence):
        """Specify a set of lidar channels that should be considered as a valid
        input to an aerosol retrieval algorithm when selected by the user.

        Call this function one or more times immediately after the constructor.
        'channelSequence' has to be an ordered sequence of 'LidarChannelInfo'
        instances. The order of lidar channels specified here will be used to
        sort measurements returned by 'getData'. If several sets of lidar
        channels are specified by means of successife calls to this function,
        then any of these would be available for selection by the user."""

        for i in range(len(channelSequence)):
            assert isinstance(channelSequence[i], LidarChannelInfo)
            # Make sure that there are no coinciding channels in the list.
            for j in range(i):
                assert channelSequence[i].compatibleWith(channelSequence[j])

        self.allowedChannelSequences.append(channelSequence)

    def addRequiredAttribute(self, attributeName):
        """Specify the name of a 'LidarInput' attribute that must be different
        from 'None' in order for the lidar measurement to be considered
        valid."""

        self.requiredAttributes.append(attributeName)

    def getTableWidget(self):
        return self.table

    def saveSettings(self, settings):
        with utils.SettingsGrouper(settings, 'LidarTableWidget'):
            settings.setValue('splitterState', self.splitter.saveState())

    def loadSettings(self, settings):
        with utils.SettingsGrouper(settings, 'LidarTableWidget'):

            if settings.contains('splitterState'):
                self.splitter.restoreState(
                    settings.value('splitterState').toByteArray())
            else:
                self.splitter.setSizes([1000, 1000])

    def connectDataList(self, lidarDataList):

        # Suspend the selection changed signal while the table cleans up.
        with gui.SignalBlocker(self.table):
            # Clear the table widget.
            self.table.clearTable()

        # Clear the plot widget.
        self.signalPlot.clear()

        self.dataList = lidarDataList

        if len(self.dataList) == 0:
            self.setErrorMessage(
                'Lidar input database contains no records processed with '
                'TropoExport tool')
            return

        # From this time on, no error checking is performed.
        self.table.populateTable(self.dataList)

        # Disable rows that may not be used in the retrieval.
        for i in range(len(self.dataList)):

            rowIsAcceptable = False

            for channelSequence in self.allowedChannelSequences:
                if any(lidarChannel.matchesLidarInput(self.dataList[i])
                    for lidarChannel in channelSequence):

                    rowIsAcceptable = True
                    break

            if not rowIsAcceptable:
                self.table.disableRow(i)

        # Show the error message reporting lack of selected measurements and
        # clear the plot widget.
        self.onSelectionChanged()

    def getData(self):
        """Return a list of the currently selected 'LidarInput' instances, in
        a predefined order, that is suitable for an aerosol profile retrieval,
        or 'None' if the retrieval is not possible or ambiguous for the
        currently selected lidar data."""

        if self.getStatusMessage() is not None:
            return None

        selectedData = self.getSelectedData()

        matchingSequence = None

        for channelSequence in self.allowedChannelSequences:

            # One-to-one matching between selected and required sets of lidar
            # channels holds when each of the required lidar channels is
            # selected exactly once and there are no superfluous selections.
            if (len(selectedData) == len(channelSequence) and
                all(lidarChannel.getErrorMessageForLidarData(
                selectedData) is None for lidarChannel in channelSequence)):

                matchingSequence = channelSequence
                break

        # If there is no matching sequence, an error message has to be set.
        assert matchingSequence is not None

        # 'selectedData', sorted in accordance with 'matchingSequence'.
        orderedData = []

        for lidarChannel in matchingSequence:
            # For each of the lidar channels, there is exactly one measurement
            # that matches it.
            for lidarInput in selectedData:
                if lidarChannel.matchesLidarInput(lidarInput):
                    orderedData.append(lidarInput)
                    break

        # Just to feel safe.
        assert len(orderedData) == len(selectedData) == len(matchingSequence)
        return orderedData

    def getSelectedData(self):
        """Return a list of the currently selected 'LidarInput' instances,
        without any constraint checkings (unlike 'getData'), and unsorted."""

        selectedRows = self.table.getSelectedRows()

        return [self.dataList[row] for row in selectedRows]

    # ---- Private overridden methods -----------------------------------------
    def setStatusMessage(self, statusMessage):

        StatusSignalingWidget.setStatusMessage(self, statusMessage)

        # In this widget, status manipulation methods (either 'setErrorMessage'
        # or 'clearStatusMessage') are always called in sync with changes in
        # selection of the lidar data records (including the case of clearing
        # the selection upon loading of a new lidar database file).
        self.selectionChanged.emit()

    # ---- Private slots ------------------------------------------------------
    def onSelectionChanged(self):
        """Check if a complete and unambigous set of lidar measurements is
        currently selected in the table widget, update the lidar signal plots
        and set the error message appropriately."""

        # Assure that the table contains data, so that error message set in
        # 'connectDataList' won't get overwritten.
        assert len(self.dataList) == self.table.rowCount() > 0

        selectedData = self.getSelectedData()

        if len(selectedData) == 0:
            self.signalPlot.clear()
            self.setErrorMessage('Lidar measurements are not selected')
            return

        # Plot the selected signals even if there are errors in some of them.
        plottingData = []

        for lidarInput in selectedData:
            # Do not touch malformed data.
            if lidarInput.getErrorMessage() is None:
                plottingData.append(lidarInput)

        # Update the plot widget.
        self.signalPlot.plotData(plottingData)

        # Check that all the selected lidar measurements are valid.
        for lidarInput in selectedData:
            errorPrefix = 'One of the selected lidar measurements is invalid: '

            errorMessage = lidarInput.getErrorMessage()
            if errorMessage is not None:
                self.setErrorMessage(errorPrefix + errorMessage)
                return

            for attributeName in self.requiredAttributes:
                if getattr(lidarInput, attributeName) is None:
                    self.setErrorMessage(errorPrefix +
                        '%s database field is contains no data' %
                        txt.quote(lidarInput.getFieldName(attributeName)))
                    return

        # Check that lidar signal weighting coefficients (which are inverse
        # to the signal dispersion) are not infinite.
        for lidarInput in selectedData:
            if any(lidarInput.lidarDispersion == 0.0):
                errorMessage = ('Dispersion of one of the selected lidar '
                    'signals is zero at some height')
                # The most common cause of zero dispersion is zero height above
                # the measurement point, caused by zero index of the data grid.
                if lidarInput.firstInputIndex == 0:
                    errorMessage += ('. Try modifying the %s boundary '
                        'of the signal' % txt.quote('Left'))

                self.setErrorMessage(errorMessage)
                return

        # Check that all the selected lidar profiles share the same grid.
        gridSteps = [lidarInput.gridHeightStep for lidarInput in selectedData]

        if any(gridStep != gridSteps[0] for gridStep in gridSteps):
            self.setErrorMessage('Grid height steps of the selected lidar '
                'measurements are inconsistent')
            return

        # Check that geodetic coordinates of the measurements are the same.
        latitudes = [lidarInput.latitude for lidarInput in selectedData]
        longitudes = [lidarInput.longitude for lidarInput in selectedData]
        altitudes = [lidarInput.altitude for lidarInput in selectedData]

        if (any(latitude != latitudes[0] for latitude  in latitudes) or
            any(longitude != longitudes[0] for longitude in longitudes) or
            any(altitude != altitudes[0] for altitude in altitudes)):

            self.setErrorMessage('Geodetic coordinates of the selected lidar '
                'measurements are inconsistent')
            return

        # Check that atmosphere models of the measurements are the same.
        atmoModels = [lidarInput.atmoModel for lidarInput in selectedData]

        if any(atmoModel != atmoModels[0] for atmoModel in atmoModels):
            self.setErrorMessage('Atmosphere models used in the selected '
                'lidar measurements are not the same')
            return

        # Check if there's a predefined set of lidar channels that is perfectly
        # matched by the currently selected lidar measurements.

        minErrorMessageList = None

        for channelSequence in self.allowedChannelSequences:

            errorMessages = [lidarChannel.getErrorMessageForLidarData(
                selectedData) for lidarChannel in channelSequence]

            # One-to-one matching between selected and required sets of lidar
            # channels holds when each of the required lidar channels is
            # selected exactly once and there are no superfluous selections.
            if (len(selectedData) == len(channelSequence) and
                all(message is None for message in errorMessages)):

                # If there's a one-to-one matching, we're done.
                self.clearStatusMessage()
                return

            else:
                # Actual list of error messages for the current channel set.
                errorMessageList = [message for message in errorMessages
                    if message is not None]

                # Update the minimal error message list (i.e. the list of
                # errors for the best-fitting channel set).
                if (minErrorMessageList is None or
                    len(errorMessageList) < len(minErrorMessageList)):
                    minErrorMessageList = errorMessageList

        # Display to the user the first error for the best-fitting channel.
        # If 'minErrorMessageList' is empty, then each of the lidar channels
        # was selected exactly once for each of the channel sets, but there
        # were superfluous selections for each of theese sets.
        if minErrorMessageList is not None and len(minErrorMessageList) > 0:
            self.setErrorMessage(minErrorMessageList[0])
        else:
            self.setErrorMessage('Too many lidar measurements are selected')
