/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_CODEGEN_MATH_LIB_H_
#define XLA_CODEGEN_MATH_LIB_H_

#include <cstddef>
#include <memory>
#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Module.h"
#include "xla/xla_data.pb.h"

namespace xla::codegen {

// Interface representing a single vectorized math function approximation.
// Each implementation may support multiple vector widths and primitive types,
// defined by the SupportedVectorTypes() method. To emit LLVM IR for a
// particular vector width and primitive type, call CreateDefinition() with the
// desired vector_width and primitive_type.
class MathFunction {
 public:
  virtual ~MathFunction() = default;
  // The name of the function being approximated.
  virtual absl::string_view FunctionName() const = 0;

  // Which LLVM intrinsics does this approximation replace?
  virtual std::vector<std::string> TargetFunctions() const = 0;

  struct VectorType {
    PrimitiveType dtype;
    size_t width;
  };

  // Returns the vector types supported well by this approximation.
  virtual std::vector<VectorType> SupportedVectorTypes() const = 0;

  // Returns the LLVM IR function definition for the approximation.
  // Reads the target machine and features from the LLVM module.
  virtual llvm::Function* CreateDefinition(llvm::Module& module,
                                           absl::string_view function_name,
                                           VectorType vector_type) const = 0;

  // The vectorized function name, e.g. "xla.ldexp.v8f64.v8i32".
  virtual std::string GenerateVectorizedFunctionName(
      VectorType vector_type) const = 0;
  virtual std::string GenerateMangledSimdName(VectorType vector_type) const = 0;
};

// A library of math approximations for use in codegen.
// The library hooks into LLVM compilation in two places:
// 1. It provides a set of VecDescs that are used to replace LLVM math
// intrinsics
//    with calls to vectorized approximations.
// 2. After optimization has been run, we must scan the module
//    for calls to the approximations and generate+insert the appropriate
//    function definitions in the module.
// Retains storage of the strings required for VecDescs in the instance.
class MathFunctionLib {
 public:
  MathFunctionLib();

  // Returns a vector of vectorization information for functions that have
  // vectorized approximations. This enables LLVM vectorization
  // passes to vectorize scalar math functions to custom function calls.
  // No definitions are generated by this function.
  std::vector<llvm::VecDesc> Vectorizations();

  // Inserts xla.* math function definitions into the module.
  // Will insert definitions marked as always inline and then run LLVM inliner,
  // constant propagation and early CSE passes to remove dead code.
  // Returns the set of function names that were replaced.
  absl::flat_hash_set<absl::string_view> RewriteMathFunctions(
      llvm::Module& module);

 private:
  std::vector<std::unique_ptr<MathFunction>> math_functions_;
  absl::flat_hash_map<absl::string_view, absl::string_view> targets_;
};

}  // namespace xla::codegen

#endif  // XLA_CODEGEN_MATH_LIB_H_
