#!/usr/bin/python3

import sys
import argparse

import apt


def checkpoint() -> None:
    pkgs = set()
    cache = apt.Cache()
    for pkg in cache:
        if pkg.is_installed:
            pkgs.add(pkg.name)
    for name in sorted(pkgs):
        print(name)


def restore(fname: str) -> None:
    desired = set([line.strip() for line in open(fname)])

    cache = apt.Cache()
    for pkg in cache:
        if pkg.is_installed and not pkg.name in desired:
            print("removing", pkg)
            pkg.mark_delete(auto_fix=False)
        if not pkg.is_installed and pkg.name in desired:
            print("installing", pkg)
            pkg.mark_install(auto_fix=False, auto_inst=False)
    cache.commit()


def _make_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser()
    sub = parser.add_subparsers()
    # TODO: make checkpoint and restore symmetric to use either stdin/stdout or
    # FILE in both cases.
    cmd = sub.add_parser(
        "checkpoint",
        help="preserve the set of installed packages",
        description="Prints the set of installed installed packages.",
    )
    cmd.set_defaults(func=lambda ns: checkpoint())
    cmd = sub.add_parser(
        "restore",
        help="restore the set of installed packages",
        description="Adjusts the system retaining exactly the packages listed in FILE.",
    )
    cmd.add_argument("fname", metavar="FILE", help="preserved apt state")
    cmd.set_defaults(func=lambda ns: restore(ns.fname))
    return parser


def main() -> None:
    parser = _make_parser()
    ns = parser.parse_args()
    if hasattr(ns, "func"):
        ns.func(ns)
    else:
        parser.print_help()


if __name__ == "__main__":
    main()
