# Copyright (c) 2011-2014, B.I.Stepanov Institute of Physics, National Academy
# of Sciences of Belarus.
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

"""Auxiliary functions to work with Microsoft Access databases via ADO."""

import win32com.client

# **** Public functions *******************************************************
def openDatabase(filePath):
    """Open an existing Access database file.

    Return the opened ADO connection object.

    Raise 'pywintypes.com_error' on failure."""

    # Use 'EnsureDispatch' to create a precompiled version of the interface,
    # to speed up the data access.
    dbConnection = win32com.client.gencache.EnsureDispatch('ADODB.Connection')
    dbConnection.Open(getConnectionString(filePath))
    return dbConnection

def createDatabase(filePath):
    """Create a new empty Access database file.

    Return the opened ADO connection object.

    Raise 'pywintypes.com_error' on failure."""

    # Use 'EnsureDispatch' to precompile the interface.
    adoCatalog = win32com.client.gencache.EnsureDispatch('ADOX.Catalog')
    return adoCatalog.Create(getConnectionString(filePath))

def getTableNames(dbConnection):
    """Return the list of available table names for an opened database."""

    # '20' means 'adSchemaTables'.
    recSet = dbConnection.OpenSchema(20)

    try:
        tableNames = []

        while not recSet.EOF:
            tableNames.append(recSet.Fields.Item('TABLE_NAME').Value)
            recSet.MoveNext()
    finally:
        recSet.Close()

    return tableNames

def getFieldNames(dbConnection, tableName):
    """Return the list of field names of a table in an opened database."""

    # '4' means 'adSchemaColumns'.
    recSet = dbConnection.OpenSchema(4)

    try:
        fieldNames = []

        while not recSet.EOF:
            fields = recSet.Fields
            if fields.Item('TABLE_NAME').Value == tableName:
                fieldNames.append(fields.Item('COLUMN_NAME').Value)
            recSet.MoveNext()
    finally:
        recSet.Close()

    return fieldNames

def createTable(dbConnection, tableName, fieldNames, fieldDataTypes):
    """Create a new table in an opened database, given the sequences of its
    field names and SQL data types.

    Raise 'pywintypes.com_error' on failure."""

    assert len(fieldNames) == len(fieldDataTypes) > 0

    # Construct a comma-separated list of field name-type pairs.
    fieldsDefinition = ', '.join(quoteName(fieldNames[i]) + ' ' +
        fieldDataTypes[i] for i in range(len(fieldNames)))

    dbConnection.Execute('CREATE TABLE %s (%s)' %
        (quoteName(tableName), fieldsDefinition))

def getDataTypeCode(sqlDataType):
    """Return the ADO 'DataTypeEnum' constant for a SQL data type with the
    given name or 'None' if the data type is not supported.

    Warning: data type names in different database systems may differ."""

    if sqlDataType in ('DATE', 'TIME'):
        return 7    # 'adDate'

    elif sqlDataType == 'INT':
        return 3    # 'adInteger'
    elif sqlDataType == 'SMALLINT':
        return 2    # 'adSmallInt'

    elif sqlDataType == 'SINGLE':
        return 4    # 'adSingle'
    elif sqlDataType == 'DOUBLE':
        return 5    # 'adDouble'

    elif sqlDataType == 'BIT':
        return 11   # 'adBoolean'

    elif sqlDataType == 'VARCHAR':
        return 202  # 'adVarWChar'
    elif sqlDataType == 'LONGTEXT':
        return 203  # 'adLongVarWChar'
    elif sqlDataType == 'LONGBINARY':
        return 205  # 'adLongVarBinary'

    return None

def getDataTypeName(sqlDataType):
    """Return a descriptive string for a SQL data type with the given name or
    'None' if the data type is not supported."""

    if sqlDataType == 'DATE':
        return 'date'
    elif sqlDataType == 'TIME':
        return 'time'

    elif sqlDataType == 'INT':
        return 'four-byte signed integer'
    elif sqlDataType == 'SMALLINT':
        return 'two-byte signed integer'

    elif sqlDataType == 'SINGLE':
        return 'single precision floating point number'
    elif sqlDataType == 'DOUBLE':
        return 'double precision floating point number'

    elif sqlDataType == 'BIT':
        return 'boolean value'

    elif sqlDataType == 'VARCHAR':
        return 'null-terminated Unicode character string'
    elif sqlDataType == 'LONGTEXT':
        return 'long null-terminated Unicode character string'
    elif sqlDataType == 'LONGBINARY':
        return 'long binary value'

    return None

def quoteName(fieldName):
    """Return a safe SQL representation of a database field or table name."""

    # Access field names containing spaces must be enclosed in square brackets.
    return '[' + fieldName + ']'

# **** Private functions ******************************************************
def getConnectionString(filePath):
    """Return the connection string to be used with an 'ADODB.Connection'
    object to open an existing Access database file."""

    # ADO connection string values must be enclosed in quotes if they start
    # with a single or double quote character or contain a semicolon.
    # The surrounding quotation mark character, if any, has to be doubled every
    # time it occurs within the value itself.
    #
    # For simplicity, quote the file path regardless of its value.
    quotedFilePath = '"' + filePath.replace('"', '""') + '"'

    return 'Provider=Microsoft.Jet.OLEDB.4.0;Data Source=%s;' % quotedFilePath
