Take the 2-minute tour ×
Code Review Stack Exchange is a question and answer site for peer programmer code reviews. It's 100% free, no registration required.

I found myself in need of of a Readers-Writer mutex. With C++17 TR2 support not yet available in our compiler, I set out to implement std::shared_mutex so that we have an easy upgrade path to the STL implementation once we get C++17 support, rather than rolling my own API.

I put all classes intended to implement or supplement STL functionality in a namespace xtd short for "eXtended sTD". Reason being that when/if proper support arrives we can just swap xtd for std and be running the STL implementation.

In addition to std::shared_mutex we also need a Reader-Writer mutex that allows recursive locking for writers. Readers are always recursive any way. This is implemented as xtd::recursive_shared_mutex this class has no equivalent in standard C++ but has the same API as std::shared_mutex with some extensions.

In the code below I use a custom class called xtd::fast_recursive_mutex, this class is a fully compatible, drop-in-replacement for std::recursive_mutex but it uses CRITICAL_SECTION on windows for faster locking than std::recursive_mutex (at least on our compiler).

I'm interested in a review of correct operation and any gross inefficiencies of the classes.

xtd/shared_mutex.hpp

#pragma once
#include "fast_recursive_mutex.hpp"
#include <condition_variable>

namespace xtd {

    namespace detail {
        class shared_mutex_base {
        public:
            shared_mutex_base() = default;
            shared_mutex_base(const shared_mutex_base&) = delete;
            ~shared_mutex_base() = default;

            shared_mutex_base& operator = (const shared_mutex_base&) = delete;

        protected:
            using unique_lock = std::unique_lock < xtd::fast_recursive_mutex >;
            using scoped_lock = std::lock_guard < xtd::fast_recursive_mutex >;

            xtd::fast_recursive_mutex m_mutex;
            std::condition_variable_any m_exclusive_release;
            std::condition_variable_any m_shared_release;
            unsigned m_state = 0;

            void do_exclusive_lock(unique_lock& lk);
            bool do_exclusive_trylock(unique_lock& lk);
            void do_lock_shared(unique_lock& lk);
            bool do_try_lock_shared(unique_lock& lk);
            void do_unlock_shared(scoped_lock& lk);

            void take_exclusive_lock();
            bool someone_has_exclusive_lock() const;
            bool no_one_has_any_lock() const;
            unsigned number_of_readers() const;
            bool maximal_number_of_readers_reached() const;
            void clear_lock_status();
            void increment_readers();
            void decrement_readers();

            static const unsigned m_write_entered = 1U << (sizeof(unsigned)*CHAR_BIT - 1);
            static const unsigned m_num_readers = ~m_write_entered;
        };
    }

    /// <summary> A shared_mutex implemented to C++17 STL specification.
    /// 
    /// This is a Readers-Writer mutex with writer priority. Optional native_handle_type and
    /// native_handle members are not implemented.
    /// 
    /// For detailed documentation, see: http://en.cppreference.com/w/cpp/thread/shared_mutex. </summary>
    class shared_mutex : public detail::shared_mutex_base {
    public:
        shared_mutex() = default;
        shared_mutex(const shared_mutex&) = delete;
        ~shared_mutex() = default;

        shared_mutex& operator = (const shared_mutex&) = delete;

        /// <summary> Obtains an exclusive lock of this mutex. </summary>
        void lock();

        /// <summary> Attempts to exclusively lock this mutex. </summary>
        /// <returns> true if it the lock was obtained, false otherwise. </returns>
        bool try_lock();

        /// <summary> Unlocks the exclusive lock on this mutex. </summary>
        void unlock();

        /// <summary> Obtains a shared lock on this mutex. Other threads may also hold a shared lock simultaneously. </summary>
        void lock_shared();

        /// <summary> Attempts to obtain a shared lock for this mutex. </summary>
        /// <returns> true if it the lock was obtained, false otherwise. </returns>
        bool try_lock_shared();

        /// <summary> Unlocks the shared lock on this mutex. </summary>
        void unlock_shared();
    };

    /// <summary> This is a non-standard class which is essentially the same as `shared_mutex` but
    /// it allows a thread to recursively obtain write locks as long as the unlock count matches
    /// the lock-count. </summary>
    class recursive_shared_mutex : public detail::shared_mutex_base {
    public:
        recursive_shared_mutex() = default;
        recursive_shared_mutex(const recursive_shared_mutex&) = delete;
        ~recursive_shared_mutex() = default;

        recursive_shared_mutex& operator = (const recursive_shared_mutex&) = delete;

        /// <summary> Obtains an exclusive lock of this mutex. For recursive calls will always obtain the
        /// lock. </summary>
        void lock();

        /// <summary> Attempts to exclusively lock this mutex. For recursive calls will always obtain the
        /// lock. </summary>
        /// <returns> true if it the lock was obtained, false otherwise. </returns>
        bool try_lock();

        /// <summary> Unlocks the exclusive lock on this mutex. </summary>
        void unlock();

        /// <summary> Obtains a shared lock on this mutex. Other threads may also hold a shared lock simultaneously. </summary>
        void lock_shared();

        /// <summary> Attempts to obtain a shared lock for this mutex. </summary>
        /// <returns> true if it the lock was obtained, false otherwise. </returns>
        bool try_lock_shared();

        /// <summary> Unlocks the shared lock on this mutex. </summary>
        void unlock_shared();

        /// <summary> Number recursive write locks. </summary>
        /// <returns> The total number of write locks. </returns>
        int num_write_locks();

        /// <summary> Query if this object is exclusively locked by me. </summary>
        /// <returns> true if locked by me, false if not. </returns>
        bool is_locked_by_me();

    private:
        std::thread::id m_write_thread;
        int m_write_recurses = 0;
    };
}

shared_mutex.cpp

#include "pch/pch.hpp"
#include "xtd/shared_mutex.hpp"

#include <thread>

namespace xtd {

    // ------------------------------------------------------------------------
    // class: shared_mutex_base
    // ------------------------------------------------------------------------
    namespace detail {

        void shared_mutex_base::do_exclusive_lock(unique_lock &lk){
            while (someone_has_exclusive_lock()) {
                m_exclusive_release.wait(lk);
            }

            take_exclusive_lock(); // We hold the mutex, there is no race here.

            while (number_of_readers() > 0) {
                m_shared_release.wait(lk);
            }
        }

        bool shared_mutex_base::do_exclusive_trylock(unique_lock &lk){
            if (lk.owns_lock() && no_one_has_any_lock()) {
                take_exclusive_lock();
                return true;
            }
            return false;
        }

        void shared_mutex_base::do_lock_shared(unique_lock& lk) {
            while (someone_has_exclusive_lock() || maximal_number_of_readers_reached()) {
                m_exclusive_release.wait(lk);
            }
            increment_readers();
        }

        bool shared_mutex_base::do_try_lock_shared(unique_lock& lk) {
            if (lk.owns_lock() && !someone_has_exclusive_lock() &&
                !maximal_number_of_readers_reached()) {
                increment_readers();
                return true;
            }
            return false;
        }

        void shared_mutex_base::do_unlock_shared(scoped_lock& lk) {
            decrement_readers();

            if (someone_has_exclusive_lock()) { // Some one is waiting for us to unlock...
                if (number_of_readers() == 0) {
                    // We were the last one they were waiting for, release one thread waiting
                    // for
                    // all shared locks to clear.
                    m_shared_release.notify_one();
                }
            }
            else {
                // Nobody is waiting for shared locks to clear, if we were at the max
                // capacity,
                // release one thread waiting to obtain a shared lock in lock_shared().
                if (number_of_readers() == m_num_readers - 1)
                    m_exclusive_release.notify_one();
            }
        }

        void shared_mutex_base::take_exclusive_lock() { m_state |= m_write_entered; }

        bool shared_mutex_base::someone_has_exclusive_lock() const {
            return (m_state & m_write_entered) != 0;
        }

        bool shared_mutex_base::no_one_has_any_lock() const { return m_state != 0; }

        unsigned shared_mutex_base::number_of_readers() const {
            return m_state & m_num_readers;
        }

        bool shared_mutex_base::maximal_number_of_readers_reached() const {
            return number_of_readers() == m_num_readers;
        }

        void shared_mutex_base::clear_lock_status() { m_state = 0; }

        void shared_mutex_base::increment_readers() {
            unsigned num_readers = number_of_readers() + 1;
            m_state &= ~m_num_readers;
            m_state |= num_readers;
        }

        void shared_mutex_base::decrement_readers() {
            unsigned num_readers = number_of_readers() - 1;
            m_state &= ~m_num_readers;
            m_state |= num_readers;
        }
    }

    // ------------------------------------------------------------------------
    // class: shared_mutex
    // ------------------------------------------------------------------------
    static_assert(std::is_standard_layout<shared_mutex>::value,
                  "Shared mutex must be standard layout");

    void shared_mutex::lock() {
        std::unique_lock<xtd::fast_recursive_mutex> lk(m_mutex);
        do_exclusive_lock(lk);
    }

    bool shared_mutex::try_lock() {
        std::unique_lock<xtd::fast_recursive_mutex> lk(m_mutex, std::try_to_lock);
        return do_exclusive_trylock(lk);
    }

    void shared_mutex::unlock() {
        {
            std::lock_guard<xtd::fast_recursive_mutex> lg(m_mutex);
            // We released an exclusive lock, no one else has a lock.
            clear_lock_status();
        }
        m_exclusive_release.notify_all();
    }

    void shared_mutex::lock_shared() {
        std::unique_lock<xtd::fast_recursive_mutex> lk(m_mutex);
        do_lock_shared(lk);
    }

    bool shared_mutex::try_lock_shared() {
        std::unique_lock<xtd::fast_recursive_mutex> lk(m_mutex, std::try_to_lock);
        return do_try_lock_shared(lk);
    }

    void shared_mutex::unlock_shared() {
        std::lock_guard<xtd::fast_recursive_mutex> _(m_mutex);
        do_unlock_shared(_);
    }

    // ------------------------------------------------------------------------
    // class: recursive_shared_mutex
    // ------------------------------------------------------------------------
    void recursive_shared_mutex::lock() {
        std::unique_lock<xtd::fast_recursive_mutex> lk(m_mutex);
        if (m_write_recurses == 0) {
            do_exclusive_lock(lk);
        }
        else {
            if (m_write_thread == std::this_thread::get_id()) {
                if (m_write_recurses ==
                    std::numeric_limits<decltype(m_write_recurses)>::max()) {
                    throw std::system_error(
                        EOVERFLOW, std::system_category(),
                        "Too many recursions in recursive_shared_mutex!");
                }
            }
            else {
                // Different thread trying to get a lock.
                do_exclusive_lock(lk);
                assert(m_write_recurses == 0);
            }
        }
        m_write_recurses++;
        m_write_thread = std::this_thread::get_id();
    }

    bool recursive_shared_mutex::try_lock() {
        std::unique_lock<xtd::fast_recursive_mutex> lk(m_mutex, std::try_to_lock);
        if ((lk.owns_lock() && m_write_recurses > 0 && m_write_thread == std::this_thread::get_id()) ||
            do_exclusive_trylock(lk)) {
            m_write_recurses++;
            m_write_thread = std::this_thread::get_id();
            return true;
        }
        return false;
    }

    void recursive_shared_mutex::unlock() {
        bool notify_them = false;
        {
            std::lock_guard<xtd::fast_recursive_mutex> lg(m_mutex);
            if (m_write_recurses == 0) {
                throw std::system_error(ENOLCK, std::system_category(),
                                        "Unlocking a unlocked mutex!");
            }
            m_write_recurses--;
            if (m_write_recurses == 0) {
                // We released an exclusive lock, no one else has a lock.
                clear_lock_status();
                notify_them = true;
            }
        }
        if (notify_them) {
            m_exclusive_release.notify_all();
        }
    }

    void recursive_shared_mutex::lock_shared() {
        std::unique_lock<xtd::fast_recursive_mutex> lk(m_mutex);
        do_lock_shared(lk);
    }

    bool recursive_shared_mutex::try_lock_shared() {
        std::unique_lock<xtd::fast_recursive_mutex> lk(m_mutex, std::try_to_lock);
        return do_try_lock_shared(lk);
    }

    void recursive_shared_mutex::unlock_shared() {
        std::lock_guard<xtd::fast_recursive_mutex> _(m_mutex);
        return do_unlock_shared(_);
    }

    int recursive_shared_mutex::num_write_locks() {
        std::lock_guard<xtd::fast_recursive_mutex> _(m_mutex);
        return m_write_recurses;
    }

    bool recursive_shared_mutex::is_locked_by_me() {
        std::lock_guard<xtd::fast_recursive_mutex> _(m_mutex);
        return m_write_recurses > 0 && m_write_thread == std::this_thread::get_id();
    }
}

The implementation is based on the reference implementation in this working paper.

share|improve this question

Your Answer

 
discard

By posting your answer, you agree to the privacy policy and terms of service.

Browse other questions tagged or ask your own question.