# Copyright (c) 2011-2014, B.I.Stepanov Institute of Physics, National Academy
# of Sciences of Belarus.
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

import h5py
import numpy
import os.path
import re

from PyQt4.QtCore import *

from common.utils import txt

__all__ = ['MatlabProcess']

# *****************************************************************************
class MatlabProcess(QObject):
    """Base class for wrappers around standalone optimization procedures
    implemented in Matlab language.

    Call 'startRetrieval' to launch the Matlab process; use 'getErrorMessage'
    and 'getRetrievalOutput' to get the retrieval results upon completion.

    Signals:
      - 'retrievalFinished': the retrieval process has either successfully
        completed or terminated with an error.
      - 'outputAvailable(outputLine)': new line of text has been printed by the
        Malab application to either standard output or standard error stream.
      - 'iterationCompleted(iterCount)': a new iteration has been completed by
        the optimization process. 'iterCount' is the number of the iteration,
        starting with 0 for the initial approximation.

    Matlab scripts (named '<program-name>.m') are assumed to be stored in the
    'code/binary/MATLAB' subdirectory of the program package root, so that
    compiled Matlab executables are placed in
    'code/binary/MATLAB/<program-name>/distrib'.

    Auxiliary data files in HDF5 format (with a predefined structure) are
    used to pass input data as well as optimization results (including the
    whole set of intermediate data that realized during the steps of the
    iterative process) to and from the Matlab program.

    Data type used to represent the retrieval results is not predetermined;
    the value returned by 'readOutputData' actually defines it.

    Redefine 'getMatlabProgramName' method to specify the name of the Matlab
    application. Redefine 'writeInputData' along with 'readOutputData' to
    implement the procedure of passing the data to and from the auxiliary data
    file. Use 'extendArray', 'readArray', and 'readScalar' for assistance in
    the process of data preparation and extraction, if required."""

    # ---- Signals ------------------------------------------------------------
    retrievalFinished = pyqtSignal()
    outputAvailable = pyqtSignal(QString)
    iterationCompleted = pyqtSignal(int)

    # ---- Public overridable methods -----------------------------------------
    def __init__(self):
        QObject.__init__(self)

        matlabDir = os.path.join('code/binary/MATLAB',
            self.getMatlabProgramName(), 'distrib')

        self.applicationFilePath = os.path.join(matlabDir,
            self.getMatlabProgramName() + '.exe')
        self.dataFilePath = os.path.join(matlabDir, 'data.h5')

        self.matlabProcess = QProcess()

        self.matlabProcess.setProcessChannelMode(QProcess.MergedChannels)

        self.matlabProcess.readyRead.connect(self.onOutputAvailable)
        self.matlabProcess.error.connect(self.onProcessError)
        self.matlabProcess.finished.connect(self.onProcessFinished)

        self.errorMessage = None
        # This will be a list of data instances (whose format would be specific
        # to the particular optimization procedure) describing results of the
        # optimization process obtained at each of the successive iterations.
        self.retrievalOutputs = None

    # ---- Public methods -----------------------------------------------------
    def startRetrieval(self, errorMessage = None):
        """Launch the optimization process and return immediately.

        If 'errorMessage' is not 'None', the optimization process won't start,
        and 'retrievalFinished' signal will be fired as soon as program control
        returns to the application event loop (with 'getErrorMessage' returning
        the specified error message). The same thing will happen (but with a
        different error message) if problems would arise during initialization
        of the auxiliary data file."""

        #!!!
        # QTimer.singleShot(0, lambda: self.onProcessFinished(0, 0))
        # return

        if errorMessage is not None:
            self.errorMessage = errorMessage
            # Report the retrieval as a failed one.
            QTimer.singleShot(0, self.onInitFailed)
            return

        try:
            # Write the algorithm input data to the intermediate file.
            self.writeMatlabInput()

        except txt.Error as e:
            self.errorMessage = e.text
            # Retrieval is not possible. Report it as a failed one.
            QTimer.singleShot(0, self.onInitFailed)
            return

        # Start the Matlab process. All the related signals have been already
        # connected in '__init__'.
        self.matlabProcess.setWorkingDirectory(
            os.path.dirname(self.dataFilePath))
        self.matlabProcess.start(self.applicationFilePath)

    def cancelRetrieval(self):
        """Stop the running retrieval process or do nothing if the process has
        not been started yet."""

        # Don't fire signals, as this is neither success nor failure.
        self.matlabProcess.error.disconnect(self.onProcessError)
        self.matlabProcess.finished.disconnect(self.onProcessFinished)
        # Terminate the process if it's running and still alive.
        self.matlabProcess.kill()

    def getErrorMessage(self):
        """Return a rich text message describing a failed retrieval or 'None'
        if the retrieval has succeeded or has not finished yet."""

        return self.errorMessage

    def getRetrievalOutput(self, iteration = -1):
        """Return a data instance (specific to the particular optimization
        procedure) describing a successful retrieval, or 'None' if the
        retrieval has failed or has not finished yet.

        If 'iteration' is not '-1', optimization process snapshot at the given
        iteration is returned instead of the final retrieval results."""

        if self.retrievalOutputs is None:
            return None

        return self.retrievalOutputs[iteration]

    def getAllRetrievalOutputs(self):
        """Return the complete list of the output data instances (specific to
        the particular optimization procedure), representing optimization
        process snapshots for all the iterations."""

        return self.retrievalOutputs

    # ---- Protected overridable methods --------------------------------------
    def getMatlabProgramName(self):
        """Return the name of the Matlab script performing the optimization
        procedure (without the '.m' extension)."""
        return None

    def writeInputData(self, hdfGroup):
        """Write data required by the optimization process to the 'input'
        subgroup of the auxiliary HDF5 file.

        The 'input' HDF group (passed as the argument) is opened for writing.
        Apart from the name of the group, there are no restrictions imposed on
        the format of the data."""
        pass

    def readOutputData(self, hdfGroup, iteration, errorPrefix):
        """Extract data describing the given iteration of the optimization
        process from the 'output' subgroup of the auxiliary data file,
        returning a proper data instance.

        If the format of the data file is invalid, raise 'txt.Error' exception
        using the given error prefix.

        The 'output' HDF group (passed as the argument) is opened for reading.
        This group must contain a scalar integer dataset named 'iterCount'
        containing the total number of iterations (including the initial
        appoximation, i.e. the zeroth one) that have been realized in the
        optimization process.

        Apart from the name of the output group and the 'iterCount' dataset,
        there are no restrictions imposed on the format of the data.
        Nevertheless, these data should contain information about any of the
        iterations of the optimization process (normally, that should be
        achieved by adding an extra dimension to the respective HDF
        datasets)."""
        return None

    # ---- Protected methods --------------------------------------------------
    @staticmethod
    def extendArray(dataArray, firstDataIndex, extendedSize):
        """Return a copy of the given Numpy array converted to 'float64' data
        type, padded with 'firstDataIndex' copies of its first element at the
        beginning and with appropriate number of zeros at the end, so that the
        final length of the array is 'extendedSize'."""

        assert extendedSize >= firstDataIndex + len(dataArray)
        assert len(dataArray) > 0

        # Construct an array of zeros of the required length.
        extArray = numpy.zeros(extendedSize, numpy.float64)

        # Fill in the data.
        extArray[firstDataIndex : firstDataIndex + len(dataArray)] = dataArray
        # Replace initial zeros with copies of the initial array's first
        # element.
        extArray[0 : firstDataIndex] = dataArray[0]

        return extArray

    @staticmethod
    def readArray(hdfGroup, datasetName, arrayIndex, iteration, errorPrefix,
        arraySize = None):
        """Return a Numpy array with 'float32' data type corresponding to the
        'arrayIndex' row and 'iteration' layer (3-rd dimension) of an
        'hdfGroup's 3-dimensional dataset with the given name.

        If 'arraySize' is not 'None', return at most 'arraySize' first elements
        of the dataset row only.

        Raise 'txt.Error' using the given error prefix on failure."""

        try:
            dataset = hdfGroup[datasetName]
        except KeyError:
            raise txt.Error(errorPrefix + 'there is no dataset named %s' %
                txt.quote(datasetName))

        try:
            # Assume that 'dataset' is 3-dimensional.
            arrayData = dataset[iteration, arrayIndex]

        except ValueError:
            raise txt.Error(errorPrefix + '%s dataset has invalid shape' %
                txt.quote(datasetName))

        dataArray = numpy.array(arrayData, dtype = numpy.float32)
        if arraySize is not None and arraySize < len(dataArray):
            dataArray = numpy.resize(dataArray, (arraySize,))

        return dataArray

    @staticmethod
    def readScalar(hdfGroup, datasetName, arrayIndex, iteration, errorPrefix):
        """Same as 'readArray', but assume that the resulting array length is 1
        and return a scalar floating-point value representing that single array
        element.

        Raise 'txt.Error' using the given error prefix if the array has an
        unexpected shape."""

        data = MatlabProcess.readArray(hdfGroup, datasetName, arrayIndex,
            iteration, errorPrefix)

        # Check if the value of the row is a scalar.
        if data.shape != (1, ):
            raise txt.Error(errorPrefix + '%s dataset has invalid shape' %
                txt.quote(datasetName))

        # Convert to the native Python data type.
        return float(data)

    # ---- Private methods ----------------------------------------------------
    def writeMatlabInput(self):
        """Prepare input file for the Matlab application."""

        try:
            # 'w' means 'create the file; truncate if it already exists'.
            hdfFile = h5py.File(self.dataFilePath, 'w')

            try:
                inputGroup = hdfFile.create_group('input')

                self.writeInputData(inputGroup)

            finally:
                hdfFile.close()

        except IOError:
            raise txt.Error('Failed to write the Matlab input file (%s)' %
                txt.quotePath(self.dataFilePath))

    def readMatlabOutput(self):
        """Construct a list of the output data instances representing snapshots
        of the optimization process at different iterations, based on the
        Matlab application output file.

        Final retrieval results will normally be represented by the last
        element of the returned list."""

        retrievalOutputs = []

        # Open the Matlab output file.
        try:
            hdfFile = h5py.File(self.dataFilePath, 'r')

        except IOError:
            raise txt.Error('Failed to read the Matlab output file (%s)' %
                txt.quotePath(self.dataFilePath))

        try:
            invalidFormatPrefix = ('Failed to read the Matlab output file '
                '(%s): ' % txt.quotePath(self.dataFilePath))

            # Open the 'output' HDF group used to store retrieval output data.
            try:
                outputGroup = hdfFile['output']

            except KeyError:
                raise txt.Error(invalidFormatPrefix +
                    'there is no group named %s' % txt.quote('output'))

            # Read the iteration count from the 'iterCount' scalar dataset.
            try:
                iterDataset = outputGroup['iterCount']
            except KeyError:
                raise txt.Error(invalidFormatPrefix +
                    'there is no dataset named %s' % txt.quote('iterCount'))

            iterArray = numpy.array(iterDataset)
            if iterArray.shape != (1, ):
                raise txt.Error(invalidFormatPrefix +
                    '%s dataset has invalid shape' % txt.quote('iterCount'))

            iterCount = int(iterArray)

            # Fill in the data separately for each of the iterations.
            for iteration in range(iterCount):

                retrievalOutputs.append(self.readOutputData(
                    outputGroup, iteration, invalidFormatPrefix))

        finally:
            hdfFile.close()

        return retrievalOutputs

    # ---- Private slots ------------------------------------------------------
    def onInitFailed(self):

        self.retrievalFinished.emit()

    def onOutputAvailable(self):

        while self.matlabProcess.canReadLine():
            # Use 'QString' to convert 'QByteArray' into a Unicode string.
            outputLine = QString(self.matlabProcess.readLine())

            # Don't display the warning about locale mismatch.
            if outputLine.startsWith('MATLAB:I18n:InconsistentLocale'):
                continue

            self.outputAvailable.emit(outputLine)

            # Parse output text describing a completed iteration that is
            # normally generated by Matlab's 'lsqnonlin' function.
            matchObj = re.match(r'\s+(\d+)\s+\d+\s+\d\.\d+e[+-]\d+\s+',
                outputLine)
            if matchObj is not None:
                self.iterationCompleted.emit(int(matchObj.group(1)))

    def onProcessError(self, processError):

        if processError == QProcess.FailedToStart:
            self.errorMessage = (
                'Failed to start the retrieval algorithm: %s' %
                txt.quotePath(self.applicationFilePath))
        else:
            self.errorMessage = (
                'Failed to run the retrieval algorithm: %s' %
                txt.quotePath(self.applicationFilePath))

        self.retrievalFinished.emit()

    def onProcessFinished(self, exitCode, exitStatus):

        # In a case of internal process error, 'onProcessError' may be called
        # prior to 'onProcessFinished'. In such a case, 'retrievalFinished'
        # signal is emitted by the former, whereas this method does nothing.
        if self.errorMessage is not None:
            return

        try:
            # Read the algorithm output data from the intermediate file.
            self.retrievalOutputs = self.readMatlabOutput()

        except txt.Error as e:
            self.errorMessage = e.text

        self.retrievalFinished.emit()
