I wrote a C++ implementation of a segment tree (I wrote it for an arbitrary function, not "lazy" version) and I want so much to ask for a review. I'd like to know how I can make it better, especially the consistency of the code and the design.

template<typename T_Type, typename Function, typename MultipleFunction>
class SegmentTree {
private:
    std::vector<T_Type> data_;
    std::vector<bool> has_changed_;
    size_t real_size_;
    Function func_;
    MultipleFunction multiple_func_;
    T_Type default_val_;

    inline void push(size_t vertex) {
        if (has_changed_[vertex]) {
            data_[2 * vertex + 1] = data_[2 * vertex + 2] = data_[vertex];
            has_changed_[2 * vertex + 1] = has_changed_[2 * vertex + 2] = true;
            has_changed_[vertex] = false;
        }
    }

    inline T_Type getData(size_t vertex, size_t left, size_t right) const {
        if (has_changed_[vertex]) {
            return multiple_func_(data_[vertex], right - left);
        } else {
            return data_[vertex];
        }
    }

    // makes a change from position "from" to position "to"
    // STL style: [left, right), [from, to)
    void setRecursive(size_t vertex, size_t left, size_t right,
                      size_t from, size_t to, T_Type val) {
        if (from >= right || left >= to) {
            return;
        }
        if (left == from && right == to) {
            data_[vertex] = val;
            has_changed_[vertex] = true;
        } else {
            push(vertex);
            size_t mid = (left + right) / 2;
            setRecursive(2 * vertex + 1, left, mid,
                         from, std::min(to, mid), val);
            setRecursive(2 * vertex + 2, mid, right,
                         std::max(from, mid), to, val);
            data_[vertex] = func_(
                    getData(2 * vertex + 1, left, mid),
                    getData(2 * vertex + 2, mid, right)
            );
        }
    }

    T_Type getRecursive(size_t vertex, size_t left, size_t right,
                        size_t from, size_t to) {
        if (from >= right || left >= to) {
            return default_val_;
        }
        if (left == from && right == to) {
            return getData(vertex, left, right);
        } else {
            push(vertex);
            size_t mid = (left + right) / 2;
            T_Type left_ans = getRecursive(2 * vertex + 1, left, mid,
                                           from, std::min(to, mid));
            T_Type right_ans = getRecursive(2 * vertex + 2, mid, right,
                                            std::max(from, mid), to);
            data_[vertex] = func_(
                    getData(2 * vertex + 1, left, mid),
                    getData(2 * vertex + 2, mid, right)
            );
            return func_(left_ans, right_ans);
        }
    }

    void buildRecursive(size_t vertex, size_t left, size_t right,
                        const std::vector<T_Type> &data) {
        if (left + 1 == right) {
            data_[vertex] = data[left];
        } else {
            size_t mid = (left + right) / 2;
            buildRecursive(2 * vertex + 1, left, mid, data);
            buildRecursive(2 * vertex + 2, mid, right, data);
            data_[vertex] = func_(
                    data_[2 * vertex + 1],
                    data_[2 * vertex + 2]
            );
        }
    }

public:
    SegmentTree(const std::vector<T_Type> &data, Function func = Function(),
                MultipleFunction multiple_func = MultipleFunction(),
                T_Type default_val = T_Type(0))
            : data_(), has_changed_(), real_size_(data.size()), func_(func), multiple_func_(multiple_func),
              default_val_(default_val) {
        size_t good_size = 1;
        while (good_size < 2 * real_size_) {
            good_size <<= 1;
        }
        data_.resize(good_size, default_val_);
        has_changed_.resize(good_size, false);

        buildRecursive(0, 0, real_size_, data);
    }

    inline void set(size_t left, size_t right, T_Type val) {
        setRecursive(0, 0, real_size_, left, right, val);
    }

    inline void set(size_t pos, T_Type val) {
        set(pos, pos + 1, val);
    }

    inline const T_Type get(size_t left, size_t right) {
        return getRecursive(0, 0, real_size_, left, right);
    }

    inline const T_Type get(size_t pos) {
        return get(pos, pos + 1);
    }

    inline const T_Type get() {
        return get(0, real_size_);
    }
};

An example:

SegmentTree<int, std::plus<int>, std::multiplies<int>> tree({5, 7, 3});
std::cout << tree.get(0, 2) << std::endl;
return 0;
share|improve this question

Your Answer

 
discard

By posting your answer, you agree to the privacy policy and terms of service.

Browse other questions tagged or ask your own question.