#! /usr/bin/env python

# Copyright (C) 2023- The University of Notre Dame
# This software is distributed under the GNU General Public License.
# See the file COPYING for details.

# Plot the time spent matching tasks to workers through the information in the performance log

import sys
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

def read_fields(f, lines_patience = 10):
    for line in f:
        if line[0] != '#':
            lines_patience = lines_patience - 1
        else:
            return line.strip('#\n\r\t ').split()
        if lines_patience < 1:
            break
    sys.stderr.write("Could not find fields. This is probably not a performance log...\n")
    sys.exit(1)


def plot_performance(performance_filename):
    """ 
    This function is used to plot the accumulated schedulling time 
    for all tasks and the separate schedulling time for each task 
    """
    
    print(f"Plotting schedulling performance of task...")
    f = open(performance_filename)
    fields = read_fields(f)
    f.seek(0)

    dispatch_offset = fields.index("tasks_dispatched")
    scheduling_offset = fields.index("time_scheduling")

# STEP1: Get the accmulated time as the number of tasks grow
    num_timestamps = 0
    tasks_dispatched = []
    time_scheduling = []
    for line in f:
        if line[0] == "#":
            continue
        items = line.strip("\n\r\t ").split()
        num_timestamps += 1
        tasks_dispatched.append(int(items[dispatch_offset]))
        time_scheduling.append((items[scheduling_offset]))

    tasks_scheduled = []
    accumulated_time = []
    last_task_id = -1
    for i, task_id in enumerate(tasks_dispatched):
        if int(time_scheduling[i]) == 0:
            continue
        if task_id > last_task_id:
            tasks_scheduled.append(task_id)
            accumulated_time.append(int(time_scheduling[i]))
        last_task_id = task_id

    # STEP2: Generate the scheduled time for each task
    scheduled_task = []
    scheduled_time = []
    # There are a lot of concurrent tasks in a period, split the execution time evenly among them
    for idx, tid in enumerate(tasks_scheduled):
        if tid != tasks_scheduled[-1]:
            next_tid = tasks_scheduled[idx + 1]
            task_gap = next_tid - tid

            this_time, next_time = accumulated_time[idx], accumulated_time[idx + 1]
            time_gap = next_time - this_time
            time_for_each_task = time_gap / task_gap

            for tid_ in range(tid, next_tid, 1):
                scheduled_task.append(tid_)
                scheduled_time.append(time_for_each_task)

    # time_top means the most popular schedulling time of all tasks, with given value indicating how popular it is
    mean_scheduled_time = sum(scheduled_time) / len(scheduled_time)
    variance = sum([(x - mean_scheduled_time) ** 2 for x in scheduled_time]) / (len(scheduled_time) - 1) ** 0.5
    std_dev = variance ** 0.5
    
    lower_bound = mean_scheduled_time - 1 * std_dev
    upper_bound = mean_scheduled_time + 1 * std_dev
    count = sum(1 for x in scheduled_time if lower_bound <= x <= upper_bound)
    top_value = count / len(scheduled_time) * 100

    data_sorted = sorted(scheduled_time)
    index = (top_value / 100) * len(data_sorted)
    if index.is_integer():
        time_top = data_sorted[int(index) - 1]
    else:
        lower_val = data_sorted[int(index) - 1]
        upper_val = data_sorted[int(index)]
        time_top = lower_val + (upper_val - lower_val) * (index % 1)

    scheduled_task_top, scheduled_time_top = [], []
    scheduled_task_tail, scheduled_time_tail = [], []
    for idx, time in enumerate(scheduled_time):
        if time < time_top:
            scheduled_task_top.append(scheduled_task[idx])
            scheduled_time_top.append(scheduled_time[idx])
        else:
            scheduled_task_tail.append(scheduled_task[idx])
            scheduled_time_tail.append(scheduled_time[idx])

    # STEP3: Draw three figures: the left one is accumulated time with accumulated tasks
    # the left top one is the schedulling time for each task in most cases
    # the right bottom one is the schedulling time for each task which could have extremely large number in some exceptional cases
    plt.figure(figsize=(10, 6))
    gs = gridspec.GridSpec(2, 2, width_ratios=[2, 1])
    
    ax1 = plt.subplot(gs[:, 0])
    ax2 = plt.subplot(gs[0, 1])
    ax3 = plt.subplot(gs[1, 1])

    ax1.plot(tasks_scheduled, accumulated_time)
    ax1.set_ylim(ymin=0)
    ax1.set_xlabel("Scheduled Tasks")
    ax1.set_ylabel("Accumulated Time (us)")

    ax2.plot(scheduled_task_top, scheduled_time_top)
    ax2.set_ylim(ymin=0)
    ax2.set_xlabel(f"Scheduled Task (top {top_value:.2f}%)")
    ax2.set_ylabel("Schedulling Time (us)")

    ax3.plot(scheduled_task_tail, scheduled_time_tail)
    ax3.set_ylim(ymin=0)
    ax3.set_xlabel(f"Scheduled Task (tail {100 - top_value:.2f}%)")
    ax3.set_ylabel("Schedulling Time (us)")
    
    plt.tight_layout()

    log_dir = os.path.dirname(performance_filename)
    save_png = os.path.join(log_dir, 'schedulling_performance.png')
    plt.savefig(save_png)
    print("Plotting done, result image has been saved in the log directory.")

    return


if __name__ == "__main__":
    performance_filename = sys.argv[1]
    plot_performance(performance_filename) 
