#!/usr/bin/python3

import argparse
import collections
import contextlib
import datetime
import fnmatch
import glob
import json
import os
import shlex
import signal
import subprocess
import sys
import tempfile
import time
from subprocess import call
from subprocess import check_call as cc
from typing import (
    Any,
    Callable,
    Dict,
    Generator,
    List,
    Optional,
    Tuple,
    Type,
    TypedDict,
    TypeVar,
    cast,
)

import dateutil.parser  # type: ignore

TYPE_VOLUME = "volume"
TYPE_FILESYSTEM = "filesystem"

ZFS_SNAPSHOT_ROOT = ".borg-offsite-backup"
SNAPSHOT_PREFIX = "@borg-"
ZVOL_DEV_ROOT = "/dev/zvol"

DEFAULT_TELEMETRY_TIMEOUT = 600
SSH_OPTS = [
    "ServerAliveInterval 300",
    "ServerAliveCountMax 2",
    "ConnectTimeout 45",
]

echo: List[str] = []
terminated = False


def sigterm(*args: Any) -> None:
    global terminated
    terminated = True


def status(string: str, *args: Any) -> None:
    if args:
        string = string % args
    print(string, file=sys.stderr)
    if hasattr(sys.stderr, "flush"):
        sys.stderr.flush()


def sudo(*cmd: str, **kwargs: Any) -> None:
    cc(([] if os.getuid() == 0 else ["sudo"]) + list(cmd), **kwargs)


def output(*cmd: str, **kwargs: Any) -> str:
    return subprocess.check_output(cmd, universal_newlines=True, **kwargs)


def sudoutput(*cmd: str) -> str:
    lcmd = ([] if os.getuid() == 0 else ["sudo"]) + list(cmd)
    return output(*lcmd)


def dataset_exists(dataset_name: str) -> bool:
    return (
        call(
            ["zfs", "list", dataset_name],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
        )
        == 0
    )


def is_mountpoint(mp: str) -> bool:
    return call(["mountpoint", "-q", "--", mp]) == 0


def list_props(
    dataset_name: str, props: List[str], recursive: bool
) -> List[Dict[str, str]]:
    cmd = (
        ["zfs", "list", "-o", ",".join(props), "-H"]
        + (["-r"] if recursive else [])
        + [dataset_name]
    )
    return [
        dict(zip(props, s.split("\t"))) for s in output(*cmd).splitlines() if s.strip()
    ]


def create_dataset(dataset: str) -> None:
    sudo(
        "zfs",
        "create",
        "-o",
        "com.sun:auto-snapshot=false",
        "-o",
        "org.qubes-os:part-of-qvm-pool=true",
        "-o",
        "secondarycache=metadata",
        "-o",
        "mountpoint=none",
        dataset,
    )


Y = TypeVar("Y")


def retrier(fun: Callable[..., Y], count: int, timeout: float) -> Y:
    if count < 1:
        raise ValueError("count must be larger than zero")
    while True:
        try:
            val = fun()
            return val
        except KeyboardInterrupt:
            raise
        except Exception:
            if count < 1:
                raise
            time.sleep(timeout)
            count = count - 1


def snapshot_exists(snapshot_name: str) -> bool:
    return (
        call(
            ["zfs", "list", "-t", "snapshot", snapshot_name],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
        )
        == 0
    )


def create_snapshot(snapshot_name: str) -> None:
    status(f"Snapshotting current {snapshot_name}")
    sudo("zfs", "snapshot", snapshot_name)


def clone_snapshot(snapshot_name: str, dataset_name: str) -> None:
    status(f"Cloning {snapshot_name} to {dataset_name}")
    sudo(
        "zfs",
        "clone",
        "-p",
        snapshot_name,
        dataset_name,
    )


def mount_dataset(dataset_name: str, mountpoint: str) -> None:
    status(f"Mounting {dataset_name} onto {mountpoint}")
    sudo("zfs", "set", "mountpoint=" + mountpoint, dataset_name)


def unmount_dataset(dataset_name: str) -> None:
    status(f"Unmounting dataset {dataset_name}")
    sudo("zfs", "set", "mountpoint=none", dataset_name)


def destroy_dataset_recursively(dataset: str) -> None:
    status(f"Destroying {dataset} recursively")
    sudo("zfs", "destroy", "-r", dataset)


@contextlib.contextmanager
def multi_context(*cms: Callable[[], Any]) -> Generator[Any, None, None]:
    with contextlib.ExitStack() as stack:
        yield [stack.enter_context(cls()) for cls in cms]


class DatasetToBackup(object):
    def __init__(self, name: str, recursive: bool, glob: bool):
        self.name = name
        assert not (recursive and glob), (name, recursive, glob)
        self.recursive = recursive
        self.glob = glob


class Cfg(object):
    backup_path: str = ""
    bridge_vm: Optional[str] = None
    backup_server: str = ""
    compression = "auto,zstd"
    backup_user: str = ""
    prune_on_days: Optional[list[int]] = None
    keep_daily = 7
    keep_weekly = 4
    keep_monthly = 12

    def __init__(self) -> None:
        self.filesystems_to_backup: List[str] = []
        self.datasets_to_backup: List[DatasetToBackup] = []
        self.exclude_patterns: List[str] = []
        self.prune_on_days = []

    def as_dict(self) -> Dict[str, Any]:
        d: Dict[str, Any] = {}
        for k in dir(self):
            v = getattr(self, k)
            if not callable(v):
                d[k] = v
        return d

    @classmethod
    def from_file(cls: Type["Cfg"], fn: str) -> "Cfg":
        with open(fn, "r") as f:
            text = f.read()
        records = json.loads(text)
        keyvals = list(records.items())
        c = cls()
        for k, v in keyvals:
            if k in ("datasets_to_backup", "filesystems_to_backup", "exclude_patterns"):
                if hasattr(v, "items"):
                    v = list(v.keys())
                if k == "datasets_to_backup":
                    dsets: List[DatasetToBackup] = []
                    for dset in v:
                        if isinstance(dset, str):
                            dsets.append(DatasetToBackup(dset, False, False))
                        elif isinstance(dset, dict):
                            dsets.append(
                                DatasetToBackup(
                                    dset["name"],
                                    dset.get("recursive", False),
                                    dset.get("glob", False),
                                )
                            )
                        else:
                            raise ValueError(f"Invalid dataset to backup: {dset}")
                    c.datasets_to_backup = dsets
                else:
                    setattr(c, k, v)
            else:
                if v:
                    setattr(c, k, v)
        for p in "backup_path backup_server backup_user".split():
            if not getattr(c, p):
                raise ValueError("%s must be defined in %s" % (p, fn))
        if not c.datasets_to_backup and not c.filesystems_to_backup:
            raise ValueError(
                "either datasets_to_backup or filesystems_to_backup "
                f"must be defined and nonempty in {fn}"
            )
        return c


class SnapshotContext(object):
    def __init__(
        self,
        datasets_to_backup: List[DatasetToBackup],
        execdate: str,
        tmpdir: str,
    ) -> None:
        self.datasets_to_backup = datasets_to_backup
        self.execdate = execdate
        self.tmpdir = tmpdir
        self.read_special = False

    def datasets_to_backup_sorted_by_mountpoint(self) -> List[Tuple[str, str, str]]:
        arrange: List[Tuple[str, str, str]] = []
        already: set[str] = set()
        for d in self.datasets_to_backup:
            if d.glob:
                dd = [
                    row
                    for row in list_props(
                        d.name.partition("/")[0],
                        ["name", "mountpoint", "type"],
                        True,
                    )
                    if fnmatch.fnmatch(row["name"], d.name)
                ]
            else:
                dd = list_props(d.name, ["name", "mountpoint", "type"], d.recursive)
            for dataset in dd:
                name, type_, mp = (
                    dataset["name"],
                    dataset["type"],
                    dataset["mountpoint"],
                )
                if name in already:
                    continue
                if type_ == "filesystem" and mp == "none":
                    continue
                arrange.append((name, type_, mp))
                already.add(name)
        return list(sorted(arrange, key=lambda x: x[1] + "/" + x[2]))

    def roots(self) -> List[str]:
        return list(
            set(
                os.path.join(d.name.partition("/")[0], ZFS_SNAPSHOT_ROOT)
                for d in self.datasets_to_backup
            )
        )

    def snapshot_to_target(self, d: str) -> str:
        return os.path.join(
            d.partition("/")[0], ZFS_SNAPSHOT_ROOT, d.partition("/")[2]
        ).rstrip("/")

    def __enter__(self) -> "SnapshotContext":
        for root in self.roots():
            if not dataset_exists(root):
                create_dataset(root)

        for d, typ, mp in self.datasets_to_backup_sorted_by_mountpoint():
            snapshot_name = d + SNAPSHOT_PREFIX + self.execdate
            target_name = self.snapshot_to_target(d)
            if not snapshot_exists(snapshot_name):
                create_snapshot(snapshot_name)

            if not dataset_exists(target_name):
                clone_snapshot(snapshot_name, target_name)

        for d, typ, mp in self.datasets_to_backup_sorted_by_mountpoint():
            snapshot_name = d + SNAPSHOT_PREFIX + self.execdate
            target_name = self.snapshot_to_target(d)

            if typ == TYPE_VOLUME:
                self.read_special = True
                source_dev = os.path.join(ZVOL_DEV_ROOT, target_name)
                target_dev = self.tmpdir + os.path.join(ZVOL_DEV_ROOT, d)
                status("Creating device file %s from %s", target_dev, source_dev)
                sudo("mkdir", "-p", "--", os.path.dirname(target_dev))

                def remove_and_copy(p: str, q: str) -> None:
                    sudo("rm", "-f", q)
                    sudo("cp", "-faL", p, q)

                retrier(
                    lambda: remove_and_copy(source_dev, target_dev),
                    20,
                    0.05,
                )
            else:
                mp = self.tmpdir + mp.rstrip("/")
                if not os.path.exists(mp) or not is_mountpoint(mp):
                    mount_dataset(target_name, mp)

        return self

    def __exit__(self, *unused: Any) -> None:
        for d, typ, mp in reversed(self.datasets_to_backup_sorted_by_mountpoint()):
            mp = self.tmpdir + mp.rstrip("/")
            target_name = self.snapshot_to_target(d)
            if typ == TYPE_VOLUME:
                target_dev = self.tmpdir + os.path.join(ZVOL_DEV_ROOT, d)
                if os.path.isfile(target_dev):
                    status("Deleting device file %s", target_dev)
                    sudo("rm", "-f", target_dev)
            else:
                if is_mountpoint(mp):
                    unmount_dataset(target_name)

        for root in self.roots():
            if dataset_exists(root):
                destroy_dataset_recursively(root)

        for d, typ, mp in reversed(self.datasets_to_backup_sorted_by_mountpoint()):
            snapshot_name = d + SNAPSHOT_PREFIX + self.execdate
            if snapshot_exists(snapshot_name):
                destroy_dataset_recursively(snapshot_name)


class BindMountContext(object):
    def __init__(self, filesystems_to_backup: List[str], tmpdir: str) -> None:
        self.filesystems_to_backup = filesystems_to_backup
        self.tmpdir = tmpdir

    def fsmounts(self) -> Generator[Tuple[str, str], None, None]:
        for d in sorted(self.filesystems_to_backup):
            yield (d, os.path.join(self.tmpdir + d))

    def __enter__(self) -> None:
        needsmount: List[Tuple[str, str]] = []
        for d, mp in self.fsmounts():
            if not os.path.exists(mp) or not is_mountpoint(mp):
                needsmount.append((mp, d))
        for mp, d in list(sorted(needsmount)):
            status("Creating %s", mp)
            if not os.path.exists(mp):
                sudo("mkdir", "-m", "700", "-p", "--", mp)
            if not is_mountpoint(mp):
                status("Mounting %s onto %s", d, mp)
                sudo("mount", "--bind", d, mp)

    def __exit__(self, *unused: Any) -> None:
        for _, mp in reversed(list(self.fsmounts())):
            if is_mountpoint(mp):
                try:
                    status("Unmounting file system %s", mp)
                    sudo("umount", mp)
                except subprocess.CalledProcessError as exc:
                    status("Unmounting file system %s failed: %s", mp, exc)
                    time.sleep(1)
                    continue


class TmpDirContext(object):
    def __init__(self, tmpdir: str) -> None:
        self.tmpdir = tmpdir

    def __enter__(self) -> None:
        sudo("mkdir", "-m", "700", "-p", "--", self.tmpdir)

    def __exit__(self, *unused: Any) -> None:
        if call(["findmnt", self.tmpdir]) == 0:
            raise Exception(
                "Will not remove anything under the temporary "
                "directory since file systems are still mounted on it."
            )

        try:
            status("Removing %s", self.tmpdir)
            qtmp = shlex.quote(self.tmpdir)
            sudo(
                "bash",
                "-c",
                f"find {qtmp} -depth -type d -print0 | xargs -0 rmdir -v --",
            )
        except subprocess.CalledProcessError:
            status(output("find", self.tmpdir))
            status(output("findmnt", self.tmpdir))
            raise


def check_connectivity(cmd: List[str]) -> None:
    for _ in range(10):
        p = subprocess.run(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            input="\n",
            timeout=10,
            text=True,
        )
        ssh_found = "SSH-" in p.stdout
        if p.returncode != 0:
            status(f"Check for connectivity failed with error {p.returncode}")
            time.sleep(3)
        elif not ssh_found:
            status(
                "Check for connectivity did not find an SSH server (stdout: %r)",
                p.stdout,
            )
            time.sleep(3)
        else:
            break
    p.check_returncode()  # type: ignore


def nc(host: str) -> List[str]:
    return ["nc", "--no-shutdown", "-w", "45", host, "22"]


class VmBridgeContext(object):
    def __init__(self, vm_name: str, host: str) -> None:
        self.host = host
        self.vm_name = vm_name
        self.was_running_before = False

    def __enter__(self) -> None:
        def wrap(cmd: List[str]) -> List[str]:
            return [
                "qvm-run",
                "--nogui",
                "-a",
                "--pass-io",
                self.vm_name,
                " ".join(shlex.quote(x) for x in cmd),
            ]

        if "Running" in output("qvm-ls", self.vm_name):
            self.was_running_before = True
        if "yes" not in output(*wrap(["echo", "yes"])):
            raise RuntimeError("VM %s not functional" % self.vm_name)

        check_connectivity(cmd=wrap(nc(self.host)))

        opts_list: List[str] = []
        for opt in SSH_OPTS:
            opts_list.extend(["-o", opt])

        os.environ["BORG_RSH"] = "borg-offsite-backup-helper " + " ".join(
            shlex.quote(x) for x in opts_list + [self.vm_name]
        )

    def __exit__(self, *ignored: Any) -> None:
        if "BORG_RSH" in os.environ:
            del os.environ["BORG_RSH"]
        if not self.was_running_before:
            cc(["qvm-shutdown", "--wait", self.vm_name])


class NoBridgeContext(object):
    def __init__(self, host: str) -> None:
        self.host = host

    def __enter__(self) -> None:
        check_connectivity(cmd=nc(self.host))

        opts_list: List[str] = []
        for opt in SSH_OPTS:
            opts_list.extend(["-o", opt])

        os.environ["BORG_RSH"] = "ssh " + " ".join(shlex.quote(x) for x in opts_list)

    def __exit__(self, *ignored: Any) -> None:
        if "BORG_RSH" in os.environ:
            del os.environ["BORG_RSH"]


@contextlib.contextmanager
def CdContext(path: str) -> Generator[str, None, None]:
    prev_cwd = os.getcwd()
    os.chdir(path)
    try:
        yield path
    finally:
        os.chdir(prev_cwd)


@contextlib.contextmanager
def DummyContext(*unused: Any) -> Generator[str, None, None]:
    yield "DummyContext"


def is_locked() -> bool:
    return bool(
        glob.glob(
            os.path.join(
                os.path.expanduser("~"), ".cache", "borg", "*", "lock.exclusive"
            )
        )
    )


def unlock() -> None:
    status("Unlocking repository")
    subprocess.check_call(echo + ["borg", "break-lock"])


def run(
    config: Cfg,
    execdate: str,
    compression: str,
    subcommand: str,
    read_special: bool,
    args: List[str],
) -> int:
    if is_locked():
        unlock()

    if subcommand == "create":
        with tempfile.TemporaryDirectory(
            prefix="borg-offsite-backup", suffix="XXXXXXXXXX"
        ) as excludedir:
            excludefn = os.path.join(excludedir, "exclude")
            with open(excludefn, "w") as excludefile:
                excludefile.write("\n".join(config.exclude_patterns))
                excludefile.flush()
            params = ["--progress"] if not os.environ.get("QUIET") else []
            if "comment" in output("borg", "create", "--help"):
                params += ["--comment", "global"]
            if "--sparse" in output("borg", "create", "--help"):
                params += ["--sparse"]
            if read_special:
                params += ["--read-special"]
            cmd = (
                [
                    "time",
                    "nice",
                    "ionice",
                    "-c3",
                    # "strace",
                    # "-ff",
                    # "-efile",
                    # "-s2048",
                    "borg",
                    "create",
                    "--exclude-caches",
                    "--keep-exclude-tags",
                    "--exclude-from",
                    excludefn,
                    "--debug",
                    "--compression",
                    compression,
                ]
                + params
                + args
                + ["::" + execdate, "."]
            )
            ret = subprocess.call(echo + cmd)
    else:
        cmd = ["borg", subcommand] + args
        ret = subprocess.call(echo + cmd)
    if subcommand == "create" and ret == 1:
        # This is just a warning.
        ret = 0
    return ret


def run_prune(keep_daily: int, keep_weekly: int, keep_monthly: int) -> int:
    if is_locked():
        unlock()
    params = ["--progress"] if not os.environ.get("QUIET") else []
    cmd = [
        "time",
        "nice",
        "ionice",
        "-c3",
        "borg",
        "prune",
        "--stats",
        "--keep-daily",
        str(keep_daily),
        "--keep-weekly",
        str(keep_weekly),
        "--keep-monthly",
        str(keep_monthly),
    ] + params
    return subprocess.call(echo + cmd)


def run_compact() -> int:
    if is_locked():
        unlock()
    params = ["--progress"] if not os.environ.get("QUIET") else []
    cmd = [
        "time",
        "nice",
        "ionice",
        "-c3",
        "borg",
        "compact",
        "--info",
    ] + params
    return subprocess.call(echo + cmd)


def run_collector(promfile: str, telemetry_timeout: int) -> None:
    def r2t(x: str) -> float:
        return dateutil.parser.parse(x).timestamp()

    Metric = collections.namedtuple("Metric", ["name", "labels", "value"])

    def repr_metric(x: Metric) -> str:
        return "borg_" + str(x.name) + fmtlbs(x.labels) + " " + str(x.value)

    def fmtlbs(ls: Dict[str, str]) -> str:
        if not ls:
            return ""
        lbs = "{"
        for n, (k, v) in enumerate(ls.items()):
            lbs += '%s="%s"' % (k, v)
            if n < len(ls.items()) - 1:
                lbs += ","
        lbs += "}"
        return lbs

    metrics: List[Metric] = []
    p = metrics.append

    start = time.time()

    class Repository(TypedDict):
        id: str
        last_modified: str
        location: str

    class ArchiveStats(TypedDict):
        compressed_size: int
        deduplicated_size: int
        nfiles: int
        original_size: int

    class ArchiveDetail(TypedDict):
        comment: str
        duration: int
        end: str
        name: str
        start: str
        stats: ArchiveStats

    class ArchiveInfo(TypedDict):
        archives: List[ArchiveDetail]
        repository: Repository

    data = cast(
        ArchiveInfo,
        json.loads(
            subprocess.check_output(
                [sys.executable, sys.argv[0], "info", "--last=1000", "--json"],
                timeout=telemetry_timeout,
            ).decode("utf-8")
        ),
    )
    p(
        Metric(
            name="repository_last_modified_timestamp_seconds",
            labels={},
            value=r2t(data["repository"]["last_modified"]),
        )
    )

    for arcdata in data["archives"]:
        lbs = {"archive": arcdata["name"]}
        lbs["dataset"] = arcdata["comment"]

        p(
            Metric(
                name="archive_start_timestamp_seconds",
                labels=lbs,
                value=r2t(arcdata["start"]),
            )
        )
        p(
            Metric(
                name="archive_end_timestamp_seconds",
                labels=lbs,
                value=r2t(arcdata["end"]),
            )
        )

        archivestats = arcdata["stats"]
        for k, v in archivestats.items():
            p(
                Metric(
                    name="archive_%s"
                    % (
                        (k + "_bytes")
                        if ("size" in k)
                        else ("files" if k == "nfiles" else k)
                    ),
                    labels=lbs,
                    value=v,
                )
            )

    end = time.time()
    p(Metric(name="collection_duration_seconds", labels={}, value=(end - start)))

    f = None
    try:
        if promfile in ["/dev/stdout", "-"]:
            f = sys.stdout
        else:
            if os.path.sep not in promfile:
                cmd = "cat /proc/$(pgrep -f ^/usr/bin/.*node.*exporter | head -1 | cut -d ' ' -f 1)/cmdline | tr '\\0' '\\n' | grep -- --collector.textfile.directory= | sed 's/^--collector.textfile.directory=//'"
                directory = subprocess.check_output(
                    ["bash", "-c", cmd], text=True
                ).rstrip("\n")
                assert directory, (
                    f"specified --telemetry-file {promfile}"
                    " must be an absolute path because Node Exporter"
                    " is not running and therefore we cannot detect"
                    " the correct textfile collector path"
                )
                promfile = os.path.join(directory, promfile)
            f = open(promfile + "." + str(os.getpid()), "w")
        for m in metrics:
            f.write((repr_metric(m) + "\n"))
        if f != sys.stdout:
            f.flush()
            os.chmod(promfile + "." + str(os.getpid()), 0o644)
            os.rename(promfile + "." + str(os.getpid()), promfile)
    finally:
        if f and f != sys.stdout:
            f.close()


def parse_args(args: List[str]) -> Tuple[argparse.Namespace, List[str]]:
    p = argparse.ArgumentParser()
    p.add_argument(
        "--telemetry-file",
        help="path to a file to save Node Exporter telemetry after creating a backup; if the path is not standard output and does not contain a directory, the Node Exporter collector directory will be used",
    )
    p.add_argument(
        "--telemetry-timeout",
        type=int,
        help="how long to wait until giving up on collecting telemetry",
        default=DEFAULT_TELEMETRY_TIMEOUT,
    )
    p.add_argument(
        "--config",
        help="path to configuration file",
        default="/etc/borg-offsite-backup.conf",
    )
    p.add_argument(
        "subcommand",
        help="borg execution mode",
    )
    return p.parse_known_args(args)


def main(args: List[str]) -> None:
    global terminated

    if "--help" in args or "-h" in args or "-?" in args:
        sys.exit(subprocess.call(["borg"] + args))

    opts, borg_args = parse_args(args)

    signal.signal(signal.SIGTERM, sigterm)

    def fatal(exit: int, msg: str, *more: str) -> None:
        status(msg, *more)
        sys.exit(exit)

    if opts.subcommand == "telemetry" and not opts.telemetry_file:
        fatal(
            os.EX_USAGE,
            "The telemetry subcommand is not supported without a "
            "--telemetry-file argument",
        )

    try:
        c = Cfg.from_file(opts.config)
    except ValueError as e:
        fatal(os.EX_USAGE, "Fatal error: " + str(e))
        return  # appease the type checking gods

    execdate = output("date", "+%Y-%m-%d").strip()
    os.environ["BORG_KEY_FILE"] = os.path.expanduser("~/.borg-offsite-backup.key")
    os.environ["BORG_PASSPHRASE"] = ""
    os.environ["BORG_REPO"] = "{backup_user}@{backup_server}:{backup_path}".format(
        **c.as_dict()
    )
    if "TMPDIR" not in os.environ:
        os.environ["TMPDIR"] = "/var/tmp"
    backup_root = "/run/borg-offsite-backup"

    backup_contexts: List[Callable[[], Any]] = []
    if opts.subcommand in ["create", "cleanup"]:
        backup_contexts.append(lambda: TmpDirContext(backup_root))
        if c.datasets_to_backup:
            backup_contexts.append(
                lambda: SnapshotContext(c.datasets_to_backup, execdate, backup_root)
            )
        if c.filesystems_to_backup:
            backup_contexts.append(
                lambda: BindMountContext(c.filesystems_to_backup, backup_root)
            )
        backup_contexts.append(
            lambda: CdContext(backup_root)
            if opts.subcommand == "create"
            else DummyContext()
        )

    connect_context: Any = (
        VmBridgeContext(c.bridge_vm, c.backup_server)
        if c.bridge_vm is not None
        else NoBridgeContext(c.backup_server)
    )

    with connect_context:
        with multi_context(*backup_contexts) as context_results:
            snapshot_contexts = [
                c for c in context_results if isinstance(c, SnapshotContext)
            ]
            read_special = (
                snapshot_contexts[0].read_special if snapshot_contexts else False
            )
            retval = None
            if not terminated:
                retval = (
                    None
                    if opts.subcommand in ["cleanup", "telemetry"]
                    else run(
                        c,
                        execdate,
                        c.compression,
                        opts.subcommand,
                        read_special,
                        borg_args,
                    )
                )
            if terminated and retval:
                retval = 0
                # Now wait so the service manager does not kill our cleanup right away.
                time.sleep(1)

        time_to_prune = (
            isinstance(c.prune_on_days, list)
            and datetime.datetime.today().day in c.prune_on_days
        )
        if time_to_prune:
            pruneretval = (
                run_prune(c.keep_daily, c.keep_weekly, c.keep_monthly)
                if (opts.subcommand == "create" and not terminated and retval == 0)
                else None
            )
            compactretval = (
                run_compact()
                if (
                    opts.subcommand == "create"
                    and not terminated
                    and retval == 0
                    and pruneretval == 0
                )
                else None
            )
        else:
            pruneretval = None
            compactretval = None

        if (
            opts.telemetry_file
            and opts.subcommand in ["create", "telemetry", "delete", "rename", "prune"]
            and not terminated
        ):
            run_collector(opts.telemetry_file, opts.telemetry_timeout)

    for val, proc in [
        (retval, "Backup"),
        (pruneretval, "Prune"),
        (compactretval, "Compact"),
    ]:
        if val not in (0, None):
            print("%s failed.  Exiting with status %s" % (proc, val), file=sys.stderr)
            sys.exit(val)


if __name__ == "__main__":
    main(sys.argv[1:])
