/*-------------------------------------------------------------------------------
 This file is part of unityForest.

 Copyright (c) [2014-2018] [Marvin N. Wright]
 Modifications and extensions by Roman Hornung

 This software may be modified and distributed under the terms of the MIT license.

 Please note that the C++ core of divfor is distributed under MIT license and the
 R package "unityForest" under GPL3 license.
 #-------------------------------------------------------------------------------*/

#ifndef TREE_H_
#define TREE_H_

#include <set>
#include <vector>
#include <random>
#include <iostream>
#include <stdexcept>
#include <memory>

#include "globals.h"
#include "Data.h"

namespace unityForest
{

  // Definition of SplitData struct
  struct SplitData
  {
    size_t tree_idx;
    size_t node_idx;
    double split_criterion;

    // Use default constructor (leaves members uninitialized)
    SplitData() = default;

    // Parameterized constructor
    SplitData(size_t tree_idx, size_t node_idx, double split_criterion)
        : tree_idx(tree_idx), node_idx(node_idx), split_criterion(split_criterion) {}
  };

  class Tree
  {
  public:
    Tree();

    // Create from loaded forest
    Tree(std::vector<std::vector<size_t>> &child_nodeIDs, std::vector<size_t> &split_varIDs,
         std::vector<double> &split_values);

    // Constructor for repr_tree_mode:
    Tree(std::vector<std::vector<size_t>> &child_nodeIDs, std::vector<size_t> &split_varIDs,
         std::vector<double> &split_values,
         const Data *data_ptr);

    virtual ~Tree() = default;

    // Polymorphic deep-copy, implemented in every subclass
    virtual std::unique_ptr<Tree> clone() const = 0;

    // Tree(const Tree&) = delete;
    // Tree& operator=(const Tree&) = delete;
    Tree(const Tree &) = default;
    Tree &operator=(const Tree &) = default;

    void init(const Data *data, uint mtry, double prop_var_root, size_t dependent_varID, size_t num_samples, uint seed,
              std::vector<size_t> *deterministic_varIDs, std::vector<size_t> *split_select_varIDs,
              std::vector<double> *split_select_weights, ImportanceMode importance_mode, uint min_node_size, uint min_node_size_root,
              bool sample_with_replacement, bool memory_saving_splitting, SplitRule splitrule,
              std::vector<double> *case_weights, std::vector<size_t> *manual_inbag, bool keep_inbag,
              std::vector<double> *sample_fraction, double alpha, double minprop, bool holdout, uint num_random_splits,
              uint max_depth, uint max_depth_root, uint num_cand_trees,
              std::vector<size_t> repr_vars);

    virtual void allocateMemory() = 0;

    void grow(std::vector<double> *variable_importance);

    void predict(const Data *prediction_data, bool oob_prediction);

    void computeUFImportance(std::vector<double> &forest_importance);

    double computeUFNodeImportance(size_t nodeID, std::vector<size_t> oob_sampleIDs_nodeID);

    void computeSplitCriterionValues();

    void computeOOBSplitCriterionValues();

    void collectSplits(size_t tree_idx, std::vector<std::vector<SplitData>> &all_splits_per_variable);

    void collectOOBSplits(size_t tree_idx, std::vector<std::vector<SplitData>> &all_splits_per_variable);

    void countVariables(std::vector<size_t> &var_counts);

    void computeUv(size_t tree_ind, std::vector<std::vector<double>> &Uv);

    void setScoreVector(std::vector<double> scores_tree);

    void markBestSplits(size_t tree_idex, const std::vector<std::set<std::pair<size_t, size_t>>> &bestSplits);

    void markBestOOBSplits(size_t tree_idex, const std::vector<std::set<std::pair<size_t, size_t>>> &bestSplits);

    virtual double computeSplitCriterion(std::vector<size_t> sampleIDs_left_child, std::vector<size_t> sampleIDs_right_child);

    virtual double computeOOBSplitCriterionValue(size_t nodeID, std::vector<size_t> oob_sampleIDs_nodeID);

    virtual double computeOOBSplitCriterionValuePermuted(size_t nodeID, std::vector<size_t> oob_sampleIDs_nodeID, std::vector<size_t> permutations);

    const std::vector<std::vector<size_t>> &getChildNodeIDs() const
    {
      return child_nodeIDs;
    }

    const std::vector<double> &getSplitValues() const
    {
      return split_values;
    }

    const std::vector<size_t> &getSplitVarIDs() const
    {
      return split_varIDs;
    }

    const std::vector<size_t> &getOobSampleIDs() const
    {
      return oob_sampleIDs;
    }

    size_t getNumSamplesOob() const
    {
      return num_samples_oob;
    }

    const std::vector<size_t> &getInbagCounts() const
    {
      return inbag_counts;
    }

    const std::vector<bool> &getIsInBest() const
    {
      return is_in_best;
    }

    const std::vector<size_t> &getNodeIDInRoot() const
    {
      return nodeID_in_root;
    }

    const std::vector<double> &getScoreValues() const
    {
      return score_values;
    }

    void setAllowedVarIDs(const std::vector<size_t> *ids)
    {
      allowedVarIDs_ = ids;
    }

  protected:
    void createPossibleSplitVarSubset(std::vector<size_t> &result);

    // Function to evaluate a random tree:
    virtual double evaluateRandomTree(const std::vector<size_t> &terminal_nodes) = 0;

    virtual bool splitNodeInternal(size_t nodeID, std::vector<size_t> &possible_split_varIDs) = 0;

    // Split node in (full) tree
    bool splitNodeFullTree(size_t nodeID);

    // Split node in random tree
    bool splitNodeRandom(size_t nodeID, const std::vector<size_t> &varIDs_root);

    bool twoDifferentValues(size_t nodeID, size_t varID);

    // Create an empty node in a (full) tree
    void createEmptyNodeFullTree();

    // Create an empty node in a random tree
    void createEmptyNodeRandomTree();

    // Create an empty node in a random tree
    virtual void createEmptyNodeRandomTreeInternal() = 0;

    // Create an empty node in a (full) tree
    virtual void createEmptyNodeFullTreeInternal() = 0;

    // Check whether the current node in a random tree is final
    virtual bool checkWhetherFinalRandom(size_t nodeID) = 0;

    // Function used to clear some objects from the random trees
    void clearRandomTree();
    // Function used to clear some objects from the random trees
    virtual void clearRandomTreeInternal() = 0;

    void bootstrap();
    void bootstrapWithoutReplacement();

    void bootstrapWeighted();
    void bootstrapWithoutReplacementWeighted();

    virtual void bootstrapClassWise();
    virtual void bootstrapWithoutReplacementClassWise();

    void setManualInbag();

    virtual void cleanUpInternal() = 0;

    size_t dependent_varID;
    uint mtry;
    double prop_var_root;

    // Number of samples (all samples, not only inbag for this tree)
    size_t num_samples;

    // Number of OOB samples
    size_t num_samples_oob;

    // Minimum node size to split, like in original RF nodes of smaller size can be produced
    uint min_node_size;

    // Minimum node size to split in the tree roots
    uint min_node_size_root;

    // Variables for which representative trees should be found
    std::vector<size_t> repr_vars;

    // Weight vector for selecting possible split variables, one weight between 0 (never select) and 1 (always select) for each variable
    // Deterministic variables are always selected
    const std::vector<size_t> *deterministic_varIDs;
    const std::vector<size_t> *split_select_varIDs;
    const std::vector<double> *split_select_weights;

    // Bootstrap weights
    const std::vector<double> *case_weights;

    // Pre-selected bootstrap samples
    const std::vector<size_t> *manual_inbag;

    // Splitting variable for each node
    std::vector<size_t> split_varIDs;

    // Value to split at for each node, for now only binary split
    // For terminal nodes the prediction value is saved here
    std::vector<double> split_values;

    // These objects are needed in the loop generating the random trees and choosing the best out of them:
    std::vector<size_t> split_varIDs_loop;
    std::vector<double> split_values_loop;
    std::vector<size_t> start_pos_loop;
    std::vector<size_t> end_pos_loop;

    // These objects store the information on the best random trees, that is, the tree roots:
    std::vector<size_t> split_varIDs_best;
    std::vector<double> split_values_best;

    std::vector<std::vector<size_t>> child_nodeIDs_best;

    std::vector<std::vector<size_t>> child_nodeIDs_loop;

    std::vector<double> values_buffer;

    // Vector of left and right child node IDs, 0 for no child
    std::vector<std::vector<size_t>> child_nodeIDs;

    // All sampleIDs in the tree, will be re-ordered while splitting
    std::vector<size_t> sampleIDs;

    // For each node a vector with start and end positions
    std::vector<size_t> start_pos;
    std::vector<size_t> end_pos;

    // IDs of OOB individuals, sorted
    std::vector<size_t> oob_sampleIDs;

    // Holdout mode
    bool holdout;

    // Inbag counts
    bool keep_inbag;
    std::vector<size_t> inbag_counts;

    // Random number generator
    std::mt19937_64 random_number_generator;

    // Pointer to original data
    const Data *data;

    // Variable importance for all variables
    std::vector<double> *variable_importance;
    ImportanceMode importance_mode;

    // When growing here the OOB set is used
    // Terminal nodeIDs for prediction samples
    std::vector<size_t> prediction_terminal_nodeIDs;

    bool sample_with_replacement;
    const std::vector<double> *sample_fraction;

    bool memory_saving_splitting;
    SplitRule splitrule;
    double alpha;
    double minprop;
    uint num_random_splits;
    uint max_depth;
    // Maximum depth of random trees
    uint max_depth_root;
    // Number of random candidate trees to try for each tree root:
    uint num_cand_trees;
    uint depth;
    size_t last_left_nodeID;
    size_t last_left_nodeID_loop;

    // Vector, which gives for every nodeID in a (full) tree, the corresponding nodeID in the tree root (0 for nodes not in the tree root)
    std::vector<size_t> nodeID_in_root;

    std::vector<double> split_criterion;

    std::vector<bool> is_in_best;
    std::vector<bool> is_in_oob_best;

    // Representative trees
    std::vector<double> score_values;

    // Read-only pointer; Tree never modifies the vector, it just reads:
    const std::vector<size_t> *allowedVarIDs_{nullptr};
  };

} // namespace unityForest

#endif /* TREE_H_ */
