//===----------------------------------------------------------------------===//
//
// Part of the CUDA Toolkit, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef _CUDA_ATOMIC
#define _CUDA_ATOMIC

#ifndef __CUDACC_RTC__
    #include <atomic>
    static_assert(ATOMIC_BOOL_LOCK_FREE == 2, "");
    static_assert(ATOMIC_CHAR_LOCK_FREE == 2, "");
    static_assert(ATOMIC_CHAR16_T_LOCK_FREE == 2, "");
    static_assert(ATOMIC_CHAR32_T_LOCK_FREE == 2, "");
    static_assert(ATOMIC_WCHAR_T_LOCK_FREE == 2, "");
    static_assert(ATOMIC_SHORT_LOCK_FREE == 2, "");
    static_assert(ATOMIC_INT_LOCK_FREE == 2, "");
    static_assert(ATOMIC_LONG_LOCK_FREE == 2, "");
    static_assert(ATOMIC_LLONG_LOCK_FREE == 2, "");
    static_assert(ATOMIC_POINTER_LOCK_FREE == 2, "");
    #undef ATOMIC_BOOL_LOCK_FREE
    #undef ATOMIC_BOOL_LOCK_FREE
    #undef ATOMIC_CHAR_LOCK_FREE
    #undef ATOMIC_CHAR16_T_LOCK_FREE
    #undef ATOMIC_CHAR32_T_LOCK_FREE
    #undef ATOMIC_WCHAR_T_LOCK_FREE
    #undef ATOMIC_SHORT_LOCK_FREE
    #undef ATOMIC_INT_LOCK_FREE
    #undef ATOMIC_LONG_LOCK_FREE
    #undef ATOMIC_LLONG_LOCK_FREE
    #undef ATOMIC_POINTER_LOCK_FREE
    #undef ATOMIC_FLAG_INIT
    #undef ATOMIC_VAR_INIT
#endif //__CUDACC_RTC__

#include "cstddef"
#include "cstdint"
#include "type_traits"
#include "version"

#include "detail/__config"

#include "detail/__pragma_push"

#include "detail/__atomic"
#include "detail/__threading_support"

#undef _LIBCUDACXX_HAS_GCC_ATOMIC_IMP
#undef _LIBCUDACXX_HAS_C_ATOMIC_IMP

#include "detail/libcxx/include/atomic"

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA

using memory_order = std::memory_order;

constexpr memory_order memory_order_relaxed = std::memory_order_relaxed;
constexpr memory_order memory_order_consume = std::memory_order_consume;
constexpr memory_order memory_order_acquire = std::memory_order_acquire;
constexpr memory_order memory_order_release = std::memory_order_release;
constexpr memory_order memory_order_acq_rel = std::memory_order_acq_rel;
constexpr memory_order memory_order_seq_cst = std::memory_order_seq_cst;

// atomic<T>

template <class _Tp, thread_scope _Sco = thread_scope_system>
struct atomic
    : public std::__atomic_base<_Tp, _Sco>
{
    typedef std::__atomic_base<_Tp, _Sco> __base;
    __host__ __device__
    atomic() noexcept : __base() {}
    __host__ __device__
    constexpr atomic(_Tp __d) noexcept : __base(__d) {}

    __host__ __device__
    _Tp operator=(_Tp __d) volatile noexcept
        {__base::store(__d); return __d;}
    __host__ __device__
    _Tp operator=(_Tp __d) noexcept
        {__base::store(__d); return __d;}

    __host__ __device__
    _Tp fetch_max(const _Tp & __op, memory_order __m = memory_order_seq_cst) volatile noexcept
    {
        return detail::__atomic_fetch_max_cuda(&this->__a_.__a_value, __op,
                                              __m, detail::__scope_tag<_Sco>());
    }

    __host__ __device__
    _Tp fetch_min(const _Tp & __op, memory_order __m = memory_order_seq_cst) volatile noexcept
    {
        return detail::__atomic_fetch_min_cuda(&this->__a_.__a_value, __op,
                                              __m, detail::__scope_tag<_Sco>());
    }
};

// atomic<T*>

template <class _Tp, thread_scope _Sco>
struct atomic<_Tp*, _Sco>
    : public std::__atomic_base<_Tp*, _Sco>
{
    typedef std::__atomic_base<_Tp*, _Sco> __base;
    __host__ __device__
    atomic() noexcept : __base() {}
    __host__ __device__
    constexpr atomic(_Tp* __d) noexcept : __base(__d) {}

    __host__ __device__
    _Tp* operator=(_Tp* __d) volatile noexcept
        {__base::store(__d); return __d;}
    __host__ __device__
    _Tp* operator=(_Tp* __d) noexcept
        {__base::store(__d); return __d;}

    __host__ __device__
    _Tp* fetch_add(ptrdiff_t __op, memory_order __m = memory_order_seq_cst)
                                                                        volatile noexcept
        {return __cxx_atomic_fetch_add(&this->__a_, __op, __m);}
    __host__ __device__
    _Tp* fetch_add(ptrdiff_t __op, memory_order __m = memory_order_seq_cst) noexcept
        {return __cxx_atomic_fetch_add(&this->__a_, __op, __m);}
    __host__ __device__
    _Tp* fetch_sub(ptrdiff_t __op, memory_order __m = memory_order_seq_cst)
                                                                        volatile noexcept
        {return __cxx_atomic_fetch_sub(&this->__a_, __op, __m);}
    __host__ __device__
    _Tp* fetch_sub(ptrdiff_t __op, memory_order __m = memory_order_seq_cst) noexcept
        {return __cxx_atomic_fetch_sub(&this->__a_, __op, __m);}

    __host__ __device__
    _Tp* operator++(int) volatile noexcept            {return fetch_add(1);}
    __host__ __device__
    _Tp* operator++(int) noexcept                     {return fetch_add(1);}
    __host__ __device__
    _Tp* operator--(int) volatile noexcept            {return fetch_sub(1);}
    __host__ __device__
    _Tp* operator--(int) noexcept                     {return fetch_sub(1);}
    __host__ __device__
    _Tp* operator++() volatile noexcept               {return fetch_add(1) + 1;}
    __host__ __device__
    _Tp* operator++() noexcept                        {return fetch_add(1) + 1;}
    __host__ __device__
    _Tp* operator--() volatile noexcept               {return fetch_sub(1) - 1;}
    __host__ __device__
    _Tp* operator--() noexcept                        {return fetch_sub(1) - 1;}
    __host__ __device__
    _Tp* operator+=(ptrdiff_t __op) volatile noexcept {return fetch_add(__op) + __op;}
    __host__ __device__
    _Tp* operator+=(ptrdiff_t __op) noexcept          {return fetch_add(__op) + __op;}
    __host__ __device__
    _Tp* operator-=(ptrdiff_t __op) volatile noexcept {return fetch_sub(__op) - __op;}
    __host__ __device__
    _Tp* operator-=(ptrdiff_t __op) noexcept          {return fetch_sub(__op) - __op;}
};

inline __host__ __device__ void atomic_thread_fence(memory_order __m, thread_scope _Scope = thread_scope_system) {
#ifdef __CUDA_ARCH__
    switch(_Scope) {
    case thread_scope_system:
        detail::__atomic_thread_fence_cuda((int)__m, detail::__thread_scope_system_tag());
        break;
    case thread_scope_device:
        detail::__atomic_thread_fence_cuda((int)__m, detail::__thread_scope_device_tag());
        break;
    case thread_scope_block:
        detail::__atomic_thread_fence_cuda((int)__m, detail::__thread_scope_block_tag());
        break;
    }
#else
    ::std::atomic_thread_fence((::std::memory_order)__m);
#endif
}

inline __host__ __device__ void atomic_signal_fence(memory_order __m) {
#ifdef __CUDA_ARCH__
    detail::__atomic_signal_fence_cuda((int)__m);
#else
    ::std::atomic_signal_fence((::std::memory_order)__m);
#endif
}

_LIBCUDACXX_END_NAMESPACE_CUDA

#include "detail/__pragma_pop"

#endif //_CUDA_ATOMIC
