#ifndef POOL_ALLOCATOR_H
#define POOL_ALLOCATOR_H

#include <algorithm>
#include <cstddef>
#include <android/log.h>
#include <sys/mman.h>

#include "StackLinkedList.h"

#if defined(HWASAN_ADVANCED) && defined(HWASAN_MALLOC)
# error Incompatible HWAsan flags are passed.
#endif

#if defined(HWASAN_ADVANCED)
#include "sanitizer/hwasan_interface.h"
#endif

template<class T>
class Allocator {
public:
    Allocator() {}
    virtual ~Allocator() {}

    // Must be called before calling Allocate/Free methods.
    virtual void Init() = 0;

    // Returns a new object.
    virtual T* Allocate() = 0;

    // Takes back the given object.
    virtual void Free(T* ptr) = 0;
};

// A  pool allocator that pre-allocates a block of memory and returns fixed-size objects
// from that memory for faster "allocation".
template<class T>
class PoolAllocator : public Allocator<T> {
public:
    PoolAllocator(const std::size_t numObjects);

    // Removes all pre-allocated memory for the pool by Init.
    ~PoolAllocator();

    // Initializes the internal memory of the allocator and prepares the objects in the pool.
    // Must be called exactly once before using the allocator.
    virtual void Init() override;

    // Returns one object from the pool, or nullptr if there are no objects
    // remaining in the pool.
    virtual T* Allocate() override;

    // Puts the allocated memory back to the pool.
    virtual void Free(T* ptr) override;

    // Returns the maximum number of elements this allocator can allocate. It can be
    // more than the numObjects passed in the constructor.
    std::size_t GetCapacity() {
        return m_totalSize / m_chunkSize;
    }

private:
    struct FreeHeader {
    };
    using Node = typename StackLinkedList<FreeHeader>::Node;
    StackLinkedList<FreeHeader> m_freeList;

#ifndef HWASAN_MALLOC
    // Pointer to the start address of the allocated buffer.
    void* m_start_ptr = nullptr;
#endif
    // Number of bytes used for each object in the pool.
    // If the user asks for a pool of very small objects, we might have to bump
    // it up (e.g., to be a round number, or up to minimum required by HWAsan).
    std::size_t m_chunkSize;

    // Total size of the pool in bytes.
    std::size_t m_totalSize;
};

// Rounds |size| to the next multiple of |block_size|.
static size_t roundUpTo(size_t size, size_t block_size) {
    return (size + block_size - 1) / block_size * block_size;
}

template<class T>
PoolAllocator<T>::PoolAllocator(const std::size_t numObjects)
    : m_totalSize(roundUpTo(numObjects * sizeof(T), 4096)) {
    std::size_t chunkSize = std::max(sizeof(T), sizeof(Node*));

#ifdef HWASAN_ADVANCED
    // HWAsan requires each allocation to be a multiple of 16 bytes.
    chunkSize = roundUpTo(chunkSize, 16);
#endif

    this->m_chunkSize = chunkSize;
}


#if defined(HWASAN_ADVANCED)
static inline uint8_t GenerateHwasanTag() {
    static uint8_t tag = 16;
    if (tag < 16) {
        tag = 16;  // Don't use the lower 16 entries.
    }
    else {
        tag++;
    }
    return tag;
}
#endif

template<class T>
void PoolAllocator<T>::Init() {
#ifndef HWASAN_MALLOC
    void* ptr = mmap(0, m_totalSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
    m_start_ptr = ptr;
#endif

    // Create a linked-list with all free positions
    const std::size_t nChunks = m_totalSize / m_chunkSize;

    for (int i = 0; i < nChunks; ++i) {
#ifdef HWASAN_MALLOC
        std::size_t address = (std::size_t)malloc(m_chunkSize);
#else
        std::size_t address = ((std::size_t)m_start_ptr) + i * m_chunkSize;
#endif

#if defined(HWASAN_ADVANCED)
        // This pre-tags all objects in the pool (and all pointers) so that the
        // allocate() method does not need to worry about tagging them.
        //
        // Memory allocators that tag pointers/memory in the allocate() method
        // do not need to re-tag them here.
        uint8_t tag = GenerateHwasanTag();
        __hwasan_tag_memory((void*)address, tag, m_chunkSize);
        address = (std::size_t)__hwasan_tag_pointer((void*)address, tag);
#endif

        m_freeList.push((Node*)address);
    }
}

template<class T>
PoolAllocator<T>::~PoolAllocator() {
#ifdef HWASAN_MALLOC
    while (true) {
        Node* ptr = m_freeList.pop();
        if (ptr == nullptr) break;
        free(reinterpret_cast<void*>(ptr));
    }
#else
    if (m_start_ptr != nullptr) {
        munmap(m_start_ptr, m_totalSize);
    }
#endif
}

template<class T>
T* PoolAllocator<T>::Allocate() {
    Node* freePosition = m_freeList.pop();

#if defined(HWASAN_ADVANCED)
    // This Pool Allocator does not allocate memory on an Allocate() call, so it does not need
    // to tag the pointers/memory here either. It operates on pre-allocated, pre-tagged pointers/memory.
    // See the Init() method which pre-tags all pointers/memory in the pool.
    //
    // If your allocator works in a different way (e.g., it will preallocate a memory,
    // but it does not know which parts of memory will be used for which kinds of objects apriori,
    // but instead it carves smaller memory blocks on-the-fly as memory allocation requests are
    // received), then you should do the tagging here. For instance:
    //
    //    __hwasan_tag_memory(ptr, tag, m_chunkSize);
    //    ptr = (T*)__hwasan_tag_pointer((void*)ptr, tag);
    //
    // Make sure you align your pointers to 16 bytes, and use 16-byte chunks, before you call these APIs.
#endif

    return reinterpret_cast<T*>(freePosition);
}

template<class T>
void PoolAllocator<T>::Free(T* ptr) {
#ifdef HWASAN_MALLOC
    // Free and reallocate so that HAWAsan creates tags for the new pointer/memory.
    free(reinterpret_cast<void*>(ptr));
    ptr = reinterpret_cast<T*>(malloc(m_chunkSize));
#elif defined(HWASAN_ADVANCED)
    // This means we have to check if the tag at *ptr is the same as the tag of ptr.
    // However, we'd like to do this without actually accessing the memory.
    intptr_t offset = __hwasan_test_shadow(ptr, m_chunkSize);
    if (offset != -1) {
        // This looks like a double-free or free before allocate, starting at address 'ptr+offset'.
        __android_log_print(ANDROID_LOG_FATAL, "hwasan", "PoolAllocator::Free() called with invalid pointer %p", ptr);

        // We don't use assert here to make sure HWAsan prints a nicer error message for us.
        uint8_t* byte_ptr = reinterpret_cast<uint8_t*>(ptr) + offset;
        *byte_ptr = 0xFF;  // Write arbitrary value to trigger HWAsan error.
    }

    // Set a new tag as we return the memory back into the pool.
    uint8_t tag = GenerateHwasanTag();
    __hwasan_tag_memory(__hwasan_tag_pointer(ptr, 0), tag, m_chunkSize);
    ptr = (T*)__hwasan_tag_pointer((void*)ptr, tag);
#endif

    m_freeList.push(reinterpret_cast<Node*>(ptr));
}

#endif // POOL_ALLOCATOR_H