#!/bin/bash
set -e
# Copyright (C) 2022 Mo Zhou <lumin@debian.org>
# MIT/Expat License.
#
# Nvidia CUDA Deep Neural Network Library installer script (Debian Specific)
# Borrowed bits from Archlinux:
# https://github.com/archlinux/svntogit-community/blob/packages/cudnn/trunk/PKGBUILD
# Useful References:
# https://developer.nvidia.com/cuDNN
# https://developer.nvidia.com/rdp/cudnn-archive
# https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/

## configs ####################################################################
TMPDIR="$(mktemp -d)"
TMPDIR_IS_OVERRIDEN=0
ARCH="$(dpkg-architecture -qDEB_HOST_ARCH)"
MULTIARCH="$(dpkg-architecture -qDEB_HOST_MULTIARCH)"
PREFIX="/usr"
# XXX: browse this to update URLs: https://developer.download.nvidia.com/compute/redist/cudnn/
CUDA_VER="11.7"
_CUDNN_VER="8.5.0"
CUDNN_VER="8.5.0.96"
URL_amd64="https://developer.download.nvidia.com/compute/redist/cudnn/v8.5.0/local_installers/11.7/cudnn-linux-x86_64-8.5.0.96_cuda11-archive.tar.xz"
URL_ppc64el="https://developer.download.nvidia.com/compute/redist/cudnn/v8.5.0/local_installers/11.7/cudnn-linux-ppc64le-8.5.0.96_cuda11-archive.tar.xz"
URL_arm64="https://developer.download.nvidia.com/compute/redist/cudnn/v8.5.0/local_installers/11.7/cudnn-linux-sbsa-8.5.0.96_cuda11-archive.tar.xz"

## usage ######################################################################
usage () {
    cat << EOF
Usage: $(basename $0) <-d|-u|-p|-h> [--prefix <prefix>] [--tmpdir <tmpdir>] ...
Arguments:
 -d|--download           download only (default: 0)
 -u|--update             update cudnn installation (default: 0)
 -p|--purge              purge cudnn installation (default: 0)
 -h|--help               display this help message
 --arch <arch>           override architecture (default: $(dpkg-architecture -qDEB_HOST_ARCH))
 --multiarch <multiarch> override multiarch triplet (default: $(dpkg-architecture -qDEB_HOST_MULTIARCH))
 --prefix <path>         override install prefix (default: /usr)
 --tmpdir <dir>          override temporary directory (default: ${TMPDIR})
Testing this script cross-architecture:
 $ update-nvidia-cudnn --arch amd64   --multiarch x86_64-linux-gnu      --tmpdir . --prefix fake {-d,-u,-p}
 $ update-nvidia-cudnn --arch ppc64el --multiarch powerpc64le-linux-gnu --tmpdir . --prefix fake {-d,-u,-p}
 $ update-nvidia-cudnn --arch arm64   --multiarch aarch64-linux-gnu     --tmpdir . --prefix fake {-d,-u,-p}
Version: cuDNN ${CUDNN_VER} for CUDA ${CUDA_VER}
EOF
}

## argument parsing ###########################################################
DOWNLOAD_ONLY=0
DO_UPDATE=0
DO_PURGE=0
while [[ $# -gt 0 ]]; do
    case $1 in
        -d|--download)
            DOWNLOAD_ONLY=1; shift;;
        -u|--update)
            DO_UPDATE=1; shift;;
        -p|--purge)
            DO_PURGE=1; shift;;
        --prefix)
            if test -n "$2"; then
                PREFIX="$2"
            fi
            shift; shift;;
        --arch)
            if test -n "$2"; then
                ARCH="$2"
            fi
            shift; shift;;
        --multiarch)
            if test -n "$2"; then
                MULTIARCH="$2"
            fi
            shift; shift;;
        --tmpdir)
            if test -n "$2"; then
                TMPDIR=$2
                TMPDIR_IS_OVERRIDEN=1
            fi
            shift; shift;;
        -h|--help)
            usage; exit 0;;
        -*|--*)
            usage; exit 1;;
        *)
            usage; exit 1;;
    esac
done
# post processing
if test ${ARCH} = "amd64"; then
    URL=${URL_amd64}
elif test ${ARCH} = "ppc64el"; then
    URL=${URL_ppc64el}
elif test ${ARCH} = "arm64"; then
    URL=${URL_arm64}
else
    echo $0: Unsupported architecture ${arch} 1>/dev/stderr
fi

## functions ##################################################################
download_cudnn () {
    # Download cudnn tarball to ${TMPDIR}
    # args: $1: URL for upstream tarball
    # return: saved file destination
    test -n "${1}" || (echo "download_cudnn(): URL not specified"; exit 1)
    local url="${1}"
    test -d ${TMPDIR} || mkdir ${TMPDIR}
    local dest="${TMPDIR}/$(basename ${url})"
    local cmd="wget --verbose --show-progress=on --progress=bar --hsts-file=/tmp/wget-hsts -c ${URL} -O ${dest}"
    if ! test -f ${dest}; then
        echo ${cmd} 1>/dev/stderr
        bash -c "${cmd}" || bash -c "${cmd} --no-check-certificate" 1>/dev/stderr
    else
        echo Skipping download as file already exists: ${dest} 1>/dev/stderr
    fi
    test -f ${dest} || (echo "Download failed."; exit 1) 1>/dev/stderr
    echo ${dest}
}

install_cudnn () {
    test -n "${1}" || (echo "install_cudnn(): invalid argument"; exit 1)
    test -n "${2}" || (echo "install_cudnn(): invalid argument"; exit 1)
    # Install extracted cudnn from src to dst
    local src=${1}  # e.g. /tmp/nvidia-cudnn/
    local dst=${2}  # e.g. /usr/local/
    FILES=( $(find ${src} -type f,l) )
    for F in ${FILES[@]}; do
        (echo ${F} | grep -qo "cudnn.txz") && continue
        if (echo ${F} | grep -qo ".*/libcudnn.*\.so.*"); then
            # shared object file
            if test -L ${F}; then
                mkdir -p ${dst}/lib/${MULTIARCH}/ || true
                cp -av ${F} ${dst}/lib/${MULTIARCH}/
            else
                install -vDm0644 -t ${dst}/lib/${MULTIARCH} ${F}
            fi
        elif $(echo ${F} | grep -qo ".*/libcudnn.*\.a"); then
            # static library file
            install -vDm0644 -t ${dst}/lib/${MULTIARCH} ${F}
        elif $(echo ${F} | grep -qo ".*/cudnn.*\.h"); then
            # header file
            install -vDm0644 -t ${dst}/include/${MULTIARCH} ${F}
        elif $(echo ${F} | grep -qo "NVIDIA_SLA_cuDNN_Support.txt"); then
            # copyright file
            install -vDm0644 -t ${dst}/share/doc/nvidia-cudnn/ ${F}
        else
            echo Skipped ${F}
        fi
    done
}

purge_cudnn () {
    test -n "${1}" || (echo "install_cudnn(): invalid argument"; exit 1)
    # Purge cudnn from the given path
    local dst="${1}"
    FILES=( $(find ${dst}/lib/${MULTIARCH} -type f,l -name "libcudnn*.so*") )
    FILES+=( $(find ${dst}/include/${MULTIARCH} -type f -name "cudnn*.h") )
    FILES+=( $(find ${dst}/lib/${MULTIARCH} -type f -name "libcudnn*.a") )
    FILES+=( ${dst}/share/doc/nvidia-cudnn/NVIDIA_SLA_cuDNN_Support.txt )
    if test 0 -eq ${#FILES[@]}; then
        exit 0
    fi
    for F in ${FILES[@]}; do
        (test -e ${F} || test -L ${F}) && rm -rv ${F}
    done
}

# flag check: must select one valid action
test ${DOWNLOAD_ONLY} -eq 0 && \
    test ${DO_UPDATE} -eq 0 && \
    test ${DO_PURGE} -eq 0 && \
    (usage; exit 0)

# trigger actions
if test "${DOWNLOAD_ONLY}" -ne 0; then
    path=$(download_cudnn ${URL})
    echo ${path}
    exit 0
elif test "${DO_UPDATE}" -ne 0; then
    path=$(download_cudnn ${URL})
    tmpdir2=$(mktemp -d)
    tar xvf ${path} -C ${tmpdir2}/
    install_cudnn ${tmpdir2} ${PREFIX}
    rm -rf ${tmpdir2}
    # cleanup
    if test 0 -eq ${TMPDIR_IS_OVERRIDEN}; then
        rm -rv ${path}
        rmdir -v ${TMPDIR}
    fi
elif test "${DO_PURGE}" -ne 0; then
    echo Purging cuDNN installation from ${PREFIX}
    purge_cudnn ${PREFIX} || true
fi

