# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import codecs
import os
import sys

from awscli.customizations.commands import BasicCommand
from awscli.customizations.configure.writer import ConfigFileWriter
from awscli.customizations.utils import uni_print


class ConfigureImportCommand(BasicCommand):
    NAME = 'import'
    DESCRIPTION = (
        'Import CSV credentials generated from the AWS web console. '
        'Entries in the CSV will be imported as profiles in the AWS '
        'credentials file, with the profile name matching the IAM User Name.'
    )
    EXAMPLES = (
        'aws configure import --csv file://credentials.csv\n\n'
        'aws configure import --csv file://credentials.csv --skip-invalid\n\n'
        'aws configure import --csv file://credentials.csv '
        '--profile-prefix test-\n\n'
    )
    ARG_TABLE = [
        {
            'name': 'csv',
            'required': True,
            'help_text': (
                'The credentials in CSV format generated by the AWS web console. '
                'The CSV file must contain the "User name", "Access key ID", and '
                '"Secret access key" headers.'
                'If passing a CSV file path instead of a CSV-formatted string, '
                'the "file://" prefix is required.'
            ),
            'cli_type_name': 'string',
        },
        {
            'name': 'skip-invalid',
            'dest': 'skip_invalid',
            'help_text': (
                'Skip entries that are invalid or do not have programmatic '
                'access instead of failing.'
            ),
            'default': False,
            'action': 'store_true',
        },
        {
            'name': 'profile-prefix',
            'dest': 'profile_prefix',
            'help_text': (
                'Adds the specified prefix to the beginning of all profile names.'
            ),
            'default': '',
            'cli_type_name': 'string',
        },
    ]

    def __init__(
        self, session, csv_parser=None, importer=None, out_stream=None
    ):
        super(ConfigureImportCommand, self).__init__(session)
        if csv_parser is None:
            csv_parser = CSVCredentialParser()
        self._csv_parser = csv_parser

        if importer is None:
            writer = ConfigFileWriter()
            importer = CredentialImporter(writer)
        self._importer = importer

        if out_stream is None:
            out_stream = sys.stdout
        self._out_stream = out_stream

    def _get_config_path(self):
        config_file = self._session.get_config_variable('credentials_file')
        return os.path.expanduser(config_file)

    def _import_csv(self, contents):
        self._check_possible_filepath(contents)
        config_path = self._get_config_path()
        credentials = self._csv_parser.parse_credentials(contents)
        for credential in credentials:
            self._importer.import_credential(
                credential,
                config_path,
                profile_prefix=self._profile_prefix,
            )
        import_msg = 'Successfully imported %s profile(s)\n' % len(credentials)
        uni_print(import_msg, out_file=self._out_stream)

    def _check_possible_filepath(self, csv_data):
        if ('\n' not in csv_data and
            os.path.exists(csv_data) and
            not csv_data.startswith('file://')):
            raise ValueError(
                "You may be passing a file to import without the 'file://' prefix. "
                "To import a CSV file, use --csv file://path/to/file.csv"
            )

    def _run_main(self, parsed_args, parsed_globals):
        self._csv_parser.strict = not parsed_args.skip_invalid
        self._profile_prefix = parsed_args.profile_prefix
        self._import_csv(parsed_args.csv)
        return 0


class CredentialParserError(Exception):
    pass


class CSVCredentialParser:
    _USERNAME_HEADER = 'User Name'
    _AKID_HEADER = 'Access Key ID'
    _SAK_HEADER = 'Secret Access key'
    _EXPECTED_HEADERS = [_USERNAME_HEADER, _AKID_HEADER, _SAK_HEADER]

    _EMPTY_CSV = 'Provided CSV contains no contents'
    _HEADER_NOT_FOUND = 'Expected header "%s" not found'
    _ROW_MISSING_HEADER = 'Row missing value for header "%s"'
    _INVALID_ROW = 'Failed to parse entry #%s: %s'

    def __init__(self, strict=True):
        self.strict = strict

    def _format_header(self, header):
        # Remove leading UTF BOM character if present
        if header.startswith(codecs.BOM_UTF8.decode()):
            header = header[1:]
        return header.lower().strip()

    def _parse_csv_headers(self, header):
        return [self._format_header(h) for h in header.split(',')]

    def _extract_expected_header_indices(self, headers):
        indices = {}
        for header in self._EXPECTED_HEADERS:
            formatted_header = self._format_header(header)
            if formatted_header not in headers:
                raise CredentialParserError(self._HEADER_NOT_FOUND % header)
            indices[header] = headers.index(formatted_header)
        return indices

    def _parse_csv_row(self, row, header_indices):
        item = {}
        cols = row.split(',')
        for header, index in header_indices.items():
            try:
                item[header] = cols[index].strip()
            except IndexError:
                item[header] = None
            if not item[header]:
                raise CredentialParserError(self._ROW_MISSING_HEADER % header)
        return item

    def _parse_csv_rows(self, rows, header_indices):
        count = 0
        parsed_rows = []
        for row in rows:
            count += 1
            try:
                item = self._parse_csv_row(row, header_indices)
            except CredentialParserError as e:
                if not self.strict:
                    continue
                raise CredentialParserError(self._INVALID_ROW % (count, e))
            parsed_rows.append(item)
        return parsed_rows

    def _parse_csv(self, csv):
        if not csv.strip():
            raise CredentialParserError(self._EMPTY_CSV)

        lines = csv.splitlines()
        parsed_headers = self._parse_csv_headers(lines[0])
        header_indices = self._extract_expected_header_indices(parsed_headers)
        return self._parse_csv_rows(lines[1:], header_indices)

    def _convert_rows_to_credentials(self, parsed_rows):
        credentials = []
        for row in parsed_rows:
            username = row.get(self._USERNAME_HEADER)
            akid = row.get(self._AKID_HEADER)
            sak = row.get(self._SAK_HEADER)
            credentials.append((username, akid, sak))
        return credentials

    def parse_credentials(self, contents):
        # Expected format is:
        # User name,Password,Access key ID,Secret access key,Console login link
        # username1,pw,akid,sak,https://console.link
        # username2,pw,akid,sak,https://console.link
        parsed_rows = self._parse_csv(contents)
        return self._convert_rows_to_credentials(parsed_rows)


class CredentialImporter:
    def __init__(self, writer):
        self._config_writer = writer

    def import_credential(
        self, credential, credentials_file, profile_prefix=''
    ):
        name, akid, sak = credential
        config_profile = {
            '__section__': profile_prefix + name,
            'aws_access_key_id': akid,
            'aws_secret_access_key': sak,
        }
        self._config_writer.update_config(config_profile, credentials_file)
