// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project

#ifndef KOKKOS_IMPL_PUBLIC_INCLUDE
#define KOKKOS_IMPL_PUBLIC_INCLUDE
#endif

#include <Kokkos_Macros.hpp>
#ifdef KOKKOS_ENABLE_EXPERIMENTAL_CXX20_MODULES
import kokkos.core;
#else
#include <Kokkos_Core.hpp>
#endif

#include <Serial/Kokkos_Serial.hpp>
#include <impl/Kokkos_CheckUsage.hpp>
#include <impl/Kokkos_Error.hpp>
#include <impl/Kokkos_ExecSpaceManager.hpp>
#include <impl/Kokkos_SharedAlloc.hpp>
#include <impl/Kokkos_Traits.hpp>

#include <cstdlib>
#include <iostream>
#include <sstream>

/*--------------------------------------------------------------------------*/

namespace Kokkos {
namespace Impl {

std::vector<SerialInternal*> SerialInternal::all_instances;
std::mutex SerialInternal::all_instances_mutex;

HostSharedPtr<SerialInternal> SerialInternal::default_instance;

SerialInternal::SerialInternal() {
  Impl::SharedAllocationRecord<void, void>::tracking_enable();

  // guard pushing to all_instances
  {
    std::scoped_lock lock(all_instances_mutex);
    all_instances.push_back(this);
  }
}

SerialInternal::~SerialInternal() {
  if (m_thread_team_data.scratch_buffer()) {
    m_thread_team_data.disband_team();
    m_thread_team_data.disband_pool();

    Kokkos::HostSpace space;

    space.deallocate(m_thread_team_data.scratch_buffer(),
                     m_thread_team_data.scratch_bytes());

    m_thread_team_data.scratch_assign(nullptr, 0, 0, 0, 0, 0);
  }

  // guard erasing from all_instances
  {
    std::scoped_lock lock(all_instances_mutex);
    auto it = std::find(all_instances.begin(), all_instances.end(), this);
    if (it == all_instances.end())
      Kokkos::abort(
          "Execution space instance to be removed couldn't be found!");
    std::swap(*it, all_instances.back());
    all_instances.pop_back();
  }
}

// Resize thread team data scratch memory
void SerialInternal::resize_thread_team_data(size_t pool_reduce_bytes,
                                             size_t team_reduce_bytes,
                                             size_t team_shared_bytes,
                                             size_t thread_local_bytes) {
  if (pool_reduce_bytes < 512) pool_reduce_bytes = 512;
  if (team_reduce_bytes < 512) team_reduce_bytes = 512;

  const size_t old_pool_reduce  = m_thread_team_data.pool_reduce_bytes();
  const size_t old_team_reduce  = m_thread_team_data.team_reduce_bytes();
  const size_t old_team_shared  = m_thread_team_data.team_shared_bytes();
  const size_t old_thread_local = m_thread_team_data.thread_local_bytes();
  const size_t old_alloc_bytes  = m_thread_team_data.scratch_bytes();

  // Allocate if any of the old allocation is tool small:

  const bool allocate = (old_pool_reduce < pool_reduce_bytes) ||
                        (old_team_reduce < team_reduce_bytes) ||
                        (old_team_shared < team_shared_bytes) ||
                        (old_thread_local < thread_local_bytes);

  if (allocate) {
    Kokkos::HostSpace space;

    if (old_alloc_bytes) {
      m_thread_team_data.disband_team();
      m_thread_team_data.disband_pool();

      // impl_deallocate doesn't fence which we try to avoid here since that
      // interferes with the using the m_instance_mutex for ensuring proper
      // kernel enqueuing
      space.impl_deallocate("Kokkos::Serial::scratch_mem",
                            m_thread_team_data.scratch_buffer(),
                            m_thread_team_data.scratch_bytes());
    }

    if (pool_reduce_bytes < old_pool_reduce) {
      pool_reduce_bytes = old_pool_reduce;
    }
    if (team_reduce_bytes < old_team_reduce) {
      team_reduce_bytes = old_team_reduce;
    }
    if (team_shared_bytes < old_team_shared) {
      team_shared_bytes = old_team_shared;
    }
    if (thread_local_bytes < old_thread_local) {
      thread_local_bytes = old_thread_local;
    }

    const size_t alloc_bytes =
        HostThreadTeamData::scratch_size(pool_reduce_bytes, team_reduce_bytes,
                                         team_shared_bytes, thread_local_bytes);

    void* ptr = space.allocate("Kokkos::Serial::scratch_mem", alloc_bytes);

    m_thread_team_data.scratch_assign(static_cast<char*>(ptr), alloc_bytes,
                                      pool_reduce_bytes, team_reduce_bytes,
                                      team_shared_bytes, thread_local_bytes);

    HostThreadTeamData* pool[1] = {&m_thread_team_data};

    m_thread_team_data.organize_pool(pool, 1);
    m_thread_team_data.organize_team(1);
  }
}
}  // namespace Impl

Serial::~Serial() {
  Impl::check_execution_space_destructor_precondition(name());
}

Serial::Serial()
    : m_space_instance(
          (Impl::check_execution_space_constructor_precondition(name()),
           Impl::SerialInternal::default_instance)) {}

Serial::Serial(NewInstance)
    : m_space_instance(
          (Impl::check_execution_space_constructor_precondition(name()),
           new Impl::SerialInternal)) {}

void Serial::print_configuration(std::ostream& os, bool /*verbose*/) const {
  os << "Host Serial Execution Space:\n";
  os << "  KOKKOS_ENABLE_SERIAL: yes\n";

#ifdef KOKKOS_ENABLE_ATOMICS_BYPASS
  os << "Kokkos atomics disabled\n";
#endif

  os << "\nSerial Runtime Configuration:\n";
}

void Serial::impl_initialize(InitializationSettings const&) {
  Impl::SerialInternal::default_instance =
      Impl::HostSharedPtr(new Impl::SerialInternal);
}

void Serial::impl_finalize() {
  Impl::SerialInternal::default_instance = nullptr;
}

const char* Serial::name() { return "Serial"; }

namespace Impl {

int g_serial_space_factory_initialized =
    initialize_space_factory<Serial>("100_Serial");

}  // namespace Impl

}  // namespace Kokkos
