#!/usr/bin/python3
# SPDX-License-Identifier: GPL-3.0-or-later
"""
Utility for setting passwords inside disk/VM images.

Written for FreedomBox images.  Works on Qemu, VirtualBox and raw disk
images.
"""

import argparse
import getpass
import logging
import os
import subprocess
import sys
import tempfile

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name


def main():
    """The main entry point."""
    logging.basicConfig(level=logging.INFO)

    parser = argparse.ArgumentParser(
        description='Change password of a user inside a disk image file')
    parser.add_argument(
        'image',
        help='Disk image file (.img or .vdi) inside which user manipulation '
        'is sought')
    parser.add_argument('user', help='User account to change password for')
    parser.add_argument('--password', help='New password for user')

    arguments = parser.parse_args()

    image_type = get_image_type(arguments)

    check_requirements(image_type)

    try:
        password = arguments.password or take_password()
        perform_operations(arguments, password, image_type)
    except subprocess.CalledProcessError as exception:
        logger.error('Error running command: %s', exception.cmd)
        if exception.output:
            logger.error('Error output - %s', exception.output.decode())
        sys.exit(1)
    except KeyboardInterrupt:
        logger.error('Command terminated by user action')
        sys.exit(2)
    except Exception as exception:  # pylint: disable=broad-except
        logger.exception('Unexpected error: %s', exception)
        sys.exit(3)


def check_requirements(image_type):
    """Check that the necessary requirements are available."""
    logger.info('Checking for necessary dependencies')
    if os.geteuid() != 0:
        logger.error('Due to limitations of the tools involved, you need to '
                     'run this command as "root" user or using the "sudo" '
                     'command.')
        sys.exit(-1)

    if image_type == 'raw':
        try:
            subprocess.check_output(['which', 'kpartx'])
        except subprocess.CalledProcessError:
            logger.error('"kpartx" command not found.  On Debian based '
                         'systems it is provided by the package "kpartx".')
            sys.exit(-1)

    if image_type == 'vm':
        try:
            subprocess.check_output(['which', 'qemu-nbd'])
        except subprocess.CalledProcessError:
            logger.error('"qemu-nbd" command not found.  On Debian based '
                         'systems it is provided by the package "qemu-utils".')
            sys.exit(-1)

    found = False
    with open('/proc/filesystems') as file_handle:
        for line in file_handle:
            if 'btrfs' in line.split():
                found = True

    if not found:
        logger.error('Your kernel does not support Btrfs filesystem.')
        sys.exit(-1)


def get_image_type(arguments):
    """Return the type of the disk image: raw/vm."""
    if arguments.image.split('.')[-1] in ('vdi', 'qcow2'):
        return 'vm'

    return 'raw'


def take_password():
    """Prompt for new password to be set."""
    while True:
        password1 = getpass.getpass('Enter new password: ')
        password2 = getpass.getpass('Re-enter new password: ')
        if password1 == password2:
            return password1

        logger.error('Passwords do not match\n')


def perform_operations(arguments, password, image_type):
    """Map/mount image and change password."""
    map_info = map_disk_image(arguments.image, image_type)

    logger.info('Root device is - %s', map_info['root_device'])

    try:
        mount_info = mount_disk_image(map_info['root_device'])

        try:
            change_password(mount_info['root_path'], arguments.user, password)
        finally:
            unmount_disk_image(mount_info)
    finally:
        unmap_disk_image(map_info)


def map_disk_image(disk_image, image_type):
    """Map the partitions inside disk image as block devices."""
    if image_type == 'vm':
        return map_vm_disk_image(disk_image)

    return map_raw_disk_image(disk_image)


def map_vm_disk_image(disk_image):
    """Map the partitions inside a VM disk image as block devices."""
    logger.info('Adding partition mappings for VM disk image - %s', disk_image)
    device = '/dev/nbd7'
    subprocess.check_call(['modprobe', 'nbd', 'max_part=64'])
    subprocess.check_call(['qemu-nbd', '--connect=' + device, disk_image])
    subprocess.check_call(['partprobe', device])
    output = subprocess.check_output(['fdisk', '-o', 'Device', '-l', device])
    root_device = output.decode().split('\n')[-2]

    return {
        'root_device': root_device,
        'image_type': 'vm',
        'mapped_device': device
    }


def map_raw_disk_image(disk_image):
    """Map the partitions inside a raw disk image as block devices."""
    logger.info('Adding partition mappings for raw disk image - %s',
                disk_image)
    output = subprocess.check_output(['kpartx', '-a', '-v', '-s', disk_image])
    output = output.decode()

    devices = []
    for line in output.split('\n'):
        if line:
            devices.append(line.split(' ')[2])

    root_device = '/dev/mapper/' + devices[-1]
    return {
        'root_device': root_device,
        'image_type': 'raw',
        'mapped_image': disk_image
    }


def mount_disk_image(root_device):
    """Mount the root device into a temporary directory and return the path."""
    mount_path = tempfile.mkdtemp()

    logger.info('Mounting %s on %s', root_device, mount_path)
    subprocess.check_call(['mount', root_device, mount_path])

    mount_info = {'mount_path': mount_path, 'root_path': mount_path}

    # XXX: Assumption that if root/@ exists, it is btrfs and that we
    # are going use that snapshot to work on it.
    if os.path.isdir(os.path.join(mount_path, '@')):
        mount_info['root_path'] = os.path.join(mount_path, '@')

    return mount_info


def change_password(root_path, user, password):
    """Change a user's password inside chroot directory."""
    logger.info('Changing password for %s inside %s', user, root_path)
    chpasswd_input = '{0}:{1}'.format(user, password)

    # XXX: Providing crypt method is not recommended.  However, without crypt
    # method, the passwd encryption happens using PAM and that does not seem to
    # be working in a chroot.
    subprocess.check_output(
        ['chpasswd', '--root', root_path, '--crypt-method', 'SHA512'],
        input=chpasswd_input.encode())


def unmount_disk_image(mount_info):
    """Unmount the root device."""
    mount_path = mount_info['mount_path']

    logger.info('Unmounting %s', mount_path)
    subprocess.check_call(['umount', mount_path])

    os.rmdir(mount_path)

    return mount_path


def unmap_disk_image(map_info):
    """Ummap the disk image partitions."""
    if map_info['image_type'] == 'vm':
        unmap_vm_disk_image(map_info)
    else:
        unmap_raw_disk_image(map_info)


def unmap_vm_disk_image(map_info):
    """Ummap the VM disk image partitions."""
    device = map_info['mapped_device']
    logger.info('Removing partition mappings from VM device - %s', device)
    subprocess.check_output(['qemu-nbd', '--disconnect', device])


def unmap_raw_disk_image(map_info):
    """Ummap the raw disk image partitions."""
    disk_image = map_info['mapped_image']
    logger.info('Removing partition mappings from raw disk - %s', disk_image)
    subprocess.check_output(['kpartx', '-d', disk_image])


if __name__ == '__main__':
    main()
