#include "xla/codegen/emitters/kernel_arguments.h"
/* Copyright 2018 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_GPU_GPU_CONSTANTS_H_
#define XLA_SERVICE_GPU_GPU_CONSTANTS_H_

#include <cstdint>

namespace xla {
namespace gpu {

// Minimum alignment for buffers passed as incoming arguments by TensorFlow.
//
// kEntryParameterAlignBytes is equal to EIGEN_MAX_ALIGN_BYTES, though including
// Eigen headers here to get that symbol may not be a good idea.
// EIGEN_MAX_ALIGN_BYTES may differ between CUDA-enabled builds vs CUDA-disabled
// builds and we don't want the IR generated by XLA:GPU to depend on that.
//
// TODO(b/111767313): Consider raising EIGEN_MAX_ALIGN_BYTES if it helps.
inline constexpr int64_t kEntryParameterAlignBytes = 16;

// Minimum alignment for buffers allocated by XLA: the temp buffers and the live
// out (result) buffers.
//
// cudnn requires 128-bit (16-byte) alignment for TensorCore operations, but
// says that 1024-bit (128-byte) alignment "may deliver better performance".
// https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#tensor-ops-guidelines-for-dl-compiler
//
// cublas requires 256-byte alignment as of v12.9.1.4.
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cublas-release-12-9
inline constexpr int64_t kXlaAllocatedBufferAlignBytes = 256;

// Minimum alignment for constant buffers.
inline constexpr int64_t kConstantBufferAlignBytes =
    kXlaAllocatedBufferAlignBytes;

inline emitters::KernelArguments::BufferAlignment GetDefaultBufferAlignment() {
  emitters::KernelArguments::BufferAlignment buffer_alignment;
  buffer_alignment.entry_parameter_align_bytes = kEntryParameterAlignBytes;
  buffer_alignment.xla_allocated_buffer_align_bytes =
      kXlaAllocatedBufferAlignBytes;
  buffer_alignment.constant_buffer_align_bytes = kConstantBufferAlignBytes;

  return buffer_alignment;
}

}  // namespace gpu
}  // namespace xla

#endif  // XLA_SERVICE_GPU_GPU_CONSTANTS_H_
