/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=8 sts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

/* mfbt maths algorithms. */

#ifndef mozilla_MathAlgorithms_h
#define mozilla_MathAlgorithms_h

#include "mozilla/Assertions.h"

#include <algorithm>
#include <cmath>
#include <climits>
#include <cstdint>
#include <type_traits>

namespace mozilla {

namespace detail {

template <typename T, typename = void>
struct AbsReturnType;

template <typename T>
struct AbsReturnType<
    T, std::enable_if_t<std::is_integral_v<T> && std::is_signed_v<T>>> {
  using Type = std::make_unsigned_t<T>;
};

template <typename T>
struct AbsReturnType<T, std::enable_if_t<std::is_floating_point_v<T>>> {
  using Type = T;
};

}  // namespace detail

template <typename T>
inline constexpr typename detail::AbsReturnType<T>::Type Abs(const T aValue) {
  using ReturnType = typename detail::AbsReturnType<T>::Type;
  return aValue >= 0 ? ReturnType(aValue) : ~ReturnType(aValue) + 1;
}

template <>
inline float Abs<float>(const float aFloat) {
  return std::fabs(aFloat);
}

template <>
inline double Abs<double>(const double aDouble) {
  return std::fabs(aDouble);
}

template <>
inline long double Abs<long double>(const long double aLongDouble) {
  return std::fabs(aLongDouble);
}

}  // namespace mozilla

namespace mozilla {

namespace detail {

// FIXME: use std::count[lr]_zero once we move to C++20

#if defined(__clang__) || defined(__GNUC__)

#  if defined(__clang__)
#    if !__has_builtin(__builtin_ctz) || !__has_builtin(__builtin_clz)
#      error "A clang providing __builtin_c[lt]z is required to build"
#    endif
#  else
// gcc has had __builtin_clz and friends since 3.4: no need to check.
#  endif

constexpr uint_fast8_t CountLeadingZeroes32(uint32_t aValue) {
  return static_cast<uint_fast8_t>(__builtin_clz(aValue));
}

constexpr uint_fast8_t CountTrailingZeroes32(uint32_t aValue) {
  return static_cast<uint_fast8_t>(__builtin_ctz(aValue));
}

constexpr uint_fast8_t CountPopulation32(uint32_t aValue) {
  return static_cast<uint_fast8_t>(__builtin_popcount(aValue));
}

constexpr uint_fast8_t CountPopulation64(uint64_t aValue) {
  return static_cast<uint_fast8_t>(__builtin_popcountll(aValue));
}

constexpr uint_fast8_t CountLeadingZeroes64(uint64_t aValue) {
  return static_cast<uint_fast8_t>(__builtin_clzll(aValue));
}

constexpr uint_fast8_t CountTrailingZeroes64(uint64_t aValue) {
  return static_cast<uint_fast8_t>(__builtin_ctzll(aValue));
}

#else
#  error "Implement these!"
constexpr uint_fast8_t CountLeadingZeroes32(uint32_t aValue) = delete;
constexpr uint_fast8_t CountTrailingZeroes32(uint32_t aValue) = delete;
constexpr uint_fast8_t CountPopulation32(uint32_t aValue) = delete;
constexpr uint_fast8_t CountPopulation64(uint64_t aValue) = delete;
constexpr uint_fast8_t CountLeadingZeroes64(uint64_t aValue) = delete;
constexpr uint_fast8_t CountTrailingZeroes64(uint64_t aValue) = delete;
#endif

}  // namespace detail

/**
 * Compute the number of high-order zero bits in the NON-ZERO number |aValue|.
 * That is, looking at the bitwise representation of the number, with the
 * highest- valued bits at the start, return the number of zeroes before the
 * first one is observed.
 *
 * CountLeadingZeroes32(0xF0FF1000) is 0;
 * CountLeadingZeroes32(0x7F8F0001) is 1;
 * CountLeadingZeroes32(0x3FFF0100) is 2;
 * CountLeadingZeroes32(0x1FF50010) is 3; and so on.
 */
constexpr uint_fast8_t CountLeadingZeroes32(uint32_t aValue) {
  MOZ_ASSERT(aValue != 0);
  return detail::CountLeadingZeroes32(aValue);
}

/**
 * Compute the number of low-order zero bits in the NON-ZERO number |aValue|.
 * That is, looking at the bitwise representation of the number, with the
 * lowest- valued bits at the start, return the number of zeroes before the
 * first one is observed.
 *
 * CountTrailingZeroes32(0x0100FFFF) is 0;
 * CountTrailingZeroes32(0x7000FFFE) is 1;
 * CountTrailingZeroes32(0x0080FFFC) is 2;
 * CountTrailingZeroes32(0x0080FFF8) is 3; and so on.
 */
constexpr uint_fast8_t CountTrailingZeroes32(uint32_t aValue) {
  MOZ_ASSERT(aValue != 0);
  return detail::CountTrailingZeroes32(aValue);
}

/**
 * Compute the number of one bits in the number |aValue|,
 */
constexpr uint_fast8_t CountPopulation32(uint32_t aValue) {
  return detail::CountPopulation32(aValue);
}

/** Analogous to CountPopulation32, but for 64-bit numbers */
constexpr uint_fast8_t CountPopulation64(uint64_t aValue) {
  return detail::CountPopulation64(aValue);
}

/** Analogous to CountLeadingZeroes32, but for 64-bit numbers. */
constexpr uint_fast8_t CountLeadingZeroes64(uint64_t aValue) {
  MOZ_ASSERT(aValue != 0);
  return detail::CountLeadingZeroes64(aValue);
}

/** Analogous to CountTrailingZeroes32, but for 64-bit numbers. */
constexpr uint_fast8_t CountTrailingZeroes64(uint64_t aValue) {
  MOZ_ASSERT(aValue != 0);
  return detail::CountTrailingZeroes64(aValue);
}

namespace detail {

template <typename T, size_t Size = sizeof(T)>
class CeilingLog2;

template <typename T>
class CeilingLog2<T, 4> {
 public:
  static constexpr uint_fast8_t compute(const T aValue) {
    // Check for <= 1 to avoid the == 0 undefined case.
    return aValue <= 1 ? 0u : 32u - CountLeadingZeroes32(aValue - 1);
  }
};

template <typename T>
class CeilingLog2<T, 8> {
 public:
  static constexpr uint_fast8_t compute(const T aValue) {
    // Check for <= 1 to avoid the == 0 undefined case.
    return aValue <= 1 ? 0u : 64u - CountLeadingZeroes64(aValue - 1);
  }
};

}  // namespace detail

/**
 * Compute the log of the least power of 2 greater than or equal to |aValue|.
 *
 * CeilingLog2(0..1) is 0;
 * CeilingLog2(2) is 1;
 * CeilingLog2(3..4) is 2;
 * CeilingLog2(5..8) is 3;
 * CeilingLog2(9..16) is 4; and so on.
 */
template <typename T>
constexpr uint_fast8_t CeilingLog2(const T aValue) {
  return detail::CeilingLog2<T>::compute(aValue);
}

/** A CeilingLog2 variant that accepts only size_t. */
constexpr uint_fast8_t CeilingLog2Size(size_t aValue) {
  return CeilingLog2(aValue);
}

/**
 * Compute the bit position of the most significant bit set in
 * |aValue|. Requires that |aValue| is non-zero.
 */
template <typename T>
constexpr uint_fast8_t FindMostSignificantBit(T aValue) {
  static_assert(sizeof(T) <= 8);
  static_assert(std::is_integral_v<T>);
  MOZ_ASSERT(aValue != 0);
  // This casts to 32-bits
  if constexpr (sizeof(T) <= 4) {
    return 31u - CountLeadingZeroes32(aValue);
  }
  // This doesn't
  if constexpr (sizeof(T) == 8) {
    return 63u - CountLeadingZeroes64(aValue);
  }
}

/**
 * Compute the log of the greatest power of 2 less than or equal to |aValue|.
 *
 * FloorLog2(0..1) is 0;
 * FloorLog2(2..3) is 1;
 * FloorLog2(4..7) is 2;
 * FloorLog2(8..15) is 3; and so on.
 */
template <typename T>
constexpr uint_fast8_t FloorLog2(const T aValue) {
  return FindMostSignificantBit(aValue | 1);
}

/** A FloorLog2 variant that accepts only size_t. */
constexpr uint_fast8_t FloorLog2Size(size_t aValue) {
  return FloorLog2(aValue);
}

/*
 * Compute the smallest power of 2 greater than or equal to |x|.  |x| must not
 * be so great that the computed value would overflow |size_t|.
 */
constexpr size_t RoundUpPow2(size_t aValue) {
  MOZ_ASSERT(aValue <= (size_t(1) << (sizeof(size_t) * CHAR_BIT - 1)),
             "can't round up -- will overflow!");
  return size_t(1) << CeilingLog2(aValue);
}

/**
 * Rotates the bits of the given value left by the amount of the shift width.
 */
template <typename T>
MOZ_NO_SANITIZE_UNSIGNED_OVERFLOW constexpr T RotateLeft(const T aValue,
                                                         uint_fast8_t aShift) {
  static_assert(std::is_unsigned_v<T>, "Rotates require unsigned values");

  MOZ_ASSERT(aShift < sizeof(T) * CHAR_BIT, "Shift value is too large!");
  MOZ_ASSERT(aShift > 0,
             "Rotation by value length is undefined behavior, but compilers "
             "do not currently fold a test into the rotate instruction. "
             "Please remove this restriction when compilers optimize the "
             "zero case (http://blog.regehr.org/archives/1063).");

  return (aValue << aShift) | (aValue >> (sizeof(T) * CHAR_BIT - aShift));
}

/**
 * Rotates the bits of the given value right by the amount of the shift width.
 */
template <typename T>
MOZ_NO_SANITIZE_UNSIGNED_OVERFLOW constexpr T RotateRight(const T aValue,
                                                          uint_fast8_t aShift) {
  static_assert(std::is_unsigned_v<T>, "Rotates require unsigned values");

  MOZ_ASSERT(aShift < sizeof(T) * CHAR_BIT, "Shift value is too large!");
  MOZ_ASSERT(aShift > 0,
             "Rotation by value length is undefined behavior, but compilers "
             "do not currently fold a test into the rotate instruction. "
             "Please remove this restriction when compilers optimize the "
             "zero case (http://blog.regehr.org/archives/1063).");

  return (aValue >> aShift) | (aValue << (sizeof(T) * CHAR_BIT - aShift));
}

/**
 * Returns true if |x| is a power of two.
 * Zero is not an integer power of two. (-Inf is not an integer)
 */
template <typename T>
constexpr bool IsPowerOfTwo(T x) {
  static_assert(std::is_unsigned_v<T>, "IsPowerOfTwo requires unsigned values");
  return x && (x & (x - 1)) == 0;
}

template <typename T>
constexpr uint_fast8_t CountTrailingZeroes(T aValue) {
  static_assert(sizeof(T) <= 8);
  static_assert(std::is_integral_v<T>);
  // This casts to 32-bits
  if constexpr (sizeof(T) <= 4) {
    return CountTrailingZeroes32(aValue);
  }
  // This doesn't
  if constexpr (sizeof(T) == 8) {
    return CountTrailingZeroes64(aValue);
  }
}

// Greatest Common Divisor, from
// https://en.wikipedia.org/wiki/Binary_GCD_algorithm#Implementation
template <typename T>
MOZ_ALWAYS_INLINE T GCD(T aA, T aB) {
  static_assert(std::is_integral_v<T>);

  MOZ_ASSERT(aA >= 0);
  MOZ_ASSERT(aB >= 0);

  if (aA == 0) {
    return aB;
  }
  if (aB == 0) {
    return aA;
  }

  T az = CountTrailingZeroes(aA);
  T bz = CountTrailingZeroes(aB);
  T shift = std::min<T>(az, bz);
  aA >>= az;
  aB >>= bz;

  while (aA != 0) {
    if constexpr (!std::is_signed_v<T>) {
      if (aA < aB) {
        std::swap(aA, aB);
      }
    }
    T diff = aA - aB;
    if constexpr (std::is_signed_v<T>) {
      aB = std::min<T>(aA, aB);
    }
    if constexpr (std::is_signed_v<T>) {
      aA = std::abs(diff);
    } else {
      aA = diff;
    }
    if (aA) {
      aA >>= CountTrailingZeroes(aA);
    }
  }

  return aB << shift;
}

} /* namespace mozilla */

#endif /* mozilla_MathAlgorithms_h */
