diff --git a/src/common/intrusive_red_black_tree.h b/src/common/intrusive_red_black_tree.h index 3173cc449..b296b639e 100644 --- a/src/common/intrusive_red_black_tree.h +++ b/src/common/intrusive_red_black_tree.h @@ -4,6 +4,8 @@ #pragma once +#include "common/alignment.h" +#include "common/common_funcs.h" #include "common/parent_of_member.h" #include "common/tree.h" @@ -15,32 +17,33 @@ class IntrusiveRedBlackTreeImpl; } +#pragma pack(push, 4) struct IntrusiveRedBlackTreeNode { + YUZU_NON_COPYABLE(IntrusiveRedBlackTreeNode); + public: - using EntryType = RBEntry; - - constexpr IntrusiveRedBlackTreeNode() = default; - - void SetEntry(const EntryType& new_entry) { - entry = new_entry; - } - - [[nodiscard]] EntryType& GetEntry() { - return entry; - } - - [[nodiscard]] const EntryType& GetEntry() const { - return entry; - } + using RBEntry = freebsd::RBEntry; private: - EntryType entry{}; + RBEntry m_entry; - friend class impl::IntrusiveRedBlackTreeImpl; +public: + explicit IntrusiveRedBlackTreeNode() = default; - template - friend class IntrusiveRedBlackTree; + [[nodiscard]] constexpr RBEntry& GetRBEntry() { + return m_entry; + } + [[nodiscard]] constexpr const RBEntry& GetRBEntry() const { + return m_entry; + } + + constexpr void SetRBEntry(const RBEntry& entry) { + m_entry = entry; + } }; +static_assert(sizeof(IntrusiveRedBlackTreeNode) == + 3 * sizeof(void*) + std::max(sizeof(freebsd::RBColor), 4)); +#pragma pack(pop) template class IntrusiveRedBlackTree; @@ -48,12 +51,17 @@ class IntrusiveRedBlackTree; namespace impl { class IntrusiveRedBlackTreeImpl { + YUZU_NON_COPYABLE(IntrusiveRedBlackTreeImpl); + private: template friend class ::Common::IntrusiveRedBlackTree; - using RootType = RBHead; - RootType root; +private: + using RootType = freebsd::RBHead; + +private: + RootType m_root; public: template @@ -81,149 +89,150 @@ public: IntrusiveRedBlackTreeImpl::reference>; private: - pointer node; + pointer m_node; public: - explicit Iterator(pointer n) : node(n) {} + constexpr explicit Iterator(pointer n) : m_node(n) {} - bool operator==(const Iterator& rhs) const { - return this->node == rhs.node; + constexpr bool operator==(const Iterator& rhs) const { + return m_node == rhs.m_node; } - bool operator!=(const Iterator& rhs) const { + constexpr bool operator!=(const Iterator& rhs) const { return !(*this == rhs); } - pointer operator->() const { - return this->node; + constexpr pointer operator->() const { + return m_node; } - reference operator*() const { - return *this->node; + constexpr reference operator*() const { + return *m_node; } - Iterator& operator++() { - this->node = GetNext(this->node); + constexpr Iterator& operator++() { + m_node = GetNext(m_node); return *this; } - Iterator& operator--() { - this->node = GetPrev(this->node); + constexpr Iterator& operator--() { + m_node = GetPrev(m_node); return *this; } - Iterator operator++(int) { + constexpr Iterator operator++(int) { const Iterator it{*this}; ++(*this); return it; } - Iterator operator--(int) { + constexpr Iterator operator--(int) { const Iterator it{*this}; --(*this); return it; } - operator Iterator() const { - return Iterator(this->node); + constexpr operator Iterator() const { + return Iterator(m_node); } }; private: - // Define accessors using RB_* functions. - bool EmptyImpl() const { - return root.IsEmpty(); + constexpr bool EmptyImpl() const { + return m_root.IsEmpty(); } - IntrusiveRedBlackTreeNode* GetMinImpl() const { - return RB_MIN(const_cast(&root)); + constexpr IntrusiveRedBlackTreeNode* GetMinImpl() const { + return freebsd::RB_MIN(const_cast(m_root)); } - IntrusiveRedBlackTreeNode* GetMaxImpl() const { - return RB_MAX(const_cast(&root)); + constexpr IntrusiveRedBlackTreeNode* GetMaxImpl() const { + return freebsd::RB_MAX(const_cast(m_root)); } - IntrusiveRedBlackTreeNode* RemoveImpl(IntrusiveRedBlackTreeNode* node) { - return RB_REMOVE(&root, node); + constexpr IntrusiveRedBlackTreeNode* RemoveImpl(IntrusiveRedBlackTreeNode* node) { + return freebsd::RB_REMOVE(m_root, node); } public: - static IntrusiveRedBlackTreeNode* GetNext(IntrusiveRedBlackTreeNode* node) { - return RB_NEXT(node); + static constexpr IntrusiveRedBlackTreeNode* GetNext(IntrusiveRedBlackTreeNode* node) { + return freebsd::RB_NEXT(node); } - static IntrusiveRedBlackTreeNode* GetPrev(IntrusiveRedBlackTreeNode* node) { - return RB_PREV(node); + static constexpr IntrusiveRedBlackTreeNode* GetPrev(IntrusiveRedBlackTreeNode* node) { + return freebsd::RB_PREV(node); } - static const IntrusiveRedBlackTreeNode* GetNext(const IntrusiveRedBlackTreeNode* node) { + static constexpr IntrusiveRedBlackTreeNode const* GetNext( + IntrusiveRedBlackTreeNode const* node) { return static_cast( GetNext(const_cast(node))); } - static const IntrusiveRedBlackTreeNode* GetPrev(const IntrusiveRedBlackTreeNode* node) { + static constexpr IntrusiveRedBlackTreeNode const* GetPrev( + IntrusiveRedBlackTreeNode const* node) { return static_cast( GetPrev(const_cast(node))); } public: - constexpr IntrusiveRedBlackTreeImpl() {} + constexpr IntrusiveRedBlackTreeImpl() = default; // Iterator accessors. - iterator begin() { + constexpr iterator begin() { return iterator(this->GetMinImpl()); } - const_iterator begin() const { + constexpr const_iterator begin() const { return const_iterator(this->GetMinImpl()); } - iterator end() { + constexpr iterator end() { return iterator(static_cast(nullptr)); } - const_iterator end() const { + constexpr const_iterator end() const { return const_iterator(static_cast(nullptr)); } - const_iterator cbegin() const { + constexpr const_iterator cbegin() const { return this->begin(); } - const_iterator cend() const { + constexpr const_iterator cend() const { return this->end(); } - iterator iterator_to(reference ref) { - return iterator(&ref); + constexpr iterator iterator_to(reference ref) { + return iterator(std::addressof(ref)); } - const_iterator iterator_to(const_reference ref) const { - return const_iterator(&ref); + constexpr const_iterator iterator_to(const_reference ref) const { + return const_iterator(std::addressof(ref)); } // Content management. - bool empty() const { + constexpr bool empty() const { return this->EmptyImpl(); } - reference back() { + constexpr reference back() { return *this->GetMaxImpl(); } - const_reference back() const { + constexpr const_reference back() const { return *this->GetMaxImpl(); } - reference front() { + constexpr reference front() { return *this->GetMinImpl(); } - const_reference front() const { + constexpr const_reference front() const { return *this->GetMinImpl(); } - iterator erase(iterator it) { + constexpr iterator erase(iterator it) { auto cur = std::addressof(*it); auto next = GetNext(cur); this->RemoveImpl(cur); @@ -234,16 +243,16 @@ public: } // namespace impl template -concept HasLightCompareType = requires { - { std::is_same::value } -> std::convertible_to; +concept HasRedBlackKeyType = requires { + { std::is_same::value } -> std::convertible_to; }; namespace impl { template - consteval auto* GetLightCompareType() { - if constexpr (HasLightCompareType) { - return static_cast(nullptr); + consteval auto* GetRedBlackKeyType() { + if constexpr (HasRedBlackKeyType) { + return static_cast(nullptr); } else { return static_cast(nullptr); } @@ -252,16 +261,17 @@ namespace impl { } // namespace impl template -using LightCompareType = std::remove_pointer_t())>; +using RedBlackKeyType = std::remove_pointer_t())>; template class IntrusiveRedBlackTree { + YUZU_NON_COPYABLE(IntrusiveRedBlackTree); public: using ImplType = impl::IntrusiveRedBlackTreeImpl; private: - ImplType impl{}; + ImplType m_impl; public: template @@ -277,9 +287,9 @@ public: using iterator = Iterator; using const_iterator = Iterator; - using light_value_type = LightCompareType; - using const_light_pointer = const light_value_type*; - using const_light_reference = const light_value_type&; + using key_type = RedBlackKeyType; + using const_key_pointer = const key_type*; + using const_key_reference = const key_type&; template class Iterator { @@ -298,183 +308,201 @@ public: IntrusiveRedBlackTree::reference>; private: - ImplIterator iterator; + ImplIterator m_impl; private: - explicit Iterator(ImplIterator it) : iterator(it) {} + constexpr explicit Iterator(ImplIterator it) : m_impl(it) {} - explicit Iterator(typename std::conditional::type::pointer ptr) - : iterator(ptr) {} + constexpr explicit Iterator(typename ImplIterator::pointer p) : m_impl(p) {} - ImplIterator GetImplIterator() const { - return this->iterator; + constexpr ImplIterator GetImplIterator() const { + return m_impl; } public: - bool operator==(const Iterator& rhs) const { - return this->iterator == rhs.iterator; + constexpr bool operator==(const Iterator& rhs) const { + return m_impl == rhs.m_impl; } - bool operator!=(const Iterator& rhs) const { + constexpr bool operator!=(const Iterator& rhs) const { return !(*this == rhs); } - pointer operator->() const { - return Traits::GetParent(std::addressof(*this->iterator)); + constexpr pointer operator->() const { + return Traits::GetParent(std::addressof(*m_impl)); } - reference operator*() const { - return *Traits::GetParent(std::addressof(*this->iterator)); + constexpr reference operator*() const { + return *Traits::GetParent(std::addressof(*m_impl)); } - Iterator& operator++() { - ++this->iterator; + constexpr Iterator& operator++() { + ++m_impl; return *this; } - Iterator& operator--() { - --this->iterator; + constexpr Iterator& operator--() { + --m_impl; return *this; } - Iterator operator++(int) { + constexpr Iterator operator++(int) { const Iterator it{*this}; - ++this->iterator; + ++m_impl; return it; } - Iterator operator--(int) { + constexpr Iterator operator--(int) { const Iterator it{*this}; - --this->iterator; + --m_impl; return it; } - operator Iterator() const { - return Iterator(this->iterator); + constexpr operator Iterator() const { + return Iterator(m_impl); } }; private: - static int CompareImpl(const IntrusiveRedBlackTreeNode* lhs, - const IntrusiveRedBlackTreeNode* rhs) { + static constexpr int CompareImpl(const IntrusiveRedBlackTreeNode* lhs, + const IntrusiveRedBlackTreeNode* rhs) { return Comparator::Compare(*Traits::GetParent(lhs), *Traits::GetParent(rhs)); } - static int LightCompareImpl(const void* elm, const IntrusiveRedBlackTreeNode* rhs) { - return Comparator::Compare(*static_cast(elm), *Traits::GetParent(rhs)); + static constexpr int CompareKeyImpl(const_key_reference key, + const IntrusiveRedBlackTreeNode* rhs) { + return Comparator::Compare(key, *Traits::GetParent(rhs)); } // Define accessors using RB_* functions. - IntrusiveRedBlackTreeNode* InsertImpl(IntrusiveRedBlackTreeNode* node) { - return RB_INSERT(&impl.root, node, CompareImpl); + constexpr IntrusiveRedBlackTreeNode* InsertImpl(IntrusiveRedBlackTreeNode* node) { + return freebsd::RB_INSERT(m_impl.m_root, node, CompareImpl); } - IntrusiveRedBlackTreeNode* FindImpl(const IntrusiveRedBlackTreeNode* node) const { - return RB_FIND(const_cast(&impl.root), - const_cast(node), CompareImpl); + constexpr IntrusiveRedBlackTreeNode* FindImpl(IntrusiveRedBlackTreeNode const* node) const { + return freebsd::RB_FIND(const_cast(m_impl.m_root), + const_cast(node), CompareImpl); } - IntrusiveRedBlackTreeNode* NFindImpl(const IntrusiveRedBlackTreeNode* node) const { - return RB_NFIND(const_cast(&impl.root), - const_cast(node), CompareImpl); + constexpr IntrusiveRedBlackTreeNode* NFindImpl(IntrusiveRedBlackTreeNode const* node) const { + return freebsd::RB_NFIND(const_cast(m_impl.m_root), + const_cast(node), CompareImpl); } - IntrusiveRedBlackTreeNode* FindLightImpl(const_light_pointer lelm) const { - return RB_FIND_LIGHT(const_cast(&impl.root), - static_cast(lelm), LightCompareImpl); + constexpr IntrusiveRedBlackTreeNode* FindKeyImpl(const_key_reference key) const { + return freebsd::RB_FIND_KEY(const_cast(m_impl.m_root), key, + CompareKeyImpl); } - IntrusiveRedBlackTreeNode* NFindLightImpl(const_light_pointer lelm) const { - return RB_NFIND_LIGHT(const_cast(&impl.root), - static_cast(lelm), LightCompareImpl); + constexpr IntrusiveRedBlackTreeNode* NFindKeyImpl(const_key_reference key) const { + return freebsd::RB_NFIND_KEY(const_cast(m_impl.m_root), key, + CompareKeyImpl); + } + + constexpr IntrusiveRedBlackTreeNode* FindExistingImpl( + IntrusiveRedBlackTreeNode const* node) const { + return freebsd::RB_FIND_EXISTING(const_cast(m_impl.m_root), + const_cast(node), CompareImpl); + } + + constexpr IntrusiveRedBlackTreeNode* FindExistingKeyImpl(const_key_reference key) const { + return freebsd::RB_FIND_EXISTING_KEY(const_cast(m_impl.m_root), key, + CompareKeyImpl); } public: constexpr IntrusiveRedBlackTree() = default; // Iterator accessors. - iterator begin() { - return iterator(this->impl.begin()); + constexpr iterator begin() { + return iterator(m_impl.begin()); } - const_iterator begin() const { - return const_iterator(this->impl.begin()); + constexpr const_iterator begin() const { + return const_iterator(m_impl.begin()); } - iterator end() { - return iterator(this->impl.end()); + constexpr iterator end() { + return iterator(m_impl.end()); } - const_iterator end() const { - return const_iterator(this->impl.end()); + constexpr const_iterator end() const { + return const_iterator(m_impl.end()); } - const_iterator cbegin() const { + constexpr const_iterator cbegin() const { return this->begin(); } - const_iterator cend() const { + constexpr const_iterator cend() const { return this->end(); } - iterator iterator_to(reference ref) { - return iterator(this->impl.iterator_to(*Traits::GetNode(std::addressof(ref)))); + constexpr iterator iterator_to(reference ref) { + return iterator(m_impl.iterator_to(*Traits::GetNode(std::addressof(ref)))); } - const_iterator iterator_to(const_reference ref) const { - return const_iterator(this->impl.iterator_to(*Traits::GetNode(std::addressof(ref)))); + constexpr const_iterator iterator_to(const_reference ref) const { + return const_iterator(m_impl.iterator_to(*Traits::GetNode(std::addressof(ref)))); } // Content management. - bool empty() const { - return this->impl.empty(); + constexpr bool empty() const { + return m_impl.empty(); } - reference back() { - return *Traits::GetParent(std::addressof(this->impl.back())); + constexpr reference back() { + return *Traits::GetParent(std::addressof(m_impl.back())); } - const_reference back() const { - return *Traits::GetParent(std::addressof(this->impl.back())); + constexpr const_reference back() const { + return *Traits::GetParent(std::addressof(m_impl.back())); } - reference front() { - return *Traits::GetParent(std::addressof(this->impl.front())); + constexpr reference front() { + return *Traits::GetParent(std::addressof(m_impl.front())); } - const_reference front() const { - return *Traits::GetParent(std::addressof(this->impl.front())); + constexpr const_reference front() const { + return *Traits::GetParent(std::addressof(m_impl.front())); } - iterator erase(iterator it) { - return iterator(this->impl.erase(it.GetImplIterator())); + constexpr iterator erase(iterator it) { + return iterator(m_impl.erase(it.GetImplIterator())); } - iterator insert(reference ref) { + constexpr iterator insert(reference ref) { ImplType::pointer node = Traits::GetNode(std::addressof(ref)); this->InsertImpl(node); return iterator(node); } - iterator find(const_reference ref) const { + constexpr iterator find(const_reference ref) const { return iterator(this->FindImpl(Traits::GetNode(std::addressof(ref)))); } - iterator nfind(const_reference ref) const { + constexpr iterator nfind(const_reference ref) const { return iterator(this->NFindImpl(Traits::GetNode(std::addressof(ref)))); } - iterator find_light(const_light_reference ref) const { - return iterator(this->FindLightImpl(std::addressof(ref))); + constexpr iterator find_key(const_key_reference ref) const { + return iterator(this->FindKeyImpl(ref)); } - iterator nfind_light(const_light_reference ref) const { - return iterator(this->NFindLightImpl(std::addressof(ref))); + constexpr iterator nfind_key(const_key_reference ref) const { + return iterator(this->NFindKeyImpl(ref)); + } + + constexpr iterator find_existing(const_reference ref) const { + return iterator(this->FindExistingImpl(Traits::GetNode(std::addressof(ref)))); + } + + constexpr iterator find_existing_key(const_key_reference ref) const { + return iterator(this->FindExistingKeyImpl(ref)); } }; -template > +template > class IntrusiveRedBlackTreeMemberTraits; template @@ -498,19 +526,16 @@ private: return std::addressof(parent->*Member); } - static constexpr Derived* GetParent(IntrusiveRedBlackTreeNode* node) { - return GetParentPointer(node); + static Derived* GetParent(IntrusiveRedBlackTreeNode* node) { + return Common::GetParentPointer(node); } - static constexpr Derived const* GetParent(const IntrusiveRedBlackTreeNode* node) { - return GetParentPointer(node); + static Derived const* GetParent(IntrusiveRedBlackTreeNode const* node) { + return Common::GetParentPointer(node); } - -private: - static constexpr TypedStorage DerivedStorage = {}; }; -template > +template > class IntrusiveRedBlackTreeMemberTraitsDeferredAssert; template @@ -521,11 +546,6 @@ public: IntrusiveRedBlackTree; using TreeTypeImpl = impl::IntrusiveRedBlackTreeImpl; - static constexpr bool IsValid() { - TypedStorage DerivedStorage = {}; - return GetParent(GetNode(GetPointer(DerivedStorage))) == GetPointer(DerivedStorage); - } - private: template friend class IntrusiveRedBlackTree; @@ -540,30 +560,36 @@ private: return std::addressof(parent->*Member); } - static constexpr Derived* GetParent(IntrusiveRedBlackTreeNode* node) { - return GetParentPointer(node); + static Derived* GetParent(IntrusiveRedBlackTreeNode* node) { + return Common::GetParentPointer(node); } - static constexpr Derived const* GetParent(const IntrusiveRedBlackTreeNode* node) { - return GetParentPointer(node); + static Derived const* GetParent(IntrusiveRedBlackTreeNode const* node) { + return Common::GetParentPointer(node); } }; template -class IntrusiveRedBlackTreeBaseNode : public IntrusiveRedBlackTreeNode { +class alignas(void*) IntrusiveRedBlackTreeBaseNode : public IntrusiveRedBlackTreeNode { public: + using IntrusiveRedBlackTreeNode::IntrusiveRedBlackTreeNode; + constexpr Derived* GetPrev() { - return static_cast(impl::IntrusiveRedBlackTreeImpl::GetPrev(this)); + return static_cast(static_cast( + impl::IntrusiveRedBlackTreeImpl::GetPrev(this))); } constexpr const Derived* GetPrev() const { - return static_cast(impl::IntrusiveRedBlackTreeImpl::GetPrev(this)); + return static_cast(static_cast( + impl::IntrusiveRedBlackTreeImpl::GetPrev(this))); } constexpr Derived* GetNext() { - return static_cast(impl::IntrusiveRedBlackTreeImpl::GetNext(this)); + return static_cast(static_cast( + impl::IntrusiveRedBlackTreeImpl::GetNext(this))); } constexpr const Derived* GetNext() const { - return static_cast(impl::IntrusiveRedBlackTreeImpl::GetNext(this)); + return static_cast(static_cast( + impl::IntrusiveRedBlackTreeImpl::GetNext(this))); } }; @@ -581,19 +607,22 @@ private: friend class impl::IntrusiveRedBlackTreeImpl; static constexpr IntrusiveRedBlackTreeNode* GetNode(Derived* parent) { - return static_cast(parent); + return static_cast( + static_cast*>(parent)); } static constexpr IntrusiveRedBlackTreeNode const* GetNode(Derived const* parent) { - return static_cast(parent); + return static_cast( + static_cast*>(parent)); } static constexpr Derived* GetParent(IntrusiveRedBlackTreeNode* node) { - return static_cast(node); + return static_cast(static_cast*>(node)); } - static constexpr Derived const* GetParent(const IntrusiveRedBlackTreeNode* node) { - return static_cast(node); + static constexpr Derived const* GetParent(IntrusiveRedBlackTreeNode const* node) { + return static_cast( + static_cast*>(node)); } }; diff --git a/src/common/tree.h b/src/common/tree.h index 18faa4a48..28370e343 100644 --- a/src/common/tree.h +++ b/src/common/tree.h @@ -43,246 +43,445 @@ * The maximum height of a red-black tree is 2lg (n+1). */ -#include "common/assert.h" +namespace Common::freebsd { -namespace Common { -template -class RBHead { -public: - [[nodiscard]] T* Root() { - return rbh_root; - } - - [[nodiscard]] const T* Root() const { - return rbh_root; - } - - void SetRoot(T* root) { - rbh_root = root; - } - - [[nodiscard]] bool IsEmpty() const { - return Root() == nullptr; - } - -private: - T* rbh_root = nullptr; -}; - -enum class EntryColor { - Black, - Red, +enum class RBColor { + RB_BLACK = 0, + RB_RED = 1, }; +#pragma pack(push, 4) template class RBEntry { public: - [[nodiscard]] T* Left() { - return rbe_left; + constexpr RBEntry() = default; + + [[nodiscard]] constexpr T* Left() { + return m_rbe_left; + } + [[nodiscard]] constexpr const T* Left() const { + return m_rbe_left; } - [[nodiscard]] const T* Left() const { - return rbe_left; + constexpr void SetLeft(T* e) { + m_rbe_left = e; } - void SetLeft(T* left) { - rbe_left = left; + [[nodiscard]] constexpr T* Right() { + return m_rbe_right; + } + [[nodiscard]] constexpr const T* Right() const { + return m_rbe_right; } - [[nodiscard]] T* Right() { - return rbe_right; + constexpr void SetRight(T* e) { + m_rbe_right = e; } - [[nodiscard]] const T* Right() const { - return rbe_right; + [[nodiscard]] constexpr T* Parent() { + return m_rbe_parent; + } + [[nodiscard]] constexpr const T* Parent() const { + return m_rbe_parent; } - void SetRight(T* right) { - rbe_right = right; + constexpr void SetParent(T* e) { + m_rbe_parent = e; } - [[nodiscard]] T* Parent() { - return rbe_parent; + [[nodiscard]] constexpr bool IsBlack() const { + return m_rbe_color == RBColor::RB_BLACK; + } + [[nodiscard]] constexpr bool IsRed() const { + return m_rbe_color == RBColor::RB_RED; + } + [[nodiscard]] constexpr RBColor Color() const { + return m_rbe_color; } - [[nodiscard]] const T* Parent() const { - return rbe_parent; - } - - void SetParent(T* parent) { - rbe_parent = parent; - } - - [[nodiscard]] bool IsBlack() const { - return rbe_color == EntryColor::Black; - } - - [[nodiscard]] bool IsRed() const { - return rbe_color == EntryColor::Red; - } - - [[nodiscard]] EntryColor Color() const { - return rbe_color; - } - - void SetColor(EntryColor color) { - rbe_color = color; + constexpr void SetColor(RBColor c) { + m_rbe_color = c; } private: - T* rbe_left = nullptr; - T* rbe_right = nullptr; - T* rbe_parent = nullptr; - EntryColor rbe_color{}; + T* m_rbe_left{}; + T* m_rbe_right{}; + T* m_rbe_parent{}; + RBColor m_rbe_color{RBColor::RB_BLACK}; +}; +#pragma pack(pop) + +template +struct CheckRBEntry { + static constexpr bool value = false; +}; +template +struct CheckRBEntry> { + static constexpr bool value = true; }; -template -[[nodiscard]] RBEntry& RB_ENTRY(Node* node) { - return node->GetEntry(); +template +concept IsRBEntry = CheckRBEntry::value; + +template +concept HasRBEntry = requires(T& t, const T& ct) { + { t.GetRBEntry() } -> std::same_as&>; + { ct.GetRBEntry() } -> std::same_as&>; +}; + +template +requires HasRBEntry +class RBHead { +private: + T* m_rbh_root = nullptr; + +public: + [[nodiscard]] constexpr T* Root() { + return m_rbh_root; + } + [[nodiscard]] constexpr const T* Root() const { + return m_rbh_root; + } + constexpr void SetRoot(T* root) { + m_rbh_root = root; + } + + [[nodiscard]] constexpr bool IsEmpty() const { + return this->Root() == nullptr; + } +}; + +template +requires HasRBEntry +[[nodiscard]] constexpr RBEntry& RB_ENTRY(T* t) { + return t->GetRBEntry(); +} +template +requires HasRBEntry +[[nodiscard]] constexpr const RBEntry& RB_ENTRY(const T* t) { + return t->GetRBEntry(); } -template -[[nodiscard]] const RBEntry& RB_ENTRY(const Node* node) { - return node->GetEntry(); +template +requires HasRBEntry +[[nodiscard]] constexpr T* RB_LEFT(T* t) { + return RB_ENTRY(t).Left(); +} +template +requires HasRBEntry +[[nodiscard]] constexpr const T* RB_LEFT(const T* t) { + return RB_ENTRY(t).Left(); } -template -[[nodiscard]] Node* RB_PARENT(Node* node) { - return RB_ENTRY(node).Parent(); +template +requires HasRBEntry +[[nodiscard]] constexpr T* RB_RIGHT(T* t) { + return RB_ENTRY(t).Right(); +} +template +requires HasRBEntry +[[nodiscard]] constexpr const T* RB_RIGHT(const T* t) { + return RB_ENTRY(t).Right(); } -template -[[nodiscard]] const Node* RB_PARENT(const Node* node) { - return RB_ENTRY(node).Parent(); +template +requires HasRBEntry +[[nodiscard]] constexpr T* RB_PARENT(T* t) { + return RB_ENTRY(t).Parent(); +} +template +requires HasRBEntry +[[nodiscard]] constexpr const T* RB_PARENT(const T* t) { + return RB_ENTRY(t).Parent(); } -template -void RB_SET_PARENT(Node* node, Node* parent) { - return RB_ENTRY(node).SetParent(parent); +template +requires HasRBEntry +constexpr void RB_SET_LEFT(T* t, T* e) { + RB_ENTRY(t).SetLeft(e); +} +template +requires HasRBEntry +constexpr void RB_SET_RIGHT(T* t, T* e) { + RB_ENTRY(t).SetRight(e); +} +template +requires HasRBEntry +constexpr void RB_SET_PARENT(T* t, T* e) { + RB_ENTRY(t).SetParent(e); } -template -[[nodiscard]] Node* RB_LEFT(Node* node) { - return RB_ENTRY(node).Left(); +template +requires HasRBEntry +[[nodiscard]] constexpr bool RB_IS_BLACK(const T* t) { + return RB_ENTRY(t).IsBlack(); +} +template +requires HasRBEntry +[[nodiscard]] constexpr bool RB_IS_RED(const T* t) { + return RB_ENTRY(t).IsRed(); } -template -[[nodiscard]] const Node* RB_LEFT(const Node* node) { - return RB_ENTRY(node).Left(); +template +requires HasRBEntry +[[nodiscard]] constexpr RBColor RB_COLOR(const T* t) { + return RB_ENTRY(t).Color(); } -template -void RB_SET_LEFT(Node* node, Node* left) { - return RB_ENTRY(node).SetLeft(left); +template +requires HasRBEntry +constexpr void RB_SET_COLOR(T* t, RBColor c) { + RB_ENTRY(t).SetColor(c); } -template -[[nodiscard]] Node* RB_RIGHT(Node* node) { - return RB_ENTRY(node).Right(); +template +requires HasRBEntry +constexpr void RB_SET(T* elm, T* parent) { + auto& rb_entry = RB_ENTRY(elm); + rb_entry.SetParent(parent); + rb_entry.SetLeft(nullptr); + rb_entry.SetRight(nullptr); + rb_entry.SetColor(RBColor::RB_RED); } -template -[[nodiscard]] const Node* RB_RIGHT(const Node* node) { - return RB_ENTRY(node).Right(); +template +requires HasRBEntry +constexpr void RB_SET_BLACKRED(T* black, T* red) { + RB_SET_COLOR(black, RBColor::RB_BLACK); + RB_SET_COLOR(red, RBColor::RB_RED); } -template -void RB_SET_RIGHT(Node* node, Node* right) { - return RB_ENTRY(node).SetRight(right); -} - -template -[[nodiscard]] bool RB_IS_BLACK(const Node* node) { - return RB_ENTRY(node).IsBlack(); -} - -template -[[nodiscard]] bool RB_IS_RED(const Node* node) { - return RB_ENTRY(node).IsRed(); -} - -template -[[nodiscard]] EntryColor RB_COLOR(const Node* node) { - return RB_ENTRY(node).Color(); -} - -template -void RB_SET_COLOR(Node* node, EntryColor color) { - return RB_ENTRY(node).SetColor(color); -} - -template -void RB_SET(Node* node, Node* parent) { - auto& entry = RB_ENTRY(node); - entry.SetParent(parent); - entry.SetLeft(nullptr); - entry.SetRight(nullptr); - entry.SetColor(EntryColor::Red); -} - -template -void RB_SET_BLACKRED(Node* black, Node* red) { - RB_SET_COLOR(black, EntryColor::Black); - RB_SET_COLOR(red, EntryColor::Red); -} - -template -void RB_ROTATE_LEFT(RBHead* head, Node* elm, Node*& tmp) { +template +requires HasRBEntry +constexpr void RB_ROTATE_LEFT(RBHead& head, T* elm, T*& tmp) { tmp = RB_RIGHT(elm); - RB_SET_RIGHT(elm, RB_LEFT(tmp)); - if (RB_RIGHT(elm) != nullptr) { + if (RB_SET_RIGHT(elm, RB_LEFT(tmp)); RB_RIGHT(elm) != nullptr) { RB_SET_PARENT(RB_LEFT(tmp), elm); } - RB_SET_PARENT(tmp, RB_PARENT(elm)); - if (RB_PARENT(tmp) != nullptr) { + if (RB_SET_PARENT(tmp, RB_PARENT(elm)); RB_PARENT(tmp) != nullptr) { if (elm == RB_LEFT(RB_PARENT(elm))) { RB_SET_LEFT(RB_PARENT(elm), tmp); } else { RB_SET_RIGHT(RB_PARENT(elm), tmp); } } else { - head->SetRoot(tmp); + head.SetRoot(tmp); } RB_SET_LEFT(tmp, elm); RB_SET_PARENT(elm, tmp); } -template -void RB_ROTATE_RIGHT(RBHead* head, Node* elm, Node*& tmp) { +template +requires HasRBEntry +constexpr void RB_ROTATE_RIGHT(RBHead& head, T* elm, T*& tmp) { tmp = RB_LEFT(elm); - RB_SET_LEFT(elm, RB_RIGHT(tmp)); - if (RB_LEFT(elm) != nullptr) { + if (RB_SET_LEFT(elm, RB_RIGHT(tmp)); RB_LEFT(elm) != nullptr) { RB_SET_PARENT(RB_RIGHT(tmp), elm); } - RB_SET_PARENT(tmp, RB_PARENT(elm)); - if (RB_PARENT(tmp) != nullptr) { + if (RB_SET_PARENT(tmp, RB_PARENT(elm)); RB_PARENT(tmp) != nullptr) { if (elm == RB_LEFT(RB_PARENT(elm))) { RB_SET_LEFT(RB_PARENT(elm), tmp); } else { RB_SET_RIGHT(RB_PARENT(elm), tmp); } } else { - head->SetRoot(tmp); + head.SetRoot(tmp); } RB_SET_RIGHT(tmp, elm); RB_SET_PARENT(elm, tmp); } -template -void RB_INSERT_COLOR(RBHead* head, Node* elm) { - Node* parent = nullptr; - Node* tmp = nullptr; +template +requires HasRBEntry +constexpr void RB_REMOVE_COLOR(RBHead& head, T* parent, T* elm) { + T* tmp; + while ((elm == nullptr || RB_IS_BLACK(elm)) && elm != head.Root()) { + if (RB_LEFT(parent) == elm) { + tmp = RB_RIGHT(parent); + if (RB_IS_RED(tmp)) { + RB_SET_BLACKRED(tmp, parent); + RB_ROTATE_LEFT(head, parent, tmp); + tmp = RB_RIGHT(parent); + } + if ((RB_LEFT(tmp) == nullptr || RB_IS_BLACK(RB_LEFT(tmp))) && + (RB_RIGHT(tmp) == nullptr || RB_IS_BLACK(RB_RIGHT(tmp)))) { + RB_SET_COLOR(tmp, RBColor::RB_RED); + elm = parent; + parent = RB_PARENT(elm); + } else { + if (RB_RIGHT(tmp) == nullptr || RB_IS_BLACK(RB_RIGHT(tmp))) { + T* oleft; + if ((oleft = RB_LEFT(tmp)) != nullptr) { + RB_SET_COLOR(oleft, RBColor::RB_BLACK); + } + + RB_SET_COLOR(tmp, RBColor::RB_RED); + RB_ROTATE_RIGHT(head, tmp, oleft); + tmp = RB_RIGHT(parent); + } + + RB_SET_COLOR(tmp, RB_COLOR(parent)); + RB_SET_COLOR(parent, RBColor::RB_BLACK); + if (RB_RIGHT(tmp)) { + RB_SET_COLOR(RB_RIGHT(tmp), RBColor::RB_BLACK); + } + + RB_ROTATE_LEFT(head, parent, tmp); + elm = head.Root(); + break; + } + } else { + tmp = RB_LEFT(parent); + if (RB_IS_RED(tmp)) { + RB_SET_BLACKRED(tmp, parent); + RB_ROTATE_RIGHT(head, parent, tmp); + tmp = RB_LEFT(parent); + } + + if ((RB_LEFT(tmp) == nullptr || RB_IS_BLACK(RB_LEFT(tmp))) && + (RB_RIGHT(tmp) == nullptr || RB_IS_BLACK(RB_RIGHT(tmp)))) { + RB_SET_COLOR(tmp, RBColor::RB_RED); + elm = parent; + parent = RB_PARENT(elm); + } else { + if (RB_LEFT(tmp) == nullptr || RB_IS_BLACK(RB_LEFT(tmp))) { + T* oright; + if ((oright = RB_RIGHT(tmp)) != nullptr) { + RB_SET_COLOR(oright, RBColor::RB_BLACK); + } + + RB_SET_COLOR(tmp, RBColor::RB_RED); + RB_ROTATE_LEFT(head, tmp, oright); + tmp = RB_LEFT(parent); + } + + RB_SET_COLOR(tmp, RB_COLOR(parent)); + RB_SET_COLOR(parent, RBColor::RB_BLACK); + + if (RB_LEFT(tmp)) { + RB_SET_COLOR(RB_LEFT(tmp), RBColor::RB_BLACK); + } + + RB_ROTATE_RIGHT(head, parent, tmp); + elm = head.Root(); + break; + } + } + } + + if (elm) { + RB_SET_COLOR(elm, RBColor::RB_BLACK); + } +} + +template +requires HasRBEntry +constexpr T* RB_REMOVE(RBHead& head, T* elm) { + T* child = nullptr; + T* parent = nullptr; + T* old = elm; + RBColor color = RBColor::RB_BLACK; + + if (RB_LEFT(elm) == nullptr) { + child = RB_RIGHT(elm); + } else if (RB_RIGHT(elm) == nullptr) { + child = RB_LEFT(elm); + } else { + T* left; + elm = RB_RIGHT(elm); + while ((left = RB_LEFT(elm)) != nullptr) { + elm = left; + } + + child = RB_RIGHT(elm); + parent = RB_PARENT(elm); + color = RB_COLOR(elm); + + if (child) { + RB_SET_PARENT(child, parent); + } + + if (parent) { + if (RB_LEFT(parent) == elm) { + RB_SET_LEFT(parent, child); + } else { + RB_SET_RIGHT(parent, child); + } + } else { + head.SetRoot(child); + } + + if (RB_PARENT(elm) == old) { + parent = elm; + } + + elm->SetRBEntry(old->GetRBEntry()); + + if (RB_PARENT(old)) { + if (RB_LEFT(RB_PARENT(old)) == old) { + RB_SET_LEFT(RB_PARENT(old), elm); + } else { + RB_SET_RIGHT(RB_PARENT(old), elm); + } + } else { + head.SetRoot(elm); + } + + RB_SET_PARENT(RB_LEFT(old), elm); + + if (RB_RIGHT(old)) { + RB_SET_PARENT(RB_RIGHT(old), elm); + } + + if (parent) { + left = parent; + } + + if (color == RBColor::RB_BLACK) { + RB_REMOVE_COLOR(head, parent, child); + } + + return old; + } + + parent = RB_PARENT(elm); + color = RB_COLOR(elm); + + if (child) { + RB_SET_PARENT(child, parent); + } + if (parent) { + if (RB_LEFT(parent) == elm) { + RB_SET_LEFT(parent, child); + } else { + RB_SET_RIGHT(parent, child); + } + } else { + head.SetRoot(child); + } + + if (color == RBColor::RB_BLACK) { + RB_REMOVE_COLOR(head, parent, child); + } + + return old; +} + +template +requires HasRBEntry +constexpr void RB_INSERT_COLOR(RBHead& head, T* elm) { + T *parent = nullptr, *tmp = nullptr; while ((parent = RB_PARENT(elm)) != nullptr && RB_IS_RED(parent)) { - Node* gparent = RB_PARENT(parent); + T* gparent = RB_PARENT(parent); if (parent == RB_LEFT(gparent)) { tmp = RB_RIGHT(gparent); if (tmp && RB_IS_RED(tmp)) { - RB_SET_COLOR(tmp, EntryColor::Black); + RB_SET_COLOR(tmp, RBColor::RB_BLACK); RB_SET_BLACKRED(parent, gparent); elm = gparent; continue; @@ -300,7 +499,7 @@ void RB_INSERT_COLOR(RBHead* head, Node* elm) { } else { tmp = RB_LEFT(gparent); if (tmp && RB_IS_RED(tmp)) { - RB_SET_COLOR(tmp, EntryColor::Black); + RB_SET_COLOR(tmp, RBColor::RB_BLACK); RB_SET_BLACKRED(parent, gparent); elm = gparent; continue; @@ -318,194 +517,14 @@ void RB_INSERT_COLOR(RBHead* head, Node* elm) { } } - RB_SET_COLOR(head->Root(), EntryColor::Black); + RB_SET_COLOR(head.Root(), RBColor::RB_BLACK); } -template -void RB_REMOVE_COLOR(RBHead* head, Node* parent, Node* elm) { - Node* tmp; - while ((elm == nullptr || RB_IS_BLACK(elm)) && elm != head->Root() && parent != nullptr) { - if (RB_LEFT(parent) == elm) { - tmp = RB_RIGHT(parent); - if (!tmp) { - ASSERT_MSG(false, "tmp is invalid!"); - break; - } - if (RB_IS_RED(tmp)) { - RB_SET_BLACKRED(tmp, parent); - RB_ROTATE_LEFT(head, parent, tmp); - tmp = RB_RIGHT(parent); - } - - if ((RB_LEFT(tmp) == nullptr || RB_IS_BLACK(RB_LEFT(tmp))) && - (RB_RIGHT(tmp) == nullptr || RB_IS_BLACK(RB_RIGHT(tmp)))) { - RB_SET_COLOR(tmp, EntryColor::Red); - elm = parent; - parent = RB_PARENT(elm); - } else { - if (RB_RIGHT(tmp) == nullptr || RB_IS_BLACK(RB_RIGHT(tmp))) { - Node* oleft; - if ((oleft = RB_LEFT(tmp)) != nullptr) { - RB_SET_COLOR(oleft, EntryColor::Black); - } - - RB_SET_COLOR(tmp, EntryColor::Red); - RB_ROTATE_RIGHT(head, tmp, oleft); - tmp = RB_RIGHT(parent); - } - - RB_SET_COLOR(tmp, RB_COLOR(parent)); - RB_SET_COLOR(parent, EntryColor::Black); - if (RB_RIGHT(tmp)) { - RB_SET_COLOR(RB_RIGHT(tmp), EntryColor::Black); - } - - RB_ROTATE_LEFT(head, parent, tmp); - elm = head->Root(); - break; - } - } else { - tmp = RB_LEFT(parent); - if (RB_IS_RED(tmp)) { - RB_SET_BLACKRED(tmp, parent); - RB_ROTATE_RIGHT(head, parent, tmp); - tmp = RB_LEFT(parent); - } - - if (!tmp) { - ASSERT_MSG(false, "tmp is invalid!"); - break; - } - - if ((RB_LEFT(tmp) == nullptr || RB_IS_BLACK(RB_LEFT(tmp))) && - (RB_RIGHT(tmp) == nullptr || RB_IS_BLACK(RB_RIGHT(tmp)))) { - RB_SET_COLOR(tmp, EntryColor::Red); - elm = parent; - parent = RB_PARENT(elm); - } else { - if (RB_LEFT(tmp) == nullptr || RB_IS_BLACK(RB_LEFT(tmp))) { - Node* oright; - if ((oright = RB_RIGHT(tmp)) != nullptr) { - RB_SET_COLOR(oright, EntryColor::Black); - } - - RB_SET_COLOR(tmp, EntryColor::Red); - RB_ROTATE_LEFT(head, tmp, oright); - tmp = RB_LEFT(parent); - } - - RB_SET_COLOR(tmp, RB_COLOR(parent)); - RB_SET_COLOR(parent, EntryColor::Black); - - if (RB_LEFT(tmp)) { - RB_SET_COLOR(RB_LEFT(tmp), EntryColor::Black); - } - - RB_ROTATE_RIGHT(head, parent, tmp); - elm = head->Root(); - break; - } - } - } - - if (elm) { - RB_SET_COLOR(elm, EntryColor::Black); - } -} - -template -Node* RB_REMOVE(RBHead* head, Node* elm) { - Node* child = nullptr; - Node* parent = nullptr; - Node* old = elm; - EntryColor color{}; - - const auto finalize = [&] { - if (color == EntryColor::Black) { - RB_REMOVE_COLOR(head, parent, child); - } - - return old; - }; - - if (RB_LEFT(elm) == nullptr) { - child = RB_RIGHT(elm); - } else if (RB_RIGHT(elm) == nullptr) { - child = RB_LEFT(elm); - } else { - Node* left; - elm = RB_RIGHT(elm); - while ((left = RB_LEFT(elm)) != nullptr) { - elm = left; - } - - child = RB_RIGHT(elm); - parent = RB_PARENT(elm); - color = RB_COLOR(elm); - - if (child) { - RB_SET_PARENT(child, parent); - } - if (parent) { - if (RB_LEFT(parent) == elm) { - RB_SET_LEFT(parent, child); - } else { - RB_SET_RIGHT(parent, child); - } - } else { - head->SetRoot(child); - } - - if (RB_PARENT(elm) == old) { - parent = elm; - } - - elm->SetEntry(old->GetEntry()); - - if (RB_PARENT(old)) { - if (RB_LEFT(RB_PARENT(old)) == old) { - RB_SET_LEFT(RB_PARENT(old), elm); - } else { - RB_SET_RIGHT(RB_PARENT(old), elm); - } - } else { - head->SetRoot(elm); - } - RB_SET_PARENT(RB_LEFT(old), elm); - if (RB_RIGHT(old)) { - RB_SET_PARENT(RB_RIGHT(old), elm); - } - if (parent) { - left = parent; - } - - return finalize(); - } - - parent = RB_PARENT(elm); - color = RB_COLOR(elm); - - if (child) { - RB_SET_PARENT(child, parent); - } - if (parent) { - if (RB_LEFT(parent) == elm) { - RB_SET_LEFT(parent, child); - } else { - RB_SET_RIGHT(parent, child); - } - } else { - head->SetRoot(child); - } - - return finalize(); -} - -// Inserts a node into the RB tree -template -Node* RB_INSERT(RBHead* head, Node* elm, CompareFunction cmp) { - Node* parent = nullptr; - Node* tmp = head->Root(); +template +requires HasRBEntry +constexpr T* RB_INSERT(RBHead& head, T* elm, Compare cmp) { + T* parent = nullptr; + T* tmp = head.Root(); int comp = 0; while (tmp) { @@ -529,17 +548,17 @@ Node* RB_INSERT(RBHead* head, Node* elm, CompareFunction cmp) { RB_SET_RIGHT(parent, elm); } } else { - head->SetRoot(elm); + head.SetRoot(elm); } RB_INSERT_COLOR(head, elm); return nullptr; } -// Finds the node with the same key as elm -template -Node* RB_FIND(RBHead* head, Node* elm, CompareFunction cmp) { - Node* tmp = head->Root(); +template +requires HasRBEntry +constexpr T* RB_FIND(RBHead& head, T* elm, Compare cmp) { + T* tmp = head.Root(); while (tmp) { const int comp = cmp(elm, tmp); @@ -555,11 +574,11 @@ Node* RB_FIND(RBHead* head, Node* elm, CompareFunction cmp) { return nullptr; } -// Finds the first node greater than or equal to the search key -template -Node* RB_NFIND(RBHead* head, Node* elm, CompareFunction cmp) { - Node* tmp = head->Root(); - Node* res = nullptr; +template +requires HasRBEntry +constexpr T* RB_NFIND(RBHead& head, T* elm, Compare cmp) { + T* tmp = head.Root(); + T* res = nullptr; while (tmp) { const int comp = cmp(elm, tmp); @@ -576,13 +595,13 @@ Node* RB_NFIND(RBHead* head, Node* elm, CompareFunction cmp) { return res; } -// Finds the node with the same key as lelm -template -Node* RB_FIND_LIGHT(RBHead* head, const void* lelm, CompareFunction lcmp) { - Node* tmp = head->Root(); +template +requires HasRBEntry +constexpr T* RB_FIND_KEY(RBHead& head, const U& key, Compare cmp) { + T* tmp = head.Root(); while (tmp) { - const int comp = lcmp(lelm, tmp); + const int comp = cmp(key, tmp); if (comp < 0) { tmp = RB_LEFT(tmp); } else if (comp > 0) { @@ -595,14 +614,14 @@ Node* RB_FIND_LIGHT(RBHead* head, const void* lelm, CompareFunction lcmp) return nullptr; } -// Finds the first node greater than or equal to the search key -template -Node* RB_NFIND_LIGHT(RBHead* head, const void* lelm, CompareFunction lcmp) { - Node* tmp = head->Root(); - Node* res = nullptr; +template +requires HasRBEntry +constexpr T* RB_NFIND_KEY(RBHead& head, const U& key, Compare cmp) { + T* tmp = head.Root(); + T* res = nullptr; while (tmp) { - const int comp = lcmp(lelm, tmp); + const int comp = cmp(key, tmp); if (comp < 0) { res = tmp; tmp = RB_LEFT(tmp); @@ -616,8 +635,43 @@ Node* RB_NFIND_LIGHT(RBHead* head, const void* lelm, CompareFunction lcmp) return res; } -template -Node* RB_NEXT(Node* elm) { +template +requires HasRBEntry +constexpr T* RB_FIND_EXISTING(RBHead& head, T* elm, Compare cmp) { + T* tmp = head.Root(); + + while (true) { + const int comp = cmp(elm, tmp); + if (comp < 0) { + tmp = RB_LEFT(tmp); + } else if (comp > 0) { + tmp = RB_RIGHT(tmp); + } else { + return tmp; + } + } +} + +template +requires HasRBEntry +constexpr T* RB_FIND_EXISTING_KEY(RBHead& head, const U& key, Compare cmp) { + T* tmp = head.Root(); + + while (true) { + const int comp = cmp(key, tmp); + if (comp < 0) { + tmp = RB_LEFT(tmp); + } else if (comp > 0) { + tmp = RB_RIGHT(tmp); + } else { + return tmp; + } + } +} + +template +requires HasRBEntry +constexpr T* RB_NEXT(T* elm) { if (RB_RIGHT(elm)) { elm = RB_RIGHT(elm); while (RB_LEFT(elm)) { @@ -636,8 +690,9 @@ Node* RB_NEXT(Node* elm) { return elm; } -template -Node* RB_PREV(Node* elm) { +template +requires HasRBEntry +constexpr T* RB_PREV(T* elm) { if (RB_LEFT(elm)) { elm = RB_LEFT(elm); while (RB_RIGHT(elm)) { @@ -656,30 +711,32 @@ Node* RB_PREV(Node* elm) { return elm; } -template -Node* RB_MINMAX(RBHead* head, bool is_min) { - Node* tmp = head->Root(); - Node* parent = nullptr; +template +requires HasRBEntry +constexpr T* RB_MIN(RBHead& head) { + T* tmp = head.Root(); + T* parent = nullptr; while (tmp) { parent = tmp; - if (is_min) { - tmp = RB_LEFT(tmp); - } else { - tmp = RB_RIGHT(tmp); - } + tmp = RB_LEFT(tmp); } return parent; } -template -Node* RB_MIN(RBHead* head) { - return RB_MINMAX(head, true); +template +requires HasRBEntry +constexpr T* RB_MAX(RBHead& head) { + T* tmp = head.Root(); + T* parent = nullptr; + + while (tmp) { + parent = tmp; + tmp = RB_RIGHT(tmp); + } + + return parent; } -template -Node* RB_MAX(RBHead* head) { - return RB_MINMAX(head, false); -} -} // namespace Common +} // namespace Common::freebsd diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 5db6a1b3a..6a83d5ceb 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -207,6 +207,7 @@ add_library(core STATIC hle/kernel/k_memory_region.h hle/kernel/k_memory_region_type.h hle/kernel/k_page_bitmap.h + hle/kernel/k_page_buffer.h hle/kernel/k_page_heap.cpp hle/kernel/k_page_heap.h hle/kernel/k_page_linked_list.h @@ -244,6 +245,8 @@ add_library(core STATIC hle/kernel/k_system_control.h hle/kernel/k_thread.cpp hle/kernel/k_thread.h + hle/kernel/k_thread_local_page.cpp + hle/kernel/k_thread_local_page.h hle/kernel/k_thread_queue.cpp hle/kernel/k_thread_queue.h hle/kernel/k_trace.h diff --git a/src/core/hle/ipc_helpers.h b/src/core/hle/ipc_helpers.h index 026257115..3c4e45fcd 100644 --- a/src/core/hle/ipc_helpers.h +++ b/src/core/hle/ipc_helpers.h @@ -385,7 +385,7 @@ public: T PopRaw(); template - std::shared_ptr PopIpcInterface() { + std::weak_ptr PopIpcInterface() { ASSERT(context->Session()->IsDomain()); ASSERT(context->GetDomainMessageHeader().input_object_count > 0); return context->GetDomainHandler(Pop() - 1); diff --git a/src/core/hle/kernel/hle_ipc.cpp b/src/core/hle/kernel/hle_ipc.cpp index e19544c54..9f2175f82 100644 --- a/src/core/hle/kernel/hle_ipc.cpp +++ b/src/core/hle/kernel/hle_ipc.cpp @@ -45,7 +45,7 @@ bool SessionRequestManager::HasSessionRequestHandler(const HLERequestContext& co LOG_CRITICAL(IPC, "object_id {} is too big!", object_id); return false; } - return DomainHandler(object_id - 1) != nullptr; + return DomainHandler(object_id - 1).lock() != nullptr; } else { return session_handler != nullptr; } @@ -53,9 +53,6 @@ bool SessionRequestManager::HasSessionRequestHandler(const HLERequestContext& co void SessionRequestHandler::ClientConnected(KServerSession* session) { session->ClientConnected(shared_from_this()); - - // Ensure our server session is tracked globally. - kernel.RegisterServerSession(session); } void SessionRequestHandler::ClientDisconnected(KServerSession* session) { diff --git a/src/core/hle/kernel/hle_ipc.h b/src/core/hle/kernel/hle_ipc.h index 754b41ff6..670cc741c 100644 --- a/src/core/hle/kernel/hle_ipc.h +++ b/src/core/hle/kernel/hle_ipc.h @@ -94,6 +94,7 @@ protected: std::weak_ptr service_thread; }; +using SessionRequestHandlerWeakPtr = std::weak_ptr; using SessionRequestHandlerPtr = std::shared_ptr; /** @@ -139,7 +140,7 @@ public: } } - SessionRequestHandlerPtr DomainHandler(std::size_t index) const { + SessionRequestHandlerWeakPtr DomainHandler(std::size_t index) const { ASSERT_MSG(index < DomainHandlerCount(), "Unexpected handler index {}", index); return domain_handlers.at(index); } @@ -328,10 +329,10 @@ public: template std::shared_ptr GetDomainHandler(std::size_t index) const { - return std::static_pointer_cast(manager->DomainHandler(index)); + return std::static_pointer_cast(manager.lock()->DomainHandler(index).lock()); } - void SetSessionRequestManager(std::shared_ptr manager_) { + void SetSessionRequestManager(std::weak_ptr manager_) { manager = std::move(manager_); } @@ -374,7 +375,7 @@ private: u32 handles_offset{}; u32 domain_offset{}; - std::shared_ptr manager; + std::weak_ptr manager; KernelCore& kernel; Core::Memory::Memory& memory; diff --git a/src/core/hle/kernel/init/init_slab_setup.cpp b/src/core/hle/kernel/init/init_slab_setup.cpp index 36fc0944a..b0f773ee0 100644 --- a/src/core/hle/kernel/init/init_slab_setup.cpp +++ b/src/core/hle/kernel/init/init_slab_setup.cpp @@ -7,19 +7,23 @@ #include "common/common_funcs.h" #include "common/common_types.h" #include "core/core.h" +#include "core/device_memory.h" #include "core/hardware_properties.h" #include "core/hle/kernel/init/init_slab_setup.h" #include "core/hle/kernel/k_code_memory.h" #include "core/hle/kernel/k_event.h" #include "core/hle/kernel/k_memory_layout.h" #include "core/hle/kernel/k_memory_manager.h" +#include "core/hle/kernel/k_page_buffer.h" #include "core/hle/kernel/k_port.h" #include "core/hle/kernel/k_process.h" #include "core/hle/kernel/k_resource_limit.h" #include "core/hle/kernel/k_session.h" #include "core/hle/kernel/k_shared_memory.h" +#include "core/hle/kernel/k_shared_memory_info.h" #include "core/hle/kernel/k_system_control.h" #include "core/hle/kernel/k_thread.h" +#include "core/hle/kernel/k_thread_local_page.h" #include "core/hle/kernel/k_transfer_memory.h" namespace Kernel::Init { @@ -32,9 +36,13 @@ namespace Kernel::Init { HANDLER(KEvent, (SLAB_COUNT(KEvent)), ##__VA_ARGS__) \ HANDLER(KPort, (SLAB_COUNT(KPort)), ##__VA_ARGS__) \ HANDLER(KSharedMemory, (SLAB_COUNT(KSharedMemory)), ##__VA_ARGS__) \ + HANDLER(KSharedMemoryInfo, (SLAB_COUNT(KSharedMemory) * 8), ##__VA_ARGS__) \ HANDLER(KTransferMemory, (SLAB_COUNT(KTransferMemory)), ##__VA_ARGS__) \ HANDLER(KCodeMemory, (SLAB_COUNT(KCodeMemory)), ##__VA_ARGS__) \ HANDLER(KSession, (SLAB_COUNT(KSession)), ##__VA_ARGS__) \ + HANDLER(KThreadLocalPage, \ + (SLAB_COUNT(KProcess) + (SLAB_COUNT(KProcess) + SLAB_COUNT(KThread)) / 8), \ + ##__VA_ARGS__) \ HANDLER(KResourceLimit, (SLAB_COUNT(KResourceLimit)), ##__VA_ARGS__) namespace { @@ -50,38 +58,46 @@ enum KSlabType : u32 { // Constexpr counts. constexpr size_t SlabCountKProcess = 80; constexpr size_t SlabCountKThread = 800; -constexpr size_t SlabCountKEvent = 700; +constexpr size_t SlabCountKEvent = 900; constexpr size_t SlabCountKInterruptEvent = 100; -constexpr size_t SlabCountKPort = 256 + 0x20; // Extra 0x20 ports over Nintendo for homebrew. +constexpr size_t SlabCountKPort = 384; constexpr size_t SlabCountKSharedMemory = 80; constexpr size_t SlabCountKTransferMemory = 200; constexpr size_t SlabCountKCodeMemory = 10; constexpr size_t SlabCountKDeviceAddressSpace = 300; -constexpr size_t SlabCountKSession = 933; +constexpr size_t SlabCountKSession = 1133; constexpr size_t SlabCountKLightSession = 100; constexpr size_t SlabCountKObjectName = 7; constexpr size_t SlabCountKResourceLimit = 5; constexpr size_t SlabCountKDebug = Core::Hardware::NUM_CPU_CORES; -constexpr size_t SlabCountKAlpha = 1; -constexpr size_t SlabCountKBeta = 6; +constexpr size_t SlabCountKIoPool = 1; +constexpr size_t SlabCountKIoRegion = 6; constexpr size_t SlabCountExtraKThread = 160; +/// Helper function to translate from the slab virtual address to the reserved location in physical +/// memory. +static PAddr TranslateSlabAddrToPhysical(KMemoryLayout& memory_layout, VAddr slab_addr) { + slab_addr -= memory_layout.GetSlabRegionAddress(); + return slab_addr + Core::DramMemoryMap::SlabHeapBase; +} + template VAddr InitializeSlabHeap(Core::System& system, KMemoryLayout& memory_layout, VAddr address, size_t num_objects) { - // TODO(bunnei): This is just a place holder. We should initialize the appropriate KSlabHeap for - // kernel object type T with the backing kernel memory pointer once we emulate kernel memory. const size_t size = Common::AlignUp(sizeof(T) * num_objects, alignof(void*)); VAddr start = Common::AlignUp(address, alignof(T)); - // This is intentionally empty. Once KSlabHeap is fully implemented, we can replace this with - // the pointer to emulated memory to pass along. Until then, KSlabHeap will just allocate/free - // host memory. - void* backing_kernel_memory{}; + // This should use the virtual memory address passed in, but currently, we do not setup the + // kernel virtual memory layout. Instead, we simply map these at a region of physical memory + // that we reserve for the slab heaps. + // TODO(bunnei): Fix this once we support the kernel virtual memory layout. if (size > 0) { + void* backing_kernel_memory{ + system.DeviceMemory().GetPointer(TranslateSlabAddrToPhysical(memory_layout, start))}; + const KMemoryRegion* region = memory_layout.FindVirtual(start + size - 1); ASSERT(region != nullptr); ASSERT(region->IsDerivedFrom(KMemoryRegionType_KernelSlab)); @@ -91,6 +107,12 @@ VAddr InitializeSlabHeap(Core::System& system, KMemoryLayout& memory_layout, VAd return start + size; } +size_t CalculateSlabHeapGapSize() { + constexpr size_t KernelSlabHeapGapSize = 2_MiB - 296_KiB; + static_assert(KernelSlabHeapGapSize <= KernelSlabHeapGapsSizeMax); + return KernelSlabHeapGapSize; +} + } // namespace KSlabResourceCounts KSlabResourceCounts::CreateDefault() { @@ -109,8 +131,8 @@ KSlabResourceCounts KSlabResourceCounts::CreateDefault() { .num_KObjectName = SlabCountKObjectName, .num_KResourceLimit = SlabCountKResourceLimit, .num_KDebug = SlabCountKDebug, - .num_KAlpha = SlabCountKAlpha, - .num_KBeta = SlabCountKBeta, + .num_KIoPool = SlabCountKIoPool, + .num_KIoRegion = SlabCountKIoRegion, }; } @@ -136,11 +158,34 @@ size_t CalculateTotalSlabHeapSize(const KernelCore& kernel) { #undef ADD_SLAB_SIZE // Add the reserved size. - size += KernelSlabHeapGapsSize; + size += CalculateSlabHeapGapSize(); return size; } +void InitializeKPageBufferSlabHeap(Core::System& system) { + auto& kernel = system.Kernel(); + + const auto& counts = kernel.SlabResourceCounts(); + const size_t num_pages = + counts.num_KProcess + counts.num_KThread + (counts.num_KProcess + counts.num_KThread) / 8; + const size_t slab_size = num_pages * PageSize; + + // Reserve memory from the system resource limit. + ASSERT(kernel.GetSystemResourceLimit()->Reserve(LimitableResource::PhysicalMemory, slab_size)); + + // Allocate memory for the slab. + constexpr auto AllocateOption = KMemoryManager::EncodeOption( + KMemoryManager::Pool::System, KMemoryManager::Direction::FromFront); + const PAddr slab_address = + kernel.MemoryManager().AllocateAndOpenContinuous(num_pages, 1, AllocateOption); + ASSERT(slab_address != 0); + + // Initialize the slabheap. + KPageBuffer::InitializeSlabHeap(kernel, system.DeviceMemory().GetPointer(slab_address), + slab_size); +} + void InitializeSlabHeaps(Core::System& system, KMemoryLayout& memory_layout) { auto& kernel = system.Kernel(); @@ -160,13 +205,13 @@ void InitializeSlabHeaps(Core::System& system, KMemoryLayout& memory_layout) { } // Create an array to represent the gaps between the slabs. - const size_t total_gap_size = KernelSlabHeapGapsSize; + const size_t total_gap_size = CalculateSlabHeapGapSize(); std::array slab_gaps; - for (size_t i = 0; i < slab_gaps.size(); i++) { + for (auto& slab_gap : slab_gaps) { // Note: This is an off-by-one error from Nintendo's intention, because GenerateRandomRange // is inclusive. However, Nintendo also has the off-by-one error, and it's "harmless", so we // will include it ourselves. - slab_gaps[i] = KSystemControl::GenerateRandomRange(0, total_gap_size); + slab_gap = KSystemControl::GenerateRandomRange(0, total_gap_size); } // Sort the array, so that we can treat differences between values as offsets to the starts of @@ -177,13 +222,21 @@ void InitializeSlabHeaps(Core::System& system, KMemoryLayout& memory_layout) { } } - for (size_t i = 0; i < slab_types.size(); i++) { + // Track the gaps, so that we can free them to the unused slab tree. + VAddr gap_start = address; + size_t gap_size = 0; + + for (size_t i = 0; i < slab_gaps.size(); i++) { // Add the random gap to the address. - address += (i == 0) ? slab_gaps[0] : slab_gaps[i] - slab_gaps[i - 1]; + const auto cur_gap = (i == 0) ? slab_gaps[0] : slab_gaps[i] - slab_gaps[i - 1]; + address += cur_gap; + gap_size += cur_gap; #define INITIALIZE_SLAB_HEAP(NAME, COUNT, ...) \ case KSlabType_##NAME: \ - address = InitializeSlabHeap(system, memory_layout, address, COUNT); \ + if (COUNT > 0) { \ + address = InitializeSlabHeap(system, memory_layout, address, COUNT); \ + } \ break; // Initialize the slabheap. @@ -192,7 +245,13 @@ void InitializeSlabHeaps(Core::System& system, KMemoryLayout& memory_layout) { FOREACH_SLAB_TYPE(INITIALIZE_SLAB_HEAP) // If we somehow get an invalid type, abort. default: - UNREACHABLE(); + UNREACHABLE_MSG("Unknown slab type: {}", slab_types[i]); + } + + // If we've hit the end of a gap, free it. + if (gap_start + gap_size != address) { + gap_start = address; + gap_size = 0; } } } diff --git a/src/core/hle/kernel/init/init_slab_setup.h b/src/core/hle/kernel/init/init_slab_setup.h index a8f7e0918..f54b67d02 100644 --- a/src/core/hle/kernel/init/init_slab_setup.h +++ b/src/core/hle/kernel/init/init_slab_setup.h @@ -32,12 +32,13 @@ struct KSlabResourceCounts { size_t num_KObjectName; size_t num_KResourceLimit; size_t num_KDebug; - size_t num_KAlpha; - size_t num_KBeta; + size_t num_KIoPool; + size_t num_KIoRegion; }; void InitializeSlabResourceCounts(KernelCore& kernel); size_t CalculateTotalSlabHeapSize(const KernelCore& kernel); +void InitializeKPageBufferSlabHeap(Core::System& system); void InitializeSlabHeaps(Core::System& system, KMemoryLayout& memory_layout); } // namespace Kernel::Init diff --git a/src/core/hle/kernel/k_address_arbiter.cpp b/src/core/hle/kernel/k_address_arbiter.cpp index 1d1f5e5f8..8cdd0490f 100644 --- a/src/core/hle/kernel/k_address_arbiter.cpp +++ b/src/core/hle/kernel/k_address_arbiter.cpp @@ -115,7 +115,7 @@ ResultCode KAddressArbiter::Signal(VAddr addr, s32 count) { { KScopedSchedulerLock sl(kernel); - auto it = thread_tree.nfind_light({addr, -1}); + auto it = thread_tree.nfind_key({addr, -1}); while ((it != thread_tree.end()) && (count <= 0 || num_waiters < count) && (it->GetAddressArbiterKey() == addr)) { // End the thread's wait. @@ -148,7 +148,7 @@ ResultCode KAddressArbiter::SignalAndIncrementIfEqual(VAddr addr, s32 value, s32 return ResultInvalidState; } - auto it = thread_tree.nfind_light({addr, -1}); + auto it = thread_tree.nfind_key({addr, -1}); while ((it != thread_tree.end()) && (count <= 0 || num_waiters < count) && (it->GetAddressArbiterKey() == addr)) { // End the thread's wait. @@ -171,7 +171,7 @@ ResultCode KAddressArbiter::SignalAndModifyByWaitingCountIfEqual(VAddr addr, s32 { [[maybe_unused]] const KScopedSchedulerLock sl(kernel); - auto it = thread_tree.nfind_light({addr, -1}); + auto it = thread_tree.nfind_key({addr, -1}); // Determine the updated value. s32 new_value{}; if (count <= 0) { diff --git a/src/core/hle/kernel/k_condition_variable.cpp b/src/core/hle/kernel/k_condition_variable.cpp index aadcc297a..8e2a9593c 100644 --- a/src/core/hle/kernel/k_condition_variable.cpp +++ b/src/core/hle/kernel/k_condition_variable.cpp @@ -244,7 +244,7 @@ void KConditionVariable::Signal(u64 cv_key, s32 count) { { KScopedSchedulerLock sl(kernel); - auto it = thread_tree.nfind_light({cv_key, -1}); + auto it = thread_tree.nfind_key({cv_key, -1}); while ((it != thread_tree.end()) && (count <= 0 || num_waiters < count) && (it->GetConditionVariableKey() == cv_key)) { KThread* target_thread = std::addressof(*it); diff --git a/src/core/hle/kernel/k_memory_layout.h b/src/core/hle/kernel/k_memory_layout.h index bcddb0d62..0858827b6 100644 --- a/src/core/hle/kernel/k_memory_layout.h +++ b/src/core/hle/kernel/k_memory_layout.h @@ -57,11 +57,11 @@ constexpr std::size_t KernelPageTableHeapSize = GetMaximumOverheadSize(MainMemor constexpr std::size_t KernelInitialPageHeapSize = 128_KiB; constexpr std::size_t KernelSlabHeapDataSize = 5_MiB; -constexpr std::size_t KernelSlabHeapGapsSize = 2_MiB - 64_KiB; -constexpr std::size_t KernelSlabHeapSize = KernelSlabHeapDataSize + KernelSlabHeapGapsSize; +constexpr std::size_t KernelSlabHeapGapsSizeMax = 2_MiB - 64_KiB; +constexpr std::size_t KernelSlabHeapSize = KernelSlabHeapDataSize + KernelSlabHeapGapsSizeMax; // NOTE: This is calculated from KThread slab counts, assuming KThread size <= 0x860. -constexpr std::size_t KernelSlabHeapAdditionalSize = 416_KiB; +constexpr std::size_t KernelSlabHeapAdditionalSize = 0x68000; constexpr std::size_t KernelResourceSize = KernelPageTableHeapSize + KernelInitialPageHeapSize + KernelSlabHeapSize; diff --git a/src/core/hle/kernel/k_page_buffer.h b/src/core/hle/kernel/k_page_buffer.h new file mode 100644 index 000000000..0a9451228 --- /dev/null +++ b/src/core/hle/kernel/k_page_buffer.h @@ -0,0 +1,34 @@ +// Copyright 2022 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include + +#include "common/alignment.h" +#include "common/assert.h" +#include "common/common_types.h" +#include "core/core.h" +#include "core/device_memory.h" +#include "core/hle/kernel/memory_types.h" + +namespace Kernel { + +class KPageBuffer final : public KSlabAllocated { +public: + KPageBuffer() = default; + + static KPageBuffer* FromPhysicalAddress(Core::System& system, PAddr phys_addr) { + ASSERT(Common::IsAligned(phys_addr, PageSize)); + return reinterpret_cast(system.DeviceMemory().GetPointer(phys_addr)); + } + +private: + [[maybe_unused]] alignas(PageSize) std::array m_buffer{}; +}; + +static_assert(sizeof(KPageBuffer) == PageSize); +static_assert(alignof(KPageBuffer) == PageSize); + +} // namespace Kernel diff --git a/src/core/hle/kernel/k_page_table.cpp b/src/core/hle/kernel/k_page_table.cpp index 0602de1f7..02d93b12e 100644 --- a/src/core/hle/kernel/k_page_table.cpp +++ b/src/core/hle/kernel/k_page_table.cpp @@ -424,6 +424,68 @@ ResultCode KPageTable::UnmapCodeMemory(VAddr dst_address, VAddr src_address, std return ResultSuccess; } +VAddr KPageTable::FindFreeArea(VAddr region_start, std::size_t region_num_pages, + std::size_t num_pages, std::size_t alignment, std::size_t offset, + std::size_t guard_pages) { + VAddr address = 0; + + if (num_pages <= region_num_pages) { + if (this->IsAslrEnabled()) { + // Try to directly find a free area up to 8 times. + for (std::size_t i = 0; i < 8; i++) { + const std::size_t random_offset = + KSystemControl::GenerateRandomRange( + 0, (region_num_pages - num_pages - guard_pages) * PageSize / alignment) * + alignment; + const VAddr candidate = + Common::AlignDown((region_start + random_offset), alignment) + offset; + + KMemoryInfo info = this->QueryInfoImpl(candidate); + + if (info.state != KMemoryState::Free) { + continue; + } + if (region_start > candidate) { + continue; + } + if (info.GetAddress() + guard_pages * PageSize > candidate) { + continue; + } + + const VAddr candidate_end = candidate + (num_pages + guard_pages) * PageSize - 1; + if (candidate_end > info.GetLastAddress()) { + continue; + } + if (candidate_end > region_start + region_num_pages * PageSize - 1) { + continue; + } + + address = candidate; + break; + } + // Fall back to finding the first free area with a random offset. + if (address == 0) { + // NOTE: Nintendo does not account for guard pages here. + // This may theoretically cause an offset to be chosen that cannot be mapped. We + // will account for guard pages. + const std::size_t offset_pages = KSystemControl::GenerateRandomRange( + 0, region_num_pages - num_pages - guard_pages); + address = block_manager->FindFreeArea(region_start + offset_pages * PageSize, + region_num_pages - offset_pages, num_pages, + alignment, offset, guard_pages); + } + } + + // Find the first free area. + if (address == 0) { + address = block_manager->FindFreeArea(region_start, region_num_pages, num_pages, + alignment, offset, guard_pages); + } + } + + return address; +} + ResultCode KPageTable::UnmapProcessMemory(VAddr dst_addr, std::size_t size, KPageTable& src_page_table, VAddr src_addr) { KScopedLightLock lk(general_lock); @@ -1055,6 +1117,46 @@ ResultCode KPageTable::MapPages(VAddr address, KPageLinkedList& page_linked_list return ResultSuccess; } +ResultCode KPageTable::MapPages(VAddr* out_addr, std::size_t num_pages, std::size_t alignment, + PAddr phys_addr, bool is_pa_valid, VAddr region_start, + std::size_t region_num_pages, KMemoryState state, + KMemoryPermission perm) { + ASSERT(Common::IsAligned(alignment, PageSize) && alignment >= PageSize); + + // Ensure this is a valid map request. + R_UNLESS(this->CanContain(region_start, region_num_pages * PageSize, state), + ResultInvalidCurrentMemory); + R_UNLESS(num_pages < region_num_pages, ResultOutOfMemory); + + // Lock the table. + KScopedLightLock lk(general_lock); + + // Find a random address to map at. + VAddr addr = this->FindFreeArea(region_start, region_num_pages, num_pages, alignment, 0, + this->GetNumGuardPages()); + R_UNLESS(addr != 0, ResultOutOfMemory); + ASSERT(Common::IsAligned(addr, alignment)); + ASSERT(this->CanContain(addr, num_pages * PageSize, state)); + ASSERT(this->CheckMemoryState(addr, num_pages * PageSize, KMemoryState::All, KMemoryState::Free, + KMemoryPermission::None, KMemoryPermission::None, + KMemoryAttribute::None, KMemoryAttribute::None) + .IsSuccess()); + + // Perform mapping operation. + if (is_pa_valid) { + R_TRY(this->Operate(addr, num_pages, perm, OperationType::Map, phys_addr)); + } else { + UNIMPLEMENTED(); + } + + // Update the blocks. + block_manager->Update(addr, num_pages, state, perm); + + // We successfully mapped the pages. + *out_addr = addr; + return ResultSuccess; +} + ResultCode KPageTable::UnmapPages(VAddr addr, const KPageLinkedList& page_linked_list) { ASSERT(this->IsLockedByCurrentThread()); @@ -1097,6 +1199,30 @@ ResultCode KPageTable::UnmapPages(VAddr addr, KPageLinkedList& page_linked_list, return ResultSuccess; } +ResultCode KPageTable::UnmapPages(VAddr address, std::size_t num_pages, KMemoryState state) { + // Check that the unmap is in range. + const std::size_t size = num_pages * PageSize; + R_UNLESS(this->Contains(address, size), ResultInvalidCurrentMemory); + + // Lock the table. + KScopedLightLock lk(general_lock); + + // Check the memory state. + std::size_t num_allocator_blocks{}; + R_TRY(this->CheckMemoryState(std::addressof(num_allocator_blocks), address, size, + KMemoryState::All, state, KMemoryPermission::None, + KMemoryPermission::None, KMemoryAttribute::All, + KMemoryAttribute::None)); + + // Perform the unmap. + R_TRY(Operate(address, num_pages, KMemoryPermission::None, OperationType::Unmap)); + + // Update the blocks. + block_manager->Update(address, num_pages, KMemoryState::Free, KMemoryPermission::None); + + return ResultSuccess; +} + ResultCode KPageTable::SetProcessMemoryPermission(VAddr addr, std::size_t size, Svc::MemoryPermission svc_perm) { const size_t num_pages = size / PageSize; diff --git a/src/core/hle/kernel/k_page_table.h b/src/core/hle/kernel/k_page_table.h index e99abe36a..54c6adf8d 100644 --- a/src/core/hle/kernel/k_page_table.h +++ b/src/core/hle/kernel/k_page_table.h @@ -46,7 +46,14 @@ public: ResultCode UnmapMemory(VAddr dst_addr, VAddr src_addr, std::size_t size); ResultCode MapPages(VAddr addr, KPageLinkedList& page_linked_list, KMemoryState state, KMemoryPermission perm); + ResultCode MapPages(VAddr* out_addr, std::size_t num_pages, std::size_t alignment, + PAddr phys_addr, KMemoryState state, KMemoryPermission perm) { + return this->MapPages(out_addr, num_pages, alignment, phys_addr, true, + this->GetRegionAddress(state), this->GetRegionSize(state) / PageSize, + state, perm); + } ResultCode UnmapPages(VAddr addr, KPageLinkedList& page_linked_list, KMemoryState state); + ResultCode UnmapPages(VAddr address, std::size_t num_pages, KMemoryState state); ResultCode SetProcessMemoryPermission(VAddr addr, std::size_t size, Svc::MemoryPermission svc_perm); KMemoryInfo QueryInfo(VAddr addr); @@ -91,6 +98,9 @@ private: ResultCode InitializeMemoryLayout(VAddr start, VAddr end); ResultCode MapPages(VAddr addr, const KPageLinkedList& page_linked_list, KMemoryPermission perm); + ResultCode MapPages(VAddr* out_addr, std::size_t num_pages, std::size_t alignment, + PAddr phys_addr, bool is_pa_valid, VAddr region_start, + std::size_t region_num_pages, KMemoryState state, KMemoryPermission perm); ResultCode UnmapPages(VAddr addr, const KPageLinkedList& page_linked_list); bool IsRegionMapped(VAddr address, u64 size); bool IsRegionContiguous(VAddr addr, u64 size) const; @@ -105,6 +115,9 @@ private: VAddr GetRegionAddress(KMemoryState state) const; std::size_t GetRegionSize(KMemoryState state) const; + VAddr FindFreeArea(VAddr region_start, std::size_t region_num_pages, std::size_t num_pages, + std::size_t alignment, std::size_t offset, std::size_t guard_pages); + ResultCode CheckMemoryStateContiguous(std::size_t* out_blocks_needed, VAddr addr, std::size_t size, KMemoryState state_mask, KMemoryState state, KMemoryPermission perm_mask, @@ -137,7 +150,7 @@ private: return CheckMemoryState(nullptr, nullptr, nullptr, out_blocks_needed, addr, size, state_mask, state, perm_mask, perm, attr_mask, attr, ignore_attr); } - ResultCode CheckMemoryState(VAddr addr, size_t size, KMemoryState state_mask, + ResultCode CheckMemoryState(VAddr addr, std::size_t size, KMemoryState state_mask, KMemoryState state, KMemoryPermission perm_mask, KMemoryPermission perm, KMemoryAttribute attr_mask, KMemoryAttribute attr, @@ -210,7 +223,7 @@ public: constexpr VAddr GetAliasCodeRegionSize() const { return alias_code_region_end - alias_code_region_start; } - size_t GetNormalMemorySize() { + std::size_t GetNormalMemorySize() { KScopedLightLock lk(general_lock); return GetHeapSize() + mapped_physical_memory_size; } diff --git a/src/core/hle/kernel/k_port.cpp b/src/core/hle/kernel/k_port.cpp index a8ba09c4a..ceb98709f 100644 --- a/src/core/hle/kernel/k_port.cpp +++ b/src/core/hle/kernel/k_port.cpp @@ -57,7 +57,12 @@ ResultCode KPort::EnqueueSession(KServerSession* session) { R_UNLESS(state == State::Normal, ResultPortClosed); server.EnqueueSession(session); - server.GetSessionRequestHandler()->ClientConnected(server.AcceptSession()); + + if (auto session_ptr = server.GetSessionRequestHandler().lock()) { + session_ptr->ClientConnected(server.AcceptSession()); + } else { + UNREACHABLE(); + } return ResultSuccess; } diff --git a/src/core/hle/kernel/k_process.cpp b/src/core/hle/kernel/k_process.cpp index 9233261cd..b39405496 100644 --- a/src/core/hle/kernel/k_process.cpp +++ b/src/core/hle/kernel/k_process.cpp @@ -70,58 +70,6 @@ void SetupMainThread(Core::System& system, KProcess& owner_process, u32 priority } } // Anonymous namespace -// Represents a page used for thread-local storage. -// -// Each TLS page contains slots that may be used by processes and threads. -// Every process and thread is created with a slot in some arbitrary page -// (whichever page happens to have an available slot). -class TLSPage { -public: - static constexpr std::size_t num_slot_entries = - Core::Memory::PAGE_SIZE / Core::Memory::TLS_ENTRY_SIZE; - - explicit TLSPage(VAddr address) : base_address{address} {} - - bool HasAvailableSlots() const { - return !is_slot_used.all(); - } - - VAddr GetBaseAddress() const { - return base_address; - } - - std::optional ReserveSlot() { - for (std::size_t i = 0; i < is_slot_used.size(); i++) { - if (is_slot_used[i]) { - continue; - } - - is_slot_used[i] = true; - return base_address + (i * Core::Memory::TLS_ENTRY_SIZE); - } - - return std::nullopt; - } - - void ReleaseSlot(VAddr address) { - // Ensure that all given addresses are consistent with how TLS pages - // are intended to be used when releasing slots. - ASSERT(IsWithinPage(address)); - ASSERT((address % Core::Memory::TLS_ENTRY_SIZE) == 0); - - const std::size_t index = (address - base_address) / Core::Memory::TLS_ENTRY_SIZE; - is_slot_used[index] = false; - } - -private: - bool IsWithinPage(VAddr address) const { - return base_address <= address && address < base_address + Core::Memory::PAGE_SIZE; - } - - VAddr base_address; - std::bitset is_slot_used; -}; - ResultCode KProcess::Initialize(KProcess* process, Core::System& system, std::string process_name, ProcessType type, KResourceLimit* res_limit) { auto& kernel = system.Kernel(); @@ -404,7 +352,7 @@ ResultCode KProcess::LoadFromMetadata(const FileSys::ProgramMetadata& metadata, } // Create TLS region - tls_region_address = CreateTLSRegion(); + R_TRY(this->CreateThreadLocalRegion(std::addressof(tls_region_address))); memory_reservation.Commit(); return handle_table.Initialize(capabilities.GetHandleTableSize()); @@ -444,7 +392,7 @@ void KProcess::PrepareForTermination() { stop_threads(kernel.System().GlobalSchedulerContext().GetThreadList()); - FreeTLSRegion(tls_region_address); + this->DeleteThreadLocalRegion(tls_region_address); tls_region_address = 0; if (resource_limit) { @@ -456,9 +404,6 @@ void KProcess::PrepareForTermination() { } void KProcess::Finalize() { - // Finalize the handle table and close any open handles. - handle_table.Finalize(); - // Free all shared memory infos. { auto it = shared_memory_list.begin(); @@ -483,67 +428,110 @@ void KProcess::Finalize() { resource_limit = nullptr; } + // Finalize the page table. + page_table.reset(); + // Perform inherited finalization. KAutoObjectWithSlabHeapAndContainer::Finalize(); } -/** - * Attempts to find a TLS page that contains a free slot for - * use by a thread. - * - * @returns If a page with an available slot is found, then an iterator - * pointing to the page is returned. Otherwise the end iterator - * is returned instead. - */ -static auto FindTLSPageWithAvailableSlots(std::vector& tls_pages) { - return std::find_if(tls_pages.begin(), tls_pages.end(), - [](const auto& page) { return page.HasAvailableSlots(); }); -} +ResultCode KProcess::CreateThreadLocalRegion(VAddr* out) { + KThreadLocalPage* tlp = nullptr; + VAddr tlr = 0; -VAddr KProcess::CreateTLSRegion() { - KScopedSchedulerLock lock(kernel); - if (auto tls_page_iter{FindTLSPageWithAvailableSlots(tls_pages)}; - tls_page_iter != tls_pages.cend()) { - return *tls_page_iter->ReserveSlot(); + // See if we can get a region from a partially used TLP. + { + KScopedSchedulerLock sl{kernel}; + + if (auto it = partially_used_tlp_tree.begin(); it != partially_used_tlp_tree.end()) { + tlr = it->Reserve(); + ASSERT(tlr != 0); + + if (it->IsAllUsed()) { + tlp = std::addressof(*it); + partially_used_tlp_tree.erase(it); + fully_used_tlp_tree.insert(*tlp); + } + + *out = tlr; + return ResultSuccess; + } } - Page* const tls_page_ptr{kernel.GetUserSlabHeapPages().Allocate()}; - ASSERT(tls_page_ptr); + // Allocate a new page. + tlp = KThreadLocalPage::Allocate(kernel); + R_UNLESS(tlp != nullptr, ResultOutOfMemory); + auto tlp_guard = SCOPE_GUARD({ KThreadLocalPage::Free(kernel, tlp); }); - const VAddr start{page_table->GetKernelMapRegionStart()}; - const VAddr size{page_table->GetKernelMapRegionEnd() - start}; - const PAddr tls_map_addr{kernel.System().DeviceMemory().GetPhysicalAddr(tls_page_ptr)}; - const VAddr tls_page_addr{page_table - ->AllocateAndMapMemory(1, PageSize, true, start, size / PageSize, - KMemoryState::ThreadLocal, - KMemoryPermission::UserReadWrite, - tls_map_addr) - .ValueOr(0)}; + // Initialize the new page. + R_TRY(tlp->Initialize(kernel, this)); - ASSERT(tls_page_addr); + // Reserve a TLR. + tlr = tlp->Reserve(); + ASSERT(tlr != 0); - std::memset(tls_page_ptr, 0, PageSize); - tls_pages.emplace_back(tls_page_addr); + // Insert into our tree. + { + KScopedSchedulerLock sl{kernel}; + if (tlp->IsAllUsed()) { + fully_used_tlp_tree.insert(*tlp); + } else { + partially_used_tlp_tree.insert(*tlp); + } + } - const auto reserve_result{tls_pages.back().ReserveSlot()}; - ASSERT(reserve_result.has_value()); - - return *reserve_result; + // We succeeded! + tlp_guard.Cancel(); + *out = tlr; + return ResultSuccess; } -void KProcess::FreeTLSRegion(VAddr tls_address) { - KScopedSchedulerLock lock(kernel); - const VAddr aligned_address = Common::AlignDown(tls_address, Core::Memory::PAGE_SIZE); - auto iter = - std::find_if(tls_pages.begin(), tls_pages.end(), [aligned_address](const auto& page) { - return page.GetBaseAddress() == aligned_address; - }); +ResultCode KProcess::DeleteThreadLocalRegion(VAddr addr) { + KThreadLocalPage* page_to_free = nullptr; - // Something has gone very wrong if we're freeing a region - // with no actual page available. - ASSERT(iter != tls_pages.cend()); + // Release the region. + { + KScopedSchedulerLock sl{kernel}; - iter->ReleaseSlot(tls_address); + // Try to find the page in the partially used list. + auto it = partially_used_tlp_tree.find_key(Common::AlignDown(addr, PageSize)); + if (it == partially_used_tlp_tree.end()) { + // If we don't find it, it has to be in the fully used list. + it = fully_used_tlp_tree.find_key(Common::AlignDown(addr, PageSize)); + R_UNLESS(it != fully_used_tlp_tree.end(), ResultInvalidAddress); + + // Release the region. + it->Release(addr); + + // Move the page out of the fully used list. + KThreadLocalPage* tlp = std::addressof(*it); + fully_used_tlp_tree.erase(it); + if (tlp->IsAllFree()) { + page_to_free = tlp; + } else { + partially_used_tlp_tree.insert(*tlp); + } + } else { + // Release the region. + it->Release(addr); + + // Handle the all-free case. + KThreadLocalPage* tlp = std::addressof(*it); + if (tlp->IsAllFree()) { + partially_used_tlp_tree.erase(it); + page_to_free = tlp; + } + } + } + + // If we should free the page it was in, do so. + if (page_to_free != nullptr) { + page_to_free->Finalize(); + + KThreadLocalPage::Free(kernel, page_to_free); + } + + return ResultSuccess; } void KProcess::LoadModule(CodeSet code_set, VAddr base_addr) { diff --git a/src/core/hle/kernel/k_process.h b/src/core/hle/kernel/k_process.h index cf1b67428..5ed0f2d83 100644 --- a/src/core/hle/kernel/k_process.h +++ b/src/core/hle/kernel/k_process.h @@ -15,6 +15,7 @@ #include "core/hle/kernel/k_condition_variable.h" #include "core/hle/kernel/k_handle_table.h" #include "core/hle/kernel/k_synchronization_object.h" +#include "core/hle/kernel/k_thread_local_page.h" #include "core/hle/kernel/k_worker_task.h" #include "core/hle/kernel/process_capability.h" #include "core/hle/kernel/slab_helpers.h" @@ -362,10 +363,10 @@ public: // Thread-local storage management // Marks the next available region as used and returns the address of the slot. - [[nodiscard]] VAddr CreateTLSRegion(); + [[nodiscard]] ResultCode CreateThreadLocalRegion(VAddr* out); // Frees a used TLS slot identified by the given address - void FreeTLSRegion(VAddr tls_address); + ResultCode DeleteThreadLocalRegion(VAddr addr); private: void PinThread(s32 core_id, KThread* thread) { @@ -413,13 +414,6 @@ private: /// The ideal CPU core for this process, threads are scheduled on this core by default. u8 ideal_core = 0; - /// The Thread Local Storage area is allocated as processes create threads, - /// each TLS area is 0x200 bytes, so one page (0x1000) is split up in 8 parts, and each part - /// holds the TLS for a specific thread. This vector contains which parts are in use for each - /// page as a bitmask. - /// This vector will grow as more pages are allocated for new threads. - std::vector tls_pages; - /// Contains the parsed process capability descriptors. ProcessCapabilities capabilities; @@ -482,6 +476,12 @@ private: KThread* exception_thread{}; KLightLock state_lock; + + using TLPTree = + Common::IntrusiveRedBlackTreeBaseTraits::TreeType; + using TLPIterator = TLPTree::iterator; + TLPTree fully_used_tlp_tree; + TLPTree partially_used_tlp_tree; }; } // namespace Kernel diff --git a/src/core/hle/kernel/k_server_port.h b/src/core/hle/kernel/k_server_port.h index 6302d5e61..2185736be 100644 --- a/src/core/hle/kernel/k_server_port.h +++ b/src/core/hle/kernel/k_server_port.h @@ -30,11 +30,11 @@ public: /// Whether or not this server port has an HLE handler available. bool HasSessionRequestHandler() const { - return session_handler != nullptr; + return !session_handler.expired(); } /// Gets the HLE handler for this port. - SessionRequestHandlerPtr GetSessionRequestHandler() const { + SessionRequestHandlerWeakPtr GetSessionRequestHandler() const { return session_handler; } @@ -42,7 +42,7 @@ public: * Sets the HLE handler template for the port. ServerSessions crated by connecting to this port * will inherit a reference to this handler. */ - void SetSessionHandler(SessionRequestHandlerPtr&& handler) { + void SetSessionHandler(SessionRequestHandlerWeakPtr&& handler) { session_handler = std::move(handler); } @@ -66,7 +66,7 @@ private: void CleanupSessions(); SessionList session_list; - SessionRequestHandlerPtr session_handler; + SessionRequestHandlerWeakPtr session_handler; KPort* parent{}; }; diff --git a/src/core/hle/kernel/k_server_session.cpp b/src/core/hle/kernel/k_server_session.cpp index 4d94eb9cf..30c56ff29 100644 --- a/src/core/hle/kernel/k_server_session.cpp +++ b/src/core/hle/kernel/k_server_session.cpp @@ -27,10 +27,7 @@ namespace Kernel { KServerSession::KServerSession(KernelCore& kernel_) : KSynchronizationObject{kernel_} {} -KServerSession::~KServerSession() { - // Ensure that the global list tracking server sessions does not hold on to a reference. - kernel.UnregisterServerSession(this); -} +KServerSession::~KServerSession() = default; void KServerSession::Initialize(KSession* parent_session_, std::string&& name_, std::shared_ptr manager_) { @@ -49,6 +46,9 @@ void KServerSession::Destroy() { parent->OnServerClosed(); parent->Close(); + + // Release host emulation members. + manager.reset(); } void KServerSession::OnClientClosed() { @@ -98,7 +98,12 @@ ResultCode KServerSession::HandleDomainSyncRequest(Kernel::HLERequestContext& co UNREACHABLE(); return ResultSuccess; // Ignore error if asserts are off } - return manager->DomainHandler(object_id - 1)->HandleSyncRequest(*this, context); + if (auto strong_ptr = manager->DomainHandler(object_id - 1).lock()) { + return strong_ptr->HandleSyncRequest(*this, context); + } else { + UNREACHABLE(); + return ResultSuccess; + } case IPC::DomainMessageHeader::CommandType::CloseVirtualHandle: { LOG_DEBUG(IPC, "CloseVirtualHandle, object_id=0x{:08X}", object_id); diff --git a/src/core/hle/kernel/k_slab_heap.h b/src/core/hle/kernel/k_slab_heap.h index 05c0bec9c..5690cc757 100644 --- a/src/core/hle/kernel/k_slab_heap.h +++ b/src/core/hle/kernel/k_slab_heap.h @@ -16,39 +16,34 @@ class KernelCore; namespace impl { -class KSlabHeapImpl final { -public: +class KSlabHeapImpl { YUZU_NON_COPYABLE(KSlabHeapImpl); YUZU_NON_MOVEABLE(KSlabHeapImpl); +public: struct Node { Node* next{}; }; +public: constexpr KSlabHeapImpl() = default; - constexpr ~KSlabHeapImpl() = default; - void Initialize(std::size_t size) { - ASSERT(head == nullptr); - obj_size = size; - } - - constexpr std::size_t GetObjectSize() const { - return obj_size; + void Initialize() { + ASSERT(m_head == nullptr); } Node* GetHead() const { - return head; + return m_head; } void* Allocate() { - Node* ret = head.load(); + Node* ret = m_head.load(); do { if (ret == nullptr) { break; } - } while (!head.compare_exchange_weak(ret, ret->next)); + } while (!m_head.compare_exchange_weak(ret, ret->next)); return ret; } @@ -56,170 +51,157 @@ public: void Free(void* obj) { Node* node = static_cast(obj); - Node* cur_head = head.load(); + Node* cur_head = m_head.load(); do { node->next = cur_head; - } while (!head.compare_exchange_weak(cur_head, node)); + } while (!m_head.compare_exchange_weak(cur_head, node)); } private: - std::atomic head{}; - std::size_t obj_size{}; + std::atomic m_head{}; }; } // namespace impl -class KSlabHeapBase { -public: +template +class KSlabHeapBase : protected impl::KSlabHeapImpl { YUZU_NON_COPYABLE(KSlabHeapBase); YUZU_NON_MOVEABLE(KSlabHeapBase); +private: + size_t m_obj_size{}; + uintptr_t m_peak{}; + uintptr_t m_start{}; + uintptr_t m_end{}; + +private: + void UpdatePeakImpl(uintptr_t obj) { + static_assert(std::atomic_ref::is_always_lock_free); + std::atomic_ref peak_ref(m_peak); + + const uintptr_t alloc_peak = obj + this->GetObjectSize(); + uintptr_t cur_peak = m_peak; + do { + if (alloc_peak <= cur_peak) { + break; + } + } while (!peak_ref.compare_exchange_strong(cur_peak, alloc_peak)); + } + +public: constexpr KSlabHeapBase() = default; - constexpr ~KSlabHeapBase() = default; - constexpr bool Contains(uintptr_t addr) const { - return start <= addr && addr < end; + bool Contains(uintptr_t address) const { + return m_start <= address && address < m_end; } - constexpr std::size_t GetSlabHeapSize() const { - return (end - start) / GetObjectSize(); - } - - constexpr std::size_t GetObjectSize() const { - return impl.GetObjectSize(); - } - - constexpr uintptr_t GetSlabHeapAddress() const { - return start; - } - - std::size_t GetObjectIndexImpl(const void* obj) const { - return (reinterpret_cast(obj) - start) / GetObjectSize(); - } - - std::size_t GetPeakIndex() const { - return GetObjectIndexImpl(reinterpret_cast(peak)); - } - - void* AllocateImpl() { - return impl.Allocate(); - } - - void FreeImpl(void* obj) { - // Don't allow freeing an object that wasn't allocated from this heap - ASSERT(Contains(reinterpret_cast(obj))); - - impl.Free(obj); - } - - void InitializeImpl(std::size_t obj_size, void* memory, std::size_t memory_size) { - // Ensure we don't initialize a slab using null memory + void Initialize(size_t obj_size, void* memory, size_t memory_size) { + // Ensure we don't initialize a slab using null memory. ASSERT(memory != nullptr); - // Initialize the base allocator - impl.Initialize(obj_size); + // Set our object size. + m_obj_size = obj_size; - // Set our tracking variables - const std::size_t num_obj = (memory_size / obj_size); - start = reinterpret_cast(memory); - end = start + num_obj * obj_size; - peak = start; + // Initialize the base allocator. + KSlabHeapImpl::Initialize(); - // Free the objects - u8* cur = reinterpret_cast(end); + // Set our tracking variables. + const size_t num_obj = (memory_size / obj_size); + m_start = reinterpret_cast(memory); + m_end = m_start + num_obj * obj_size; + m_peak = m_start; - for (std::size_t i{}; i < num_obj; i++) { + // Free the objects. + u8* cur = reinterpret_cast(m_end); + + for (size_t i = 0; i < num_obj; i++) { cur -= obj_size; - impl.Free(cur); + KSlabHeapImpl::Free(cur); } } -private: - using Impl = impl::KSlabHeapImpl; + size_t GetSlabHeapSize() const { + return (m_end - m_start) / this->GetObjectSize(); + } - Impl impl; - uintptr_t peak{}; - uintptr_t start{}; - uintptr_t end{}; + size_t GetObjectSize() const { + return m_obj_size; + } + + void* Allocate() { + void* obj = KSlabHeapImpl::Allocate(); + + return obj; + } + + void Free(void* obj) { + // Don't allow freeing an object that wasn't allocated from this heap. + const bool contained = this->Contains(reinterpret_cast(obj)); + ASSERT(contained); + KSlabHeapImpl::Free(obj); + } + + size_t GetObjectIndex(const void* obj) const { + if constexpr (SupportDynamicExpansion) { + if (!this->Contains(reinterpret_cast(obj))) { + return std::numeric_limits::max(); + } + } + + return (reinterpret_cast(obj) - m_start) / this->GetObjectSize(); + } + + size_t GetPeakIndex() const { + return this->GetObjectIndex(reinterpret_cast(m_peak)); + } + + uintptr_t GetSlabHeapAddress() const { + return m_start; + } + + size_t GetNumRemaining() const { + // Only calculate the number of remaining objects under debug configuration. + return 0; + } }; template -class KSlabHeap final : public KSlabHeapBase { +class KSlabHeap final : public KSlabHeapBase { +private: + using BaseHeap = KSlabHeapBase; + public: - enum class AllocationType { - Host, - Guest, - }; + constexpr KSlabHeap() = default; - explicit constexpr KSlabHeap(AllocationType allocation_type_ = AllocationType::Host) - : KSlabHeapBase(), allocation_type{allocation_type_} {} - - void Initialize(void* memory, std::size_t memory_size) { - if (allocation_type == AllocationType::Guest) { - InitializeImpl(sizeof(T), memory, memory_size); - } + void Initialize(void* memory, size_t memory_size) { + BaseHeap::Initialize(sizeof(T), memory, memory_size); } T* Allocate() { - switch (allocation_type) { - case AllocationType::Host: - // Fallback for cases where we do not yet support allocating guest memory from the slab - // heap, such as for kernel memory regions. - return new T; + T* obj = static_cast(BaseHeap::Allocate()); - case AllocationType::Guest: - T* obj = static_cast(AllocateImpl()); - if (obj != nullptr) { - new (obj) T(); - } - return obj; + if (obj != nullptr) [[likely]] { + std::construct_at(obj); } - - UNREACHABLE_MSG("Invalid AllocationType {}", allocation_type); - return nullptr; + return obj; } - T* AllocateWithKernel(KernelCore& kernel) { - switch (allocation_type) { - case AllocationType::Host: - // Fallback for cases where we do not yet support allocating guest memory from the slab - // heap, such as for kernel memory regions. - return new T(kernel); + T* Allocate(KernelCore& kernel) { + T* obj = static_cast(BaseHeap::Allocate()); - case AllocationType::Guest: - T* obj = static_cast(AllocateImpl()); - if (obj != nullptr) { - new (obj) T(kernel); - } - return obj; + if (obj != nullptr) [[likely]] { + std::construct_at(obj, kernel); } - - UNREACHABLE_MSG("Invalid AllocationType {}", allocation_type); - return nullptr; + return obj; } void Free(T* obj) { - switch (allocation_type) { - case AllocationType::Host: - // Fallback for cases where we do not yet support allocating guest memory from the slab - // heap, such as for kernel memory regions. - delete obj; - return; - - case AllocationType::Guest: - FreeImpl(obj); - return; - } - - UNREACHABLE_MSG("Invalid AllocationType {}", allocation_type); + BaseHeap::Free(obj); } - constexpr std::size_t GetObjectIndex(const T* obj) const { - return GetObjectIndexImpl(obj); + size_t GetObjectIndex(const T* obj) const { + return BaseHeap::GetObjectIndex(obj); } - -private: - const AllocationType allocation_type; }; } // namespace Kernel diff --git a/src/core/hle/kernel/k_thread.cpp b/src/core/hle/kernel/k_thread.cpp index de3ffe0c7..ba7f72c6b 100644 --- a/src/core/hle/kernel/k_thread.cpp +++ b/src/core/hle/kernel/k_thread.cpp @@ -210,7 +210,7 @@ ResultCode KThread::Initialize(KThreadFunction func, uintptr_t arg, VAddr user_s if (owner != nullptr) { // Setup the TLS, if needed. if (type == ThreadType::User) { - tls_address = owner->CreateTLSRegion(); + R_TRY(owner->CreateThreadLocalRegion(std::addressof(tls_address))); } parent = owner; @@ -305,7 +305,7 @@ void KThread::Finalize() { // If the thread has a local region, delete it. if (tls_address != 0) { - parent->FreeTLSRegion(tls_address); + ASSERT(parent->DeleteThreadLocalRegion(tls_address).IsSuccess()); } // Release any waiters. @@ -326,6 +326,9 @@ void KThread::Finalize() { } } + // Release host emulation members. + host_context.reset(); + // Perform inherited finalization. KSynchronizationObject::Finalize(); } diff --git a/src/core/hle/kernel/k_thread.h b/src/core/hle/kernel/k_thread.h index d058db62c..f46db7298 100644 --- a/src/core/hle/kernel/k_thread.h +++ b/src/core/hle/kernel/k_thread.h @@ -656,7 +656,7 @@ private: static_assert(sizeof(SyncObjectBuffer::sync_objects) == sizeof(SyncObjectBuffer::handles)); struct ConditionVariableComparator { - struct LightCompareType { + struct RedBlackKeyType { u64 cv_key{}; s32 priority{}; @@ -672,8 +672,8 @@ private: template requires( std::same_as || - std::same_as) static constexpr int Compare(const T& lhs, - const KThread& rhs) { + std::same_as) static constexpr int Compare(const T& lhs, + const KThread& rhs) { const u64 l_key = lhs.GetConditionVariableKey(); const u64 r_key = rhs.GetConditionVariableKey(); diff --git a/src/core/hle/kernel/k_thread_local_page.cpp b/src/core/hle/kernel/k_thread_local_page.cpp new file mode 100644 index 000000000..4653c29f6 --- /dev/null +++ b/src/core/hle/kernel/k_thread_local_page.cpp @@ -0,0 +1,65 @@ +// Copyright 2022 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#include "common/scope_exit.h" +#include "core/hle/kernel/k_memory_block.h" +#include "core/hle/kernel/k_page_table.h" +#include "core/hle/kernel/k_process.h" +#include "core/hle/kernel/k_thread_local_page.h" +#include "core/hle/kernel/kernel.h" + +namespace Kernel { + +ResultCode KThreadLocalPage::Initialize(KernelCore& kernel, KProcess* process) { + // Set that this process owns us. + m_owner = process; + m_kernel = &kernel; + + // Allocate a new page. + KPageBuffer* page_buf = KPageBuffer::Allocate(kernel); + R_UNLESS(page_buf != nullptr, ResultOutOfMemory); + auto page_buf_guard = SCOPE_GUARD({ KPageBuffer::Free(kernel, page_buf); }); + + // Map the address in. + const auto phys_addr = kernel.System().DeviceMemory().GetPhysicalAddr(page_buf); + R_TRY(m_owner->PageTable().MapPages(std::addressof(m_virt_addr), 1, PageSize, phys_addr, + KMemoryState::ThreadLocal, + KMemoryPermission::UserReadWrite)); + + // We succeeded. + page_buf_guard.Cancel(); + + return ResultSuccess; +} + +ResultCode KThreadLocalPage::Finalize() { + // Get the physical address of the page. + const PAddr phys_addr = m_owner->PageTable().GetPhysicalAddr(m_virt_addr); + ASSERT(phys_addr); + + // Unmap the page. + R_TRY(m_owner->PageTable().UnmapPages(this->GetAddress(), 1, KMemoryState::ThreadLocal)); + + // Free the page. + KPageBuffer::Free(*m_kernel, KPageBuffer::FromPhysicalAddress(m_kernel->System(), phys_addr)); + + return ResultSuccess; +} + +VAddr KThreadLocalPage::Reserve() { + for (size_t i = 0; i < m_is_region_free.size(); i++) { + if (m_is_region_free[i]) { + m_is_region_free[i] = false; + return this->GetRegionAddress(i); + } + } + + return 0; +} + +void KThreadLocalPage::Release(VAddr addr) { + m_is_region_free[this->GetRegionIndex(addr)] = true; +} + +} // namespace Kernel diff --git a/src/core/hle/kernel/k_thread_local_page.h b/src/core/hle/kernel/k_thread_local_page.h new file mode 100644 index 000000000..658c67e94 --- /dev/null +++ b/src/core/hle/kernel/k_thread_local_page.h @@ -0,0 +1,112 @@ +// Copyright 2022 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include +#include + +#include "common/alignment.h" +#include "common/assert.h" +#include "common/common_types.h" +#include "common/intrusive_red_black_tree.h" +#include "core/hle/kernel/k_page_buffer.h" +#include "core/hle/kernel/memory_types.h" +#include "core/hle/kernel/slab_helpers.h" +#include "core/hle/result.h" + +namespace Kernel { + +class KernelCore; +class KProcess; + +class KThreadLocalPage final : public Common::IntrusiveRedBlackTreeBaseNode, + public KSlabAllocated { +public: + static constexpr size_t RegionsPerPage = PageSize / Svc::ThreadLocalRegionSize; + static_assert(RegionsPerPage > 0); + +public: + constexpr explicit KThreadLocalPage(VAddr addr = {}) : m_virt_addr(addr) { + m_is_region_free.fill(true); + } + + constexpr VAddr GetAddress() const { + return m_virt_addr; + } + + ResultCode Initialize(KernelCore& kernel, KProcess* process); + ResultCode Finalize(); + + VAddr Reserve(); + void Release(VAddr addr); + + bool IsAllUsed() const { + return std::ranges::all_of(m_is_region_free.begin(), m_is_region_free.end(), + [](bool is_free) { return !is_free; }); + } + + bool IsAllFree() const { + return std::ranges::all_of(m_is_region_free.begin(), m_is_region_free.end(), + [](bool is_free) { return is_free; }); + } + + bool IsAnyUsed() const { + return !this->IsAllFree(); + } + + bool IsAnyFree() const { + return !this->IsAllUsed(); + } + +public: + using RedBlackKeyType = VAddr; + + static constexpr RedBlackKeyType GetRedBlackKey(const RedBlackKeyType& v) { + return v; + } + static constexpr RedBlackKeyType GetRedBlackKey(const KThreadLocalPage& v) { + return v.GetAddress(); + } + + template + requires(std::same_as || + std::same_as) static constexpr int Compare(const T& lhs, + const KThreadLocalPage& + rhs) { + const VAddr lval = GetRedBlackKey(lhs); + const VAddr rval = GetRedBlackKey(rhs); + + if (lval < rval) { + return -1; + } else if (lval == rval) { + return 0; + } else { + return 1; + } + } + +private: + constexpr VAddr GetRegionAddress(size_t i) const { + return this->GetAddress() + i * Svc::ThreadLocalRegionSize; + } + + constexpr bool Contains(VAddr addr) const { + return this->GetAddress() <= addr && addr < this->GetAddress() + PageSize; + } + + constexpr size_t GetRegionIndex(VAddr addr) const { + ASSERT(Common::IsAligned(addr, Svc::ThreadLocalRegionSize)); + ASSERT(this->Contains(addr)); + return (addr - this->GetAddress()) / Svc::ThreadLocalRegionSize; + } + +private: + VAddr m_virt_addr{}; + KProcess* m_owner{}; + KernelCore* m_kernel{}; + std::array m_is_region_free{}; +}; + +} // namespace Kernel diff --git a/src/core/hle/kernel/kernel.cpp b/src/core/hle/kernel/kernel.cpp index 71bd466cf..f9828bc43 100644 --- a/src/core/hle/kernel/kernel.cpp +++ b/src/core/hle/kernel/kernel.cpp @@ -52,7 +52,7 @@ namespace Kernel { struct KernelCore::Impl { explicit Impl(Core::System& system_, KernelCore& kernel_) - : time_manager{system_}, object_list_container{kernel_}, + : time_manager{system_}, service_threads_manager{1, "yuzu:ServiceThreadsManager"}, system{system_} {} void SetMulticore(bool is_multi) { @@ -60,6 +60,7 @@ struct KernelCore::Impl { } void Initialize(KernelCore& kernel) { + global_object_list_container = std::make_unique(kernel); global_scheduler_context = std::make_unique(kernel); global_handle_table = std::make_unique(kernel); global_handle_table->Initialize(KHandleTable::MaxTableSize); @@ -76,7 +77,7 @@ struct KernelCore::Impl { // Initialize kernel memory and resources. InitializeSystemResourceLimit(kernel, system.CoreTiming()); InitializeMemoryLayout(); - InitializePageSlab(); + Init::InitializeKPageBufferSlabHeap(system); InitializeSchedulers(); InitializeSuspendThreads(); InitializePreemption(kernel); @@ -107,19 +108,6 @@ struct KernelCore::Impl { for (auto* server_port : server_ports_) { server_port->Close(); } - // Close all open server sessions. - std::unordered_set server_sessions_; - { - std::lock_guard lk(server_sessions_lock); - server_sessions_ = server_sessions; - server_sessions.clear(); - } - for (auto* server_session : server_sessions_) { - server_session->Close(); - } - - // Ensure that the object list container is finalized and properly shutdown. - object_list_container.Finalize(); // Ensures all service threads gracefully shutdown. ClearServiceThreads(); @@ -194,11 +182,15 @@ struct KernelCore::Impl { { std::lock_guard lk(registered_objects_lock); if (registered_objects.size()) { - LOG_WARNING(Kernel, "{} kernel objects were dangling on shutdown!", - registered_objects.size()); + LOG_DEBUG(Kernel, "{} kernel objects were dangling on shutdown!", + registered_objects.size()); registered_objects.clear(); } } + + // Ensure that the object list container is finalized and properly shutdown. + global_object_list_container->Finalize(); + global_object_list_container.reset(); } void InitializePhysicalCores() { @@ -291,15 +283,16 @@ struct KernelCore::Impl { // Gets the dummy KThread for the caller, allocating a new one if this is the first time KThread* GetHostDummyThread() { - auto make_thread = [this]() { - KThread* thread = KThread::Create(system.Kernel()); + auto initialize = [this](KThread* thread) { ASSERT(KThread::InitializeDummyThread(thread).IsSuccess()); thread->SetName(fmt::format("DummyThread:{}", GetHostThreadId())); return thread; }; - thread_local KThread* saved_thread = make_thread(); - return saved_thread; + thread_local auto raw_thread = KThread(system.Kernel()); + thread_local auto thread = initialize(&raw_thread); + + return thread; } /// Registers a CPU core thread by allocating a host thread ID for it @@ -660,22 +653,6 @@ struct KernelCore::Impl { time_phys_addr, time_size, "Time:SharedMemory"); } - void InitializePageSlab() { - // Allocate slab heaps - user_slab_heap_pages = - std::make_unique>(KSlabHeap::AllocationType::Guest); - - // TODO(ameerj): This should be derived, not hardcoded within the kernel - constexpr u64 user_slab_heap_size{0x3de000}; - // Reserve slab heaps - ASSERT( - system_resource_limit->Reserve(LimitableResource::PhysicalMemory, user_slab_heap_size)); - // Initialize slab heap - user_slab_heap_pages->Initialize( - system.DeviceMemory().GetPointer(Core::DramMemoryMap::SlabHeapBase), - user_slab_heap_size); - } - KClientPort* CreateNamedServicePort(std::string name) { auto search = service_interface_factory.find(name); if (search == service_interface_factory.end()) { @@ -713,7 +690,6 @@ struct KernelCore::Impl { } std::mutex server_ports_lock; - std::mutex server_sessions_lock; std::mutex registered_objects_lock; std::mutex registered_in_use_objects_lock; @@ -737,14 +713,13 @@ struct KernelCore::Impl { // stores all the objects in place. std::unique_ptr global_handle_table; - KAutoObjectWithListContainer object_list_container; + std::unique_ptr global_object_list_container; /// Map of named ports managed by the kernel, which can be retrieved using /// the ConnectToPort SVC. std::unordered_map service_interface_factory; NamedPortTable named_ports; std::unordered_set server_ports; - std::unordered_set server_sessions; std::unordered_set registered_objects; std::unordered_set registered_in_use_objects; @@ -756,7 +731,6 @@ struct KernelCore::Impl { // Kernel memory management std::unique_ptr memory_manager; - std::unique_ptr> user_slab_heap_pages; // Shared memory for services Kernel::KSharedMemory* hid_shared_mem{}; @@ -915,11 +889,11 @@ const Core::ExclusiveMonitor& KernelCore::GetExclusiveMonitor() const { } KAutoObjectWithListContainer& KernelCore::ObjectListContainer() { - return impl->object_list_container; + return *impl->global_object_list_container; } const KAutoObjectWithListContainer& KernelCore::ObjectListContainer() const { - return impl->object_list_container; + return *impl->global_object_list_container; } void KernelCore::InvalidateAllInstructionCaches() { @@ -949,16 +923,6 @@ KClientPort* KernelCore::CreateNamedServicePort(std::string name) { return impl->CreateNamedServicePort(std::move(name)); } -void KernelCore::RegisterServerSession(KServerSession* server_session) { - std::lock_guard lk(impl->server_sessions_lock); - impl->server_sessions.insert(server_session); -} - -void KernelCore::UnregisterServerSession(KServerSession* server_session) { - std::lock_guard lk(impl->server_sessions_lock); - impl->server_sessions.erase(server_session); -} - void KernelCore::RegisterKernelObject(KAutoObject* object) { std::lock_guard lk(impl->registered_objects_lock); impl->registered_objects.insert(object); @@ -1031,14 +995,6 @@ const KMemoryManager& KernelCore::MemoryManager() const { return *impl->memory_manager; } -KSlabHeap& KernelCore::GetUserSlabHeapPages() { - return *impl->user_slab_heap_pages; -} - -const KSlabHeap& KernelCore::GetUserSlabHeapPages() const { - return *impl->user_slab_heap_pages; -} - Kernel::KSharedMemory& KernelCore::GetHidSharedMem() { return *impl->hid_shared_mem; } diff --git a/src/core/hle/kernel/kernel.h b/src/core/hle/kernel/kernel.h index c1254b18d..7087bbda6 100644 --- a/src/core/hle/kernel/kernel.h +++ b/src/core/hle/kernel/kernel.h @@ -43,6 +43,7 @@ class KHandleTable; class KLinkedListNode; class KMemoryLayout; class KMemoryManager; +class KPageBuffer; class KPort; class KProcess; class KResourceLimit; @@ -52,6 +53,7 @@ class KSession; class KSharedMemory; class KSharedMemoryInfo; class KThread; +class KThreadLocalPage; class KTransferMemory; class KWorkerTaskManager; class KWritableEvent; @@ -194,14 +196,6 @@ public: /// Opens a port to a service previously registered with RegisterNamedService. KClientPort* CreateNamedServicePort(std::string name); - /// Registers a server session with the gobal emulation state, to be freed on shutdown. This is - /// necessary because we do not emulate processes for HLE sessions. - void RegisterServerSession(KServerSession* server_session); - - /// Unregisters a server session previously registered with RegisterServerSession when it was - /// destroyed during the current emulation session. - void UnregisterServerSession(KServerSession* server_session); - /// Registers all kernel objects with the global emulation state, this is purely for tracking /// leaks after emulation has been shutdown. void RegisterKernelObject(KAutoObject* object); @@ -239,12 +233,6 @@ public: /// Gets the virtual memory manager for the kernel. const KMemoryManager& MemoryManager() const; - /// Gets the slab heap allocated for user space pages. - KSlabHeap& GetUserSlabHeapPages(); - - /// Gets the slab heap allocated for user space pages. - const KSlabHeap& GetUserSlabHeapPages() const; - /// Gets the shared memory object for HID services. Kernel::KSharedMemory& GetHidSharedMem(); @@ -336,6 +324,10 @@ public: return slab_heap_container->writeable_event; } else if constexpr (std::is_same_v) { return slab_heap_container->code_memory; + } else if constexpr (std::is_same_v) { + return slab_heap_container->page_buffer; + } else if constexpr (std::is_same_v) { + return slab_heap_container->thread_local_page; } } @@ -397,6 +389,8 @@ private: KSlabHeap transfer_memory; KSlabHeap writeable_event; KSlabHeap code_memory; + KSlabHeap page_buffer; + KSlabHeap thread_local_page; }; std::unique_ptr slab_heap_container; diff --git a/src/core/hle/kernel/service_thread.cpp b/src/core/hle/kernel/service_thread.cpp index 4eb3a5988..52d25b837 100644 --- a/src/core/hle/kernel/service_thread.cpp +++ b/src/core/hle/kernel/service_thread.cpp @@ -49,12 +49,9 @@ ServiceThread::Impl::Impl(KernelCore& kernel, std::size_t num_threads, const std return; } + // Allocate a dummy guest thread for this host thread. kernel.RegisterHostThread(); - // Ensure the dummy thread allocated for this host thread is closed on exit. - auto* dummy_thread = kernel.GetCurrentEmuThread(); - SCOPE_EXIT({ dummy_thread->Close(); }); - while (true) { std::function task; diff --git a/src/core/hle/kernel/slab_helpers.h b/src/core/hle/kernel/slab_helpers.h index f1c11256e..dc1e48fc9 100644 --- a/src/core/hle/kernel/slab_helpers.h +++ b/src/core/hle/kernel/slab_helpers.h @@ -59,7 +59,7 @@ class KAutoObjectWithSlabHeapAndContainer : public Base { private: static Derived* Allocate(KernelCore& kernel) { - return kernel.SlabHeap().AllocateWithKernel(kernel); + return kernel.SlabHeap().Allocate(kernel); } static void Free(KernelCore& kernel, Derived* obj) { diff --git a/src/core/hle/kernel/svc_types.h b/src/core/hle/kernel/svc_types.h index 365e22e4e..b2e9ec092 100644 --- a/src/core/hle/kernel/svc_types.h +++ b/src/core/hle/kernel/svc_types.h @@ -96,4 +96,6 @@ constexpr inline s32 IdealCoreNoUpdate = -3; constexpr inline s32 LowestThreadPriority = 63; constexpr inline s32 HighestThreadPriority = 0; +constexpr inline size_t ThreadLocalRegionSize = 0x200; + } // namespace Kernel::Svc diff --git a/src/core/hle/service/am/am.cpp b/src/core/hle/service/am/am.cpp index 2f8e21568..420de3c54 100644 --- a/src/core/hle/service/am/am.cpp +++ b/src/core/hle/service/am/am.cpp @@ -980,7 +980,7 @@ private: LOG_DEBUG(Service_AM, "called"); IPC::RequestParser rp{ctx}; - applet->GetBroker().PushNormalDataFromGame(rp.PopIpcInterface()); + applet->GetBroker().PushNormalDataFromGame(rp.PopIpcInterface().lock()); IPC::ResponseBuilder rb{ctx, 2}; rb.Push(ResultSuccess); @@ -1007,7 +1007,7 @@ private: LOG_DEBUG(Service_AM, "called"); IPC::RequestParser rp{ctx}; - applet->GetBroker().PushInteractiveDataFromGame(rp.PopIpcInterface()); + applet->GetBroker().PushInteractiveDataFromGame(rp.PopIpcInterface().lock()); ASSERT(applet->IsInitialized()); applet->ExecuteInteractive(); diff --git a/src/core/hle/service/kernel_helpers.cpp b/src/core/hle/service/kernel_helpers.cpp index b8c2c6e51..ff0bbb788 100644 --- a/src/core/hle/service/kernel_helpers.cpp +++ b/src/core/hle/service/kernel_helpers.cpp @@ -17,21 +17,12 @@ namespace Service::KernelHelpers { ServiceContext::ServiceContext(Core::System& system_, std::string name_) : kernel(system_.Kernel()) { - - // Create a resource limit for the process. - const auto physical_memory_size = - kernel.MemoryManager().GetSize(Kernel::KMemoryManager::Pool::System); - auto* resource_limit = Kernel::CreateResourceLimitForProcess(system_, physical_memory_size); - // Create the process. process = Kernel::KProcess::Create(kernel); ASSERT(Kernel::KProcess::Initialize(process, system_, std::move(name_), Kernel::KProcess::ProcessType::KernelInternal, - resource_limit) + kernel.GetSystemResourceLimit()) .IsSuccess()); - - // Close reference to our resource limit, as the process opens one. - resource_limit->Close(); } ServiceContext::~ServiceContext() { diff --git a/src/core/hle/service/sm/sm.cpp b/src/core/hle/service/sm/sm.cpp index eaa172595..695a1faa6 100644 --- a/src/core/hle/service/sm/sm.cpp +++ b/src/core/hle/service/sm/sm.cpp @@ -81,6 +81,8 @@ ResultVal ServiceManager::GetServicePort(const std::string& name } auto* port = Kernel::KPort::Create(kernel); + SCOPE_EXIT({ port->Close(); }); + port->Initialize(ServerSessionCountMax, false, name); auto handler = it->second; port->GetServerPort().SetSessionHandler(std::move(handler));