# Copyright (c) 2011-2014, B.I.Stepanov Institute of Physics, National Academy
# of Sciences of Belarus.
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

import datetime
import numpy

from LiricOutput import *
from ManualParams import *
from MatlabProcess import *
from PhotometerInput import *
from PreparedLidarInput import *

__all__ = ['LiricRetrievalProcess']

# *****************************************************************************
class LiricRetrievalProcess(MatlabProcess):
    """Wrapper around a standalone Matlab application that implements aerosol
    profile retrieval algorithm utilizing combined lidar and CIMEL Sun
    photometer data.

    Attributes:
      - 'lidarInputs': a list of 'PreparedLidarInput' instances representing 3
        or 4 lidar measurements, in a predefined order: unpolarized (or
        parallel-polarized) measurements at 355, 532, and 1064 nm, and
        optionally a cross-polarized measurement at 532 nm.
      - 'photometerInput': a 'PhotometerInput' istance.
      - 'algorithmParams': a 'ManualParams' instance.

      - 'profileInitFine', 'profileInitCoarse', 'profileInitSpherical',
        'profileInitSpheroid': initial approximations to the aerosol profiles
        to be retrieved, or 'None's if default values should be used instead.
      - 'ratiosInit': initial approximations to backscatter ratios at the
        reference points (a list of 3 or 4 numbers, depending on the size of
        'lidarInputs'), or 'None' if default values should be used instead.

      - 'hideIntermediateOutputs': 'True', if 'getAllRetrievalOutputs' method
        of this class should return only the last retrieval output instance
        (similar to 'getRetrievalOutput()'). This is 'False' by default."""

    # ---- Public overridden methods ------------------------------------------
    def __init__(self, lidarInputs, photometerInput, algorithmParams):

        MatlabProcess.__init__(self)

        # ---- Members ----
        self.lidarInputs = lidarInputs
        self.photometerInput = photometerInput
        self.algorithmParams = algorithmParams

        self.profileInitFine = None
        self.profileInitCoarse = None
        self.profileInitSpherical = None
        self.profileInitSpheroid = None

        self.ratiosInit = None

        self.hideIntermediateOutputs = False

    # ---- Public methods -----------------------------------------------------
    def setInitDataSimplified(self, profileInitFine, profileInitCoarse,
        ratiosInit):

        assert len(ratiosInit) == len(self.lidarInputs) == 3

        self.profileInitFine = profileInitFine
        self.profileInitCoarse = profileInitCoarse

        self.ratiosInit = ratiosInit

    def setInitDataPolarimetric(self, profileInitFine, profileInitSpherical,
        profileInitSpheroid, ratiosInit):

        assert len(ratiosInit) == len(self.lidarInputs) == 4

        self.profileInitFine = profileInitFine
        self.profileInitSpherical = profileInitSpherical
        self.profileInitSpheroid = profileInitSpheroid

        self.ratiosInit = ratiosInit

    def setHideIntermediateOutputs(self, hideIntermediateOutputs):

        self.hideIntermediateOutputs = hideIntermediateOutputs

    # ---- Public overridden methods ------------------------------------------
    def getAllRetrievalOutputs(self):

        if self.hideIntermediateOutputs:
            return [self.getRetrievalOutput()]

        return MatlabProcess.getAllRetrievalOutputs(self)

    # ---- Private overridden methods -----------------------------------------
    def getMatlabProgramName(self):
        return 'liric'

    def writeInputData(self, hdfGroup):
        """Prepare input file for the Matlab application."""

        ###
        # from PreparedLidarInput import PreparedLidarInput
        # PreparedLidarInput.exportDataListToExcel(
            # '__debug/lidarInputs.xls', self.lidarInputs)

        # Just to feel safe.
        assertLidarInputsValid(self.lidarInputs)

        # Number of lidar channels.
        hdfGroup.create_dataset('J', data = [len(self.lidarInputs)]);
        # Number of aerosol modes to be retrieved.
        hdfGroup.create_dataset('M', data = [len(self.lidarInputs) - 1]);

        heightStep = self.lidarInputs[0].gridHeightStep
        assert all(input.gridHeightStep == heightStep
            for input in self.lidarInputs)

        hdfGroup.create_dataset('hStep', data = [heightStep])

        # The number of indices to be stripped from the bottom part of the data
        # grid (and hence from the retrieval process and its results). This is
        # to be done if altitudes of the photometer and lidar measurement sites
        # are diffenent, as the bottom boundary for the lidar equations is
        # determined by location of the photometer, not the lidar.
        #
        # If 'gridDisplacement' is negative, grid gets expanded (i.e.,
        # additional lidar equations are added).
        gridDisplacement = round(
            self.algorithmParams.photometerDisplacement / heightStep)

        assert gridDisplacement <= min(
            input.firstInputIndex for input in self.lidarInputs)

        gridSize = max(input.lastInputIndex + 1 for input in self.lidarInputs)
        gridSize -= gridDisplacement

        if len(self.lidarInputs) == 3:
            # There are 3 input channels and 2 aerosol modes.
            hdfGroup.create_dataset('k', data = [
                self.algorithmParams.weighting355,
                self.algorithmParams.weighting532,
                self.algorithmParams.weighting1064])
            hdfGroup.create_dataset('f', data = [
                self.algorithmParams.weightingFine,
                self.algorithmParams.weightingCoarse])
            hdfGroup.create_dataset('d', data = [
                self.algorithmParams.weightSmoothFine,
                self.algorithmParams.weightSmoothCoarse])

            # Use total or parallel-only backscatter coefficients depending
            # on the actually used input channel.
            hdfGroup.create_dataset('b', data = [[
                getAerosolBackscatter(self.lidarInputs[0],
                    self.photometerInput.bFine355,
                    self.photometerInput.bFine355P),
                getAerosolBackscatter(self.lidarInputs[1],
                    self.photometerInput.bFine532,
                    self.photometerInput.bFine532P),
                getAerosolBackscatter(self.lidarInputs[2],
                    self.photometerInput.bFine1064,
                    self.photometerInput.bFine1064P)], [

                getAerosolBackscatter(self.lidarInputs[0],
                    self.photometerInput.bCoarse355,
                    self.photometerInput.bCoarse355P),
                getAerosolBackscatter(self.lidarInputs[1],
                    self.photometerInput.bCoarse532,
                    self.photometerInput.bCoarse532P),
                getAerosolBackscatter(self.lidarInputs[2],
                    self.photometerInput.bCoarse1064,
                    self.photometerInput.bCoarse1064P)]])

            # Attenuations are the same regardless of polarization.
            hdfGroup.create_dataset('a', data = [[
                self.photometerInput.aFine355,
                self.photometerInput.aFine532,
                self.photometerInput.aFine1064], [

                self.photometerInput.aCoarse355,
                self.photometerInput.aCoarse532,
                self.photometerInput.aCoarse1064]])

            hdfGroup.create_dataset('V', data = [
                self.photometerInput.vFineCorr,
                self.photometerInput.vCoarseCorr])

            hdfGroup.create_dataset('C0', data = [
                numpy.ones(gridSize) * 0.01 if self.profileInitFine is None
                    else self.profileInitFine.astype(numpy.float64),
                numpy.ones(gridSize) * 0.002 if self.profileInitCoarse is None
                    else self.profileInitCoarse.astype(numpy.float64)])

        elif len(self.lidarInputs) == 4:
            # There are 4 input channels and 3 aerosol modes.
            hdfGroup.create_dataset('k', data = [
                self.algorithmParams.weighting355,
                self.algorithmParams.weighting532,
                self.algorithmParams.weighting1064,
                self.algorithmParams.weighting532C])
            hdfGroup.create_dataset('f', data = [
                self.algorithmParams.weightingFine,
                self.algorithmParams.weightingSpherical,
                self.algorithmParams.weightingSpheroid])
            hdfGroup.create_dataset('d', data = [
                self.algorithmParams.weightSmoothFine,
                self.algorithmParams.weightSmoothSpherical,
                self.algorithmParams.weightSmoothSpheroid])

            # Use total, parallel-only, or cross-only backscatter
            # coefficients depending on the input channel.
            hdfGroup.create_dataset('b', data = [[
                getAerosolBackscatter(self.lidarInputs[0],
                    self.photometerInput.bFine355,
                    self.photometerInput.bFine355P),
                getAerosolBackscatter(self.lidarInputs[1],
                    self.photometerInput.bFine532,
                    self.photometerInput.bFine532P),
                getAerosolBackscatter(self.lidarInputs[2],
                    self.photometerInput.bFine1064,
                    self.photometerInput.bFine1064P),
                getAerosolBackscatter(self.lidarInputs[3],
                    self.photometerInput.bFine532,
                    self.photometerInput.bFine532P,
                    self.photometerInput.bFine532C,
                    self.algorithmParams.parallelLeakage532)
                if self.lidarInputs[3].wavelength == 532.0 else
                getAerosolBackscatter(self.lidarInputs[3],
                    self.photometerInput.bFine355,
                    self.photometerInput.bFine355P,
                    self.photometerInput.bFine355C,
                    self.algorithmParams.parallelLeakage532)], [

                getAerosolBackscatter(self.lidarInputs[0],
                    self.photometerInput.bSpherical355,
                    self.photometerInput.bSpherical355P),
                getAerosolBackscatter(self.lidarInputs[1],
                    self.photometerInput.bSpherical532,
                    self.photometerInput.bSpherical532P),
                getAerosolBackscatter(self.lidarInputs[2],
                    self.photometerInput.bSpherical1064,
                    self.photometerInput.bSpherical1064P),
                getAerosolBackscatter(self.lidarInputs[3],
                    self.photometerInput.bSpherical532,
                    self.photometerInput.bSpherical532P,
                    self.photometerInput.bSpherical532C,
                    self.algorithmParams.parallelLeakage532)
                if self.lidarInputs[3].wavelength == 532.0 else
                getAerosolBackscatter(self.lidarInputs[3],
                    self.photometerInput.bSpherical355,
                    self.photometerInput.bSpherical355P,
                    self.photometerInput.bSpherical355C,
                    self.algorithmParams.parallelLeakage532)], [

                getAerosolBackscatter(self.lidarInputs[0],
                    self.photometerInput.bSpheroid355,
                    self.photometerInput.bSpheroid355P),
                getAerosolBackscatter(self.lidarInputs[1],
                    self.photometerInput.bSpheroid532,
                    self.photometerInput.bSpheroid532P),
                getAerosolBackscatter(self.lidarInputs[2],
                    self.photometerInput.bSpheroid1064,
                    self.photometerInput.bSpheroid1064P),
                getAerosolBackscatter(self.lidarInputs[3],
                    self.photometerInput.bSpheroid532,
                    self.photometerInput.bSpheroid532P,
                    self.photometerInput.bSpheroid532C,
                    self.algorithmParams.parallelLeakage532)
                if self.lidarInputs[3].wavelength == 532.0 else
                getAerosolBackscatter(self.lidarInputs[3],
                    self.photometerInput.bSpheroid355,
                    self.photometerInput.bSpheroid355P,
                    self.photometerInput.bSpheroid355C,
                    self.algorithmParams.parallelLeakage532)]])

            # Attenuations are the same regardless of polarization.
            hdfGroup.create_dataset('a', data = [[
                self.photometerInput.aFine355,
                self.photometerInput.aFine532,
                self.photometerInput.aFine1064,
                self.photometerInput.aFine532
                if self.lidarInputs[3].wavelength == 532.0 else
                self.photometerInput.aFine355], [

                self.photometerInput.aSpherical355,
                self.photometerInput.aSpherical532,
                self.photometerInput.aSpherical1064,
                self.photometerInput.aSpherical532
                if self.lidarInputs[3].wavelength == 532.0 else
                self.photometerInput.aSpherical355], [

                self.photometerInput.aSpheroid355,
                self.photometerInput.aSpheroid532,
                self.photometerInput.aSpheroid1064,
                self.photometerInput.aSpheroid532
                if self.lidarInputs[3].wavelength == 532.0 else
                self.photometerInput.aSpheroid355]])

            hdfGroup.create_dataset('V', data = [
                self.photometerInput.vFineCorr,
                self.photometerInput.vSphericalCorr,
                self.photometerInput.vSpheroidCorr])

            # 'Spherical' and 'Spheroid' aerosol modes are initialized in
            # such a way that their sum is equal to that of the 'Coarse'
            # mode used in the 2-mode aerosol retrieval above.
            hdfGroup.create_dataset('C0', data = [
                numpy.ones(gridSize) * 0.01 if self.profileInitFine is None
                    else self.profileInitFine.astype(numpy.float64),
                numpy.ones(gridSize) * 0.001
                    if self.profileInitSpherical is None
                    else self.profileInitSpherical.astype(numpy.float64),
                numpy.ones(gridSize) * 0.001
                    if self.profileInitSpheroid is None
                    else self.profileInitSpheroid.astype(numpy.float64)])

        # Store all the numbers as floats (the default Matlab data type).
        # Integers are not accepted by some Matlab functions (e.g.
        # 'sparse', for indexing) and need to be converted anyway.
        hdfGroup.create_dataset('Nbeg', data = [
            float(input.firstInputIndex - gridDisplacement)
            for input in self.lidarInputs])
        hdfGroup.create_dataset('Nref', data = [
            float(input.refPointIndex - gridDisplacement)
            for input in self.lidarInputs])
        hdfGroup.create_dataset('Nfin', data = [
            float(input.lastInputIndex - gridDisplacement)
            for input in self.lidarInputs])

        hdfGroup.create_dataset('betaMref', data = [
            input.getCorrectedRefMolBackscatter(
                # Use the same depolarization for all the wavelengths, as
                # the difference in values is negligible.
                self.algorithmParams.molDepolarization532,
                # This is used with cross-polarized 532 nm channel only.
                self.algorithmParams.parallelLeakage532)
                for input in self.lidarInputs])

        # All the arrays are converted to 'float64' data type, as otherwise
        # some Matlab functions (e.g. 'sparse') would refuse to work.
        hdfGroup.create_dataset('S', data = [
            self.extendArray(input.normalizedSignal,
                input.firstInputIndex - gridDisplacement, gridSize)
                for input in self.lidarInputs])

        hdfGroup.create_dataset('Omega', data = [
            self.extendArray(input.lidarDispersion,
                input.firstInputIndex - gridDisplacement, gridSize)
                for input in self.lidarInputs])

        hdfGroup.create_dataset('betaM', data = [
            self.extendArray(input.getCorrectedMolBackscatter(
                # Use the same depolarization for all the wavelengths, as
                # the difference in values is negligible.
                self.algorithmParams.molDepolarization532,
                # This is used with cross-polarized 532 nm channel only.
                self.algorithmParams.parallelLeakage532),
                input.firstInputIndex - gridDisplacement, gridSize)
                for input in self.lidarInputs])

        hdfGroup.create_dataset('tauM', data = [
            self.extendArray(input.molThickness,
            input.firstInputIndex - gridDisplacement, gridSize)
            for input in self.lidarInputs])

        hdfGroup.create_dataset('Z', data = [
            self.extendArray(input.getInputZenithAngles(),
            input.firstInputIndex - gridDisplacement, gridSize)
            for input in self.lidarInputs])

        hdfGroup.create_dataset('R0', data =
            numpy.ones(len(self.lidarInputs)) * 1.1 if self.ratiosInit is None
            else self.ratiosInit)

        hdfGroup.create_dataset('TolFun', data = [
            self.algorithmParams.tolFun])
        hdfGroup.create_dataset('TolX', data = [
            self.algorithmParams.tolX])

    def readOutputData(self, hdfGroup, iteration, errorPrefix):
        """Extract data for the given iteration from the Matlab data file and
        return a 'LiricOutput' instance holding it."""

        output = LiricOutput()

        # Just to feel safe.
        assertLidarInputsValid(self.lidarInputs)

        # Assume that geodetic coordinates of lidar measurements coincide.
        output.latitude = self.lidarInputs[0].latitude
        output.longitude = self.lidarInputs[0].longitude
        output.altitude = self.lidarInputs[0].altitude

        # Obtain boundary dates and times for the set of lidar measurements.
        startDateTime = min(input.getStartDateTime()
            for input in self.lidarInputs)
        stopDateTime = max(input.getStopDateTime()
            for input in self.lidarInputs)

        output.startDate = startDateTime.date()
        output.startTime = startDateTime.time()
        output.stopDate = stopDateTime.date()
        output.stopTime = stopDateTime.time()

        # Save the date and time of the photometer measurement.
        output.photometerDate = self.photometerInput.date
        output.photometerTime = self.photometerInput.time

        # Save the date and time of the moment when the retrieval has finished.
        retrievalDateTime = datetime.datetime.now()

        output.retrievalDate = retrievalDateTime.date()
        output.retrievalTime = retrievalDateTime.time()

        # Copy the common data grid height step from a lidar input. Always use
        # zero zenith angle in output databases.
        output.gridStep = self.lidarInputs[0].gridHeightStep
        output.gridZenithAngle = 0.0

        # This value is filled in by the automated retriever only.
        output.joinIndex = None

        # Convert TropoExport's molecular atmosphere IDs to constants used by
        # the automated retriever, if possible.
        atmoModels = {'STD' : 'Standard atmosphere', 'CIRA' : 'CIRA 1986',
            'CUST' : 'Custom'}

        if self.lidarInputs[0].atmoModel not in atmoModels:
            output.atmosphereModel = None
        else:
            output.atmosphereModel = atmoModels[self.lidarInputs[0].atmoModel]

        # Suffixes used to denote input wavelengths in attribute names .
        wavelengthIds = ['355', '532', '1064']

        # Suffixes used to denote lidar channels in attribute names.
        channelIds = wavelengthIds + ['532C']

        # Suffixes used to denote aerosol modes in attribute names.
        if len(self.lidarInputs) == 3:
            # There are 2 modes in the 3-channel retrieval.
            modeSuffixes = ['Fine', 'Coarse']
        elif len(self.lidarInputs) == 4:
            # There are 3 modes in the 4-channel retrieval.
            modeSuffixes = ['Fine', 'Spherical', 'Spheroid']

        # A complete set of suffixes will be used for aerosol attenuation and
        # backscatter coefficients, so that the full aerosol model is stored.
        fullModeSuffixes = ['Fine', 'Spherical', 'Spheroid', 'Coarse']

        # Apply the same correction to the data grid as in 'writeInputData'.
        gridDisplacement = round(
            self.algorithmParams.photometerDisplacement / output.gridStep)

        gridSize = max(input.lastInputIndex + 1 for input in self.lidarInputs)
        gridSize -= gridDisplacement

        # Copy parameters from the lidar inputs.
        for i in range(len(self.lidarInputs)):

            setattr(output, 'firstInputIndex' + channelIds[i],
                self.lidarInputs[i].firstInputIndex)

            setattr(output, 'signalId' + channelIds[i],
                self.lidarInputs[i].localId)
            # 'Near' signal IDs are only filled in by the automated retriever.
            setattr(output, 'signalId' + channelIds[i] + 'Near', None)

            setattr(output, 'polarization' + channelIds[i],
                self.lidarInputs[i].polarization)

            setattr(output, 'dispersion' + channelIds[i],
                # Store exactly the same value that was stored as 'Omega'
                # in 'writeInputData'.
                #
                # Convert 'float64' values returned by 'self.extendArray'
                # to 'float32' ones used by the 'LiricOutput' attributes.
                self.extendArray(self.lidarInputs[i].lidarDispersion,
                    self.lidarInputs[i].firstInputIndex - gridDisplacement,
                    gridSize).astype(numpy.float32))

            setattr(output, 'molThicknessError' + channelIds[i],
                self.lidarInputs[i].molThicknessError)
            setattr(output, 'aerosolThicknessError' + channelIds[i],
                self.lidarInputs[i].aerosolThicknessError)
            setattr(output, 'totalBackscatterError' + channelIds[i],
                self.lidarInputs[i].totalBackscatterError)

        # Fill in parameters obtained from the photometer measurement.
        output.vFine = self.photometerInput.vFineCorr
        output.vCoarse = self.photometerInput.vCoarseCorr
        output.sphericity = self.photometerInput.sphericity

        # Copy aerosol backscatter parameters from the photometer input.
        for wavelengthId in wavelengthIds:
            for modeSuffix in fullModeSuffixes:
                for polarSuffix in ('', 'P', 'C'):

                    attrName = ('b' + modeSuffix + wavelengthId + polarSuffix)
                    setattr(output, attrName, getattr(self.photometerInput,
                        attrName))

        # Copy aerosol attenuation parameters from the photometer input.
        # Attenuations do not depend on polarization.
        for wavelengthId in wavelengthIds:
            for modeSuffix in fullModeSuffixes:

                attrName = ('a' + modeSuffix + wavelengthId)
                setattr(output, attrName, getattr(self.photometerInput,
                    attrName))

        # Copy algorithm parameters from 'self.algorithmParams'.
        for i in range(len(self.lidarInputs)):
            setattr(output, 'weighting' + channelIds[i],
                getattr(self.algorithmParams, 'weighting' + channelIds[i]))

        for j in range(len(modeSuffixes)):
            setattr(output, 'weighting' + modeSuffixes[j],
                getattr(self.algorithmParams, 'weighting' + modeSuffixes[j]))
            setattr(output, 'weightSmooth' + modeSuffixes[j],
                getattr(self.algorithmParams, 'weightSmooth' +
                    modeSuffixes[j]))

        # Copy polarization parameters from 'self.algorithmParams'.
        for name in self.algorithmParams.getAttributeNames(
            'Polarization_parameters'):
            setattr(output, name, getattr(self.algorithmParams, name))

        # Read aerosol concentration profiles.
        for j in range(len(modeSuffixes)):
            setattr(output, 'profile' + modeSuffixes[j],
                self.readArray(hdfGroup, 'C', j, iteration, errorPrefix))

        # For lidar signals, return only the relevant portions of the arrays.
        for i in range(len(self.lidarInputs)):
            setattr(output, 'measuredSignal' + channelIds[i],
                self.readArray(hdfGroup, 'LmeasR', i, iteration, errorPrefix,
                self.lidarInputs[i].lastInputIndex - gridDisplacement + 1))

            setattr(output, 'calculatedSignal' + channelIds[i],
                self.readArray(hdfGroup, 'Lcalc', i, iteration, errorPrefix,
                self.lidarInputs[i].lastInputIndex - gridDisplacement + 1))

        # Read retrieved backscatter ratios at the reference points.
        for i in range(len(self.lidarInputs)):
            setattr(output, 'ratio' + channelIds[i],
                self.readScalar(hdfGroup, 'R', i, iteration, errorPrefix))

        # Read residuals of the optimization equations.
        for i in range(len(self.lidarInputs)):
            setattr(output, 'residual' + channelIds[i],
                self.readScalar(hdfGroup, 'Psi1', i, iteration, errorPrefix))

        for j in range(len(modeSuffixes)):
            setattr(output, 'residual' + modeSuffixes[j],
                self.readScalar(hdfGroup, 'Psi2', j, iteration, errorPrefix))

            setattr(output, 'resSmooth'+ modeSuffixes[j],
                self.readScalar(hdfGroup, 'Psi3', j, iteration, errorPrefix))

        output.onDataLoaded()
        return output

# **** Private functions ******************************************************
def assertLidarInputsValid(lidarInputs):
    """Make sure that the number and order of lidar measurements in the
    'lidarInputs' list, as well as their wavelengths and polarizations, meet
    the expectations of the retrieval algorithm preparation and post-processing
    code (see 'writeInputData' and 'readOutputData')."""

    # Unpolarized measurements are expected to be specified first and obey a
    # strict order of wavelengths.
    assert lidarInputs[0].wavelength == 355.0
    assert lidarInputs[1].wavelength == 532.0
    assert lidarInputs[2].wavelength == 1064.0

    # Parallel-polarized measurements are allowed to replace some or all of the
    # unpolarized ones.
    assert all(input.polarization in (0, 3) for input in lidarInputs[0 : 3])

    # Cross-polarized measurement at 532 nm has to be the last one, if present.
    if len(lidarInputs) > 3:
        assert len(lidarInputs) == 4
        assert lidarInputs[3].wavelength in (532.0, 355.0)
        assert lidarInputs[3].polarization == 2

def getAerosolBackscatter(lidarInput, bTotal, bParallel,
    bCross = 0.0, parallelLeakage = 0.0):
    """Return aerosol backscatter coefficient for the given lidar channel."""

    if lidarInput.polarization == 0:
        # For unpolarized channels, use total backscatter coefficient.
        return bTotal
    elif lidarInput.polarization == 3:
        # For parallel-polarized channels, use parallel backscatter.
        return bParallel

    else:
        # For cross-polarized channels, take parallel-polarized backscatter
        # into account as well, to take account of instrumental inaccuracies.
        assert lidarInput.polarization == 2
        return bCross + parallelLeakage * bParallel
