/*
 * SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
 * Copyright (c) 2018-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * See file LICENSE for terms.
 */

#ifndef PARALLEL_TREE_FINDER_H
#define PARALLEL_TREE_FINDER_H

#include "agg_tree.h"
#include "agg_types.h"
#include "dump_file.h"
#include "sub_tree_info.h"
#include "sub_tree_score.h"
#include "thread_pool.h"

class ParallelTreeFinder;

// A single task of tree finding, checking whether a single tree can match
class TreeFinderTask : public ThreadPoolTask
{
    AggTree& m_agg_tree_;
    ParallelTreeFinder& m_tree_finder_;

    JobSubTreeInfo m_job_sub_tree_info_;
    JobSubTreeScore m_sub_tree_result_;
    SharpExtJobId m_external_job_id_;
    AggTreeType m_tree_type_;

   public:
    TreeFinderTask(AggTree& agg_tree, ParallelTreeFinder& tree_finder)
        : m_agg_tree_(agg_tree), m_tree_finder_(tree_finder), m_tree_type_(AGG_TREE_TYPE_SHARP)
    {}

    virtual ~TreeFinderTask() {}

    inline JobSubTreeInfo& GetTaskJobSubTreeInfo() { return m_job_sub_tree_info_; }
    inline JobSubTreeScore& GetTaskJobSubTreeScore() { return m_sub_tree_result_; }

    sharp_trees_t GetTreeId() const { return m_agg_tree_.GetId(); }
    void SetJobId(const SharpExtJobId& external_job_id) { m_external_job_id_ = external_job_id; }
    void SetTreeType(AggTreeType tree_type) { m_tree_type_ = tree_type; }

    virtual void Run();
};

using TreeFinderTasksVector = std::vector<TreeFinderTask>;

// Invoke method on all trees In parallel using multi threading
class ParallelTreeFinder : public ThreadPoolTasksCollection
{
    ThreadPool<default_task_queues_size> m_thread_pool_;
    TreeFinderTasksVector m_tasks_;

    // search data
    SetPortDataConstPtr const* m_compute_ports_;
    JobResource const* m_job_resource_;

    std::unique_ptr<file_utils::DumpFile> m_job_failures_dump_file_ptr_;

    void CreateTasks();

   public:
    ParallelTreeFinder()
        : ThreadPoolTasksCollection(),
          m_thread_pool_(),
          m_compute_ports_(nullptr),
          m_job_resource_(nullptr),
          m_job_failures_dump_file_ptr_{
              file_utils::GetDumpFileIfEnabled("Failed Job Requests", "sharp_am_failed_job_requests_details.dump")}
    {}

    ~ParallelTreeFinder();

    void Init();

    bool FindBestSharpTree(const SetPortDataConstPtr& compute_ports,
                           const JobResource& job_resource,
                           const SharpJob& p_job,
                           JobSubTreeInfo& best_tree_info,
                           JobSubTreeScore& best_tree_result);

    bool FindBestMulticastTree(const SetPortDataConstPtr& compute_ports,
                               const JobResource& job_resource,
                               const SharpJob& p_job,
                               JobSubTreeInfo& best_tree_info,
                               JobSubTreeScore& best_tree_result,
                               SharableMlidPtr& sharable_mlid);

    void DumpJobRequestFailure(const uint64_t job_id,
                               const uint64_t external_job_id,
                               JobSubTreeInfo* best_tree_info = nullptr,
                               char const* const failure_reason = nullptr);

    inline const SetPortDataConstPtr* GetComputePorts() { return m_compute_ports_; }
    inline const JobResource* GetJobResource() { return m_job_resource_; }

    int CreateLltTreeForJob(const SharpJob& p_job, JobSubTreeInfo& job_sub_tree_info);

    int CreateSatTreeForJob(const SharpJob& p_job,
                            const JobResource& job_resource,
                            JobSubTreeInfo& job_sub_tree_info,
                            JobSubTreeScore& sub_tree_result);

    int CreateMulticastTreeForJob(const SharpJob& p_job, JobSubTreeInfo& job_sub_tree_info, SharableMlidPtr& sharable_mlid);

    void ReCreateTasks();

   private:
    void RunTasks(const SetPortDataConstPtr& compute_ports, const JobResource& job_resource, const SharpJob& p_job, AggTreeType tree_type);
};

#endif   // PARALLEL_TREE_FINDER_H
