#!/usr/bin/python3 -Es

# ==============================================================
# Daemon that gathers statistical information from all 
# Nvidia GPUs in the current system. Every second, it gathers 
# the statistics of every GPU and maintains cumulative counters,
# globally and per process.
#
# Client processes can connect to this daemon on TCP port 59123.
# Clients can send requests of two bytes, consisting of one byte
# request code followed by one byte integer version number.
# The request code can be 'T' to obtain the GPU types or 'S' to
# obtain all statistical counters.
# The response of the daemon starts with a 4-byte integer. The
# first byte is the version of the response format and the 
# subsequent three bytes indicate the length (big endian) of the
# response string that follows. See the formatters for the layout
# of the response string, later on in this source code.
#
# Dependencies: pip/pip3 install nvidia-ml-py
#
# This program can be executed by python2 or python3 (just change
# the first line of this source file).
# --------------------------------------------------------------
# Author: Gerlof Langeveld
# Date:   July 2018 (initial)
#
# Copyright (C) 2018 Gerlof Langeveld
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
# ==============================================================

import os
import sys
import time
import socket
import struct
import logging
import logging.handlers as loghand
import threading

GPUDPORT     = 59123		# TCP port number server

COMPUTE      = 1		# task support bit value
ACCOUNT      = 2		# task support bit value

# =================================
# GPU related bookkeeping
# =================================
gpulist = []			# list with per-GPU bookkeeping
cliterm = {}			# dict with one entry per client (client
				# socket as key), that contains a dict with
				# the terminated per-process bookkeepings
				# that still have to be received by this client
                                # (pid of terminated process as key)

gpulock = threading.Lock()	# mutex for access to gpulist/cliterm


# =================================
# per-GPU class
# =================================
class Stats(object): pass		# generic statistics container

class GpuProp(object):
    ###############################
    # initialization method to setup
    # properties
    ###############################
    def __init__(self, num):
        gpuhandle = pynvml.nvmlDeviceGetHandleByIndex(num)
        pciInfo   = pynvml.nvmlDeviceGetPciInfo(gpuhandle)

        self.gpuhandle         = gpuhandle
        self.stats             = Stats()

        
        self.stats.busid       = pciInfo.busId
        self.stats.devname     = pynvml.nvmlDeviceGetName(gpuhandle)

        # nvml backward compatibility
        if type(self.stats.busid) == bytes:
            self.stats.busid = self.stats.busid.decode('ascii', errors='replace') 
        if type(self.stats.devname) == bytes:
            self.stats.devname = self.stats.devname.decode('ascii', errors='replace') 

        self.stats.devname = self.stats.devname.replace(' ', '_')

        self.stats.tasksupport = 0		# process stats support

        try:
            procinfo = pynvml.nvmlDeviceGetComputeRunningProcesses(gpuhandle)
            self.stats.tasksupport |= COMPUTE	# compute support
        except Exception:
            pass				# no compute support

        try:
            pynvml.nvmlDeviceSetAccountingMode(gpuhandle, True)
            pynvml.nvmlDeviceSetPersistenceMode(gpuhandle, True)  # NVIDIA advise
            self.stats.tasksupport |= ACCOUNT	# account support
        except Exception as e:
            pass

        self.stats.gpupercnow   = 0	# perc of time that GPU was busy
        self.stats.mempercnow   = 0	# perc of time that memory was rd/wr

        self.stats.memtotalnow  = 0	# in Kb
        self.stats.memusednow   = 0	# in Kb

        self.stats.gpusamples   = 0
        self.stats.gpuperccum   = 0	# perc of time that GPU was busy
        self.stats.memperccum   = 0	# perc of time that memory was rd/wr
        self.stats.memusedcum   = 0	# in KiB
 
        self.stats.procstats    = {}	# stats of active processes (key = pid)

    ###############################
    # method to fetch counters and values
    ###############################
    def readstats(self):
        self.stats.gpusamples += 1

        # -----------------------------
        # get rates (utilization percentages)
        # -----------------------------
        try:
            rates = pynvml.nvmlDeviceGetUtilizationRates(self.gpuhandle)

            self.stats.gpupercnow  = rates.gpu
            self.stats.mempercnow  = rates.memory
            self.stats.gpuperccum += rates.gpu
            self.stats.memperccum += rates.memory
        except pynvml.NVMLError as err:
            self.stats.gpupercnow  = -1
            self.stats.mempercnow  = -1
            self.stats.gpuperccum  = -1
            self.stats.memperccum  = -1

        # -----------------------------
        # get memory occupation GPU-wide
        # -----------------------------
        try:
            meminfo = pynvml.nvmlDeviceGetMemoryInfo(self.gpuhandle)

            self.stats.memtotalnow = meminfo.total // 1024
            self.stats.memusednow  = meminfo.used  // 1024
            self.stats.memusedcum += meminfo.used  // 1024	# in KiB
        except pynvml.NVMLError as err:
            pass

        # -----------------------------
        # get per-process statistics
        # -----------------------------
        try:
            procinfo = pynvml.nvmlDeviceGetComputeRunningProcesses(
                                                        self.gpuhandle)

            # -------------------------
            # build list with pids from
            # the previous interval
            # -------------------------
            actprocs = list(self.stats.procstats.keys())

            # -------------------------
            # handle proc stats of this
            # interval
            # -------------------------
            for proc in procinfo:
                pid = proc.pid

                # ---------------------
                # new process? 
                #     create new stats 
                # ---------------------
                if pid not in actprocs:
                    self.stats.procstats[pid] = Stats()

                    self.stats.procstats[pid].memnow  = 0	# in KiB
                    self.stats.procstats[pid].memcum  = 0	# in KiB
                    self.stats.procstats[pid].sample  = 0

                    self.stats.procstats[pid].gpubusy = -1
                    self.stats.procstats[pid].membusy = -1
                    self.stats.procstats[pid].timems  = -1
                else:
                    actprocs.remove(pid)

                # ---------------------
                # maintain proc stats
                # ---------------------
                if proc.usedGpuMemory:
                    self.stats.procstats[pid].memnow  = proc.usedGpuMemory//1024
                    self.stats.procstats[pid].memcum +=	proc.usedGpuMemory//1024
                    self.stats.procstats[pid].sample += 1

                if self.stats.tasksupport & ACCOUNT:
                    try:
                        stats = pynvml.nvmlDeviceGetAccountingStats(self.gpuhandle, pid)

                        self.stats.procstats[pid].gpubusy = stats.gpuUtilization
                        self.stats.procstats[pid].membusy = stats.memoryUtilization
                        self.stats.procstats[pid].timems  = stats.time
                    except Exception:
                        pass

            # -------------------------
            # determine which processes
            # have terminated since 
            # previous sample
            # -------------------------
            for pid in actprocs:
                 for client in cliterm:
                     cliterm[client][pid] = self.stats.procstats[pid]

                 del self.stats.procstats[pid]

        except pynvml.NVMLError as err:
            pass


    ###############################
    # obtain current statistics
    ###############################
    def getstats(self):
        return self.stats


# =================================
# Main function
# =================================
def main():
    # -----------------------------
    # initialize GPU access,
    # specifically to detect of it 
    # succeeds
    # -----------------------------
    try:
        pynvml.nvmlInit()
    except Exception:
        logging.error("Shared lib 'libnvidia-ml' probably not installed!")
        sys.exit()

    # -----------------------------
    #   open IPv6 stream socket 
    # -----------------------------
    try:
        mysock = socket.socket(socket.AF_INET6,  socket.SOCK_STREAM, 0)
    except Exception as sockex:
        try:
            mysock = socket.socket(socket.AF_INET,  socket.SOCK_STREAM, 0)
        except Exception as sockex:
            logging.error("Socket creation fails")
            sys.exit(1)

    # -----------------------------
    #   bind to local port and
    #   make socket passive
    # -----------------------------
    try:
        mysock.bind( ("", GPUDPORT) )
        mysock.listen(32)
    except Exception as sockex:
        logging.error("Socket binding to port %d fails", GPUDPORT)
        sys.exit(1)

    # -----------------------------
    # release parent process
    # (daemonize)
    # -----------------------------
    try:
        if os.fork():
            sys.exit(0)	# parent process exits; child continues...
    except Exception:
        logging.error("Failed to fork child")

    # -----------------------------
    # initialize GPU access for the
    # child process
    # -----------------------------
    try:
        pynvml.nvmlInit()
    except Exception:
        pass

    # -----------------------------
    # determine number of GPUs in
    # this system
    # -----------------------------
    gpunum = pynvml.nvmlDeviceGetCount()
    logging.info("Number of GPUs: %d", gpunum)

    if gpunum == 0:
        logging.info("Terminated (no GPUs available)")
        sys.exit() 

    # -----------------------------
    # initialize per-GPU bookkeeping
    # -----------------------------
    for i in range(gpunum):
        gpulist.append( GpuProp(i) )

    # -----------------------------
    # kick off new thread to fetch
    # statistics periodically
    # -----------------------------
    t = threading.Thread(target=gpuscanner, args=(1,))
    t.daemon = True
    t.start()

    logging.info("Initialization succeeded")

    # -----------------------------
    # main thread:
    #   await connect of client
    # -----------------------------
    while True:
        newsock, peeraddr = mysock.accept()

        # -------------------------
        # create new thread to
        # serve this client
        # -------------------------
        t = threading.Thread(target=serveclient, args=(newsock, peeraddr))
        t.daemon = True
        t.start()

 
# ===========================================
# Thread start function:
# Serve new client that has just
# connected.
#
# -------------------------------------------
# Protocol between client and server:
#
# - client transmits request 
#   consisting of two bytes
#
#     byte 0: type of request
#             'S' get statistical counters
#             'T' get type of each GPU
#
#     byte 1: integer version number
#             response layout might change
#             so the client asks for a
#             specific response version
#
# - server transmits response 
#   consisting of a four bytes integer
#   in big endian byte order
#
#     byte 0:   version number, preferably
#               as requested by the client
#
#     byte 1-3: length of the response string
#               that follows
#
#   followed by the response string that is
#   version specific (see gpuformatters)
# ===========================================
def serveclient(sock, peer):
    # -----------------------------
    # create per client bookkeeping
    # for terminated processes
    # -----------------------------
    with gpulock:
        cliterm[sock] = {}

    # -----------------------------
    # main loop
    # -----------------------------
    while True:
        # -------------------------
        # wait for request
        # -------------------------
        try:
            rcvbuf = sock.recv(20)
        except Exception as sockex:
            logging.error("Receive error: %s", sockex)
            sock.close()
            break

        # -------------------------
        # connection closed by peer?
        # -------------------------
        if not rcvbuf:	
            sock.close()
            break

        logging.debug("Received: %s", rcvbuf)

        # -------------------------
        # request has wrong length?
        # -------------------------
        if len(rcvbuf) != 2:
            logging.error('Wrong request length: %d', len(rcvbuf))
            sock.close()
            break

        # -------------------------
        # valid request:
        #     get statistical counters?
        # -------------------------
        try:
            command = chr(rcvbuf[0])	# Python3
            version = rcvbuf[1]
        except Exception:
            command = rcvbuf[0]		# Python2
            version = ord(rcvbuf[1])

        if command == 'S':
             if version == 0 or version >= len(gpuformatters):
                 version = len(gpuformatters)-1

             xmitbuf = gpuformatters[version](sock).encode('ascii',
                                                    errors='replace')

        # -------------------------
        # valid request:
        #     get GPU types?
        # -------------------------
        elif command == 'T':		
             if version == 0 or version >= len(gpudevnames):
                 version = len(gpudevnames)-1

             xmitbuf = gpudevnames[version]().encode('ascii', errors='replace')

        # -------------------------
        # invalid request!
        # -------------------------
        else:
            logging.error('Wrong request from client: %s', command)
            sock.close()
            break

        # -------------------------
        # transmit GPU statistics
        # as bytes
        # -------------------------
        logging.debug("Send: %s", xmitbuf)

        prelude = struct.pack(">I", (version << 24) + len(xmitbuf))

        try:
            sock.send(prelude)
            sock.send(xmitbuf)
        except Exception as sockex:
            logging.error("Send error: %s", sockex)
            sock.close()
            break

    # -----------------------------
    # delete per client bookkeeping
    # of terminated processes
    # -----------------------------
    with gpulock:
        del cliterm[sock]

    # -----------------------------
    # END OF CLIENT THREAD
    # -----------------------------


# =================================
# Generate sequence of device names
# =================================
def gpudevname_v1():
    # -----------------------------
    # main loop:
    # - get device name of every GPU
    # - convert into one string
    #   with format:
    #      numgpus@busid devname tasksupport@busid devname tasksupport@...
    # -----------------------------
    strbuf = str( len(gpulist) )

    with gpulock:
        for i, gpu in enumerate(gpulist):
            s = gpu.getstats()
            strbuf += "@{:s} {:s} {:d}".format(
                                    s.busid, s.devname, s.tasksupport)

    return strbuf

gpudevnames = [None, gpudevname_v1]


# =================================
# Convert statistics of all GPUs
# into parseable string
# =================================
def gpuformatter_v1(clisock):
    # -----------------------------
    # main loop:
    # - get statistics for every GPU
    # - convert stats to one string
    #   with format:
    #      numgpus@gpu0 stats#pid stats#pid stats@gpu1 stats#pid stats@....
    # -----------------------------
    strbuf = ""

    with gpulock:
        for i, gpu in enumerate(gpulist):
            s = gpu.getstats()

            # ---------------------
            # generic GPU stats
            # ---------------------
            strbuf += "@{:d} {:d} {:d} {:d} {:d} {:d} {:d} {:d}".format(
                       s.gpupercnow,  s.mempercnow,
                       s.memtotalnow, s.memusednow, s.gpusamples,
                       s.gpuperccum,  s.memperccum, s.memusedcum);

            # ---------------------
            # active processes for
            # this GPU
            # ---------------------
            for pid, stat in s.procstats.items():
                strbuf += "#A {:d} {:d} {:d} {:d} {:d} {:d} {:d}".format(pid,
                           stat.gpubusy, stat.membusy, stat.timems,
                           stat.memnow,  stat.memcum,  stat.sample)

            # ---------------------
            # terminated processes
            # for this GPU
            # ---------------------
            for pid, stat in cliterm[clisock].items():
                strbuf += "#E {:d} {:d} {:d} {:d} {:d} {:d} {:d}".format(pid,
                           stat.gpubusy, stat.membusy, stat.timems,
                           stat.memnow,  stat.memcum,  stat.sample)

            cliterm[clisock].clear()


    return strbuf


gpuformatters = [None, gpuformatter_v1]


# =================================
# Thread start function:
# Scan all GPUs with a particular
# interval to obtain their stats
# =================================
def gpuscanner(scaninterval):
    # -----------------------------
    # main loop:
    # - get statistics for every GPU
    # - sleep for interval
    # -----------------------------
    while True:
        with gpulock:
            for gpu in gpulist:
                gpu.readstats()

        time.sleep(scaninterval)

# ==========================================================================

# -----------------------------
# initialize logging
# -----------------------------
if '-v' in sys.argv:
    loglevel = logging.DEBUG
else:
    loglevel = logging.INFO

fm = logging.Formatter('atopgpud %(levelname)s: %(message)s')

fh = loghand.SysLogHandler('/dev/log', 
             facility=loghand.SysLogHandler.LOG_DAEMON)
fh.setFormatter(fm)
fh.setLevel(loglevel)

lg = logging.getLogger()		# root logger
lg.addHandler(fh)
lg.setLevel(loglevel)

# -----------------------------
# load module pynvml
# -----------------------------
try:
    import pynvml
except Exception:
    logging.error("Python module 'pynvml' not installed!")
    sys.exit(1)


try:
   # -----------------------------
   # call main function
   # -----------------------------
   main()

finally:
    # -----------------------------
    # shutdown GPU access
    # -----------------------------
    try:
        pynvml.nvmlShutdown()
    except Exception:
        pass
