From aa284ed32b6058dd664ba5abd84e1d3f21b71598 Mon Sep 17 00:00:00 2001 From: Stanislaw Halik Date: Sun, 29 Aug 2021 02:27:01 +0200 Subject: add onnxruntime --- .../build/native/include/onnxruntime_cxx_api.h | 650 +++++++++++++++++++++ 1 file changed, 650 insertions(+) create mode 100644 onnxruntime-1.8.1/build/native/include/onnxruntime_cxx_api.h (limited to 'onnxruntime-1.8.1/build/native/include/onnxruntime_cxx_api.h') diff --git a/onnxruntime-1.8.1/build/native/include/onnxruntime_cxx_api.h b/onnxruntime-1.8.1/build/native/include/onnxruntime_cxx_api.h new file mode 100644 index 0000000..4c1b707 --- /dev/null +++ b/onnxruntime-1.8.1/build/native/include/onnxruntime_cxx_api.h @@ -0,0 +1,650 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Summary: The Ort C++ API is a header only wrapper around the Ort C API. +// +// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors +// and automatically releasing resources in the destructors. +// +// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers. +// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). +// +// Only move assignment between objects is allowed, there are no copy constructors. Some objects have explicit 'Clone' +// methods for this purpose. + +#pragma once +#include "onnxruntime_c_api.h" +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef ORT_NO_EXCEPTIONS +#include +#endif + +namespace Ort { + +// All C++ methods that can fail will throw an exception of this type +struct Exception : std::exception { + Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {} + + OrtErrorCode GetOrtErrorCode() const { return code_; } + const char* what() const noexcept override { return message_.c_str(); } + + private: + std::string message_; + OrtErrorCode code_; +}; + +#ifdef ORT_NO_EXCEPTIONS +#define ORT_CXX_API_THROW(string, code) \ + do { \ + std::cerr << Ort::Exception(string, code) \ + .what() \ + << std::endl; \ + abort(); \ + } while (false) +#else +#define ORT_CXX_API_THROW(string, code) \ + throw Ort::Exception(string, code) +#endif + +// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, it's in a template so that we can define a global variable in a header and make +// it transparent to the users of the API. +template +struct Global { + static const OrtApi* api_; +}; + +// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it. + +template +#ifdef ORT_API_MANUAL_INIT +const OrtApi* Global::api_{}; +inline void InitApi() { Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } +#else +const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); +#endif + +// This returns a reference to the OrtApi interface in use, in case someone wants to use the C API functions +inline const OrtApi& GetApi() { return *Global::api_; } + +// This is a C++ wrapper for GetAvailableProviders() C API and returns +// a vector of strings representing the available execution providers. +std::vector GetAvailableProviders(); + +// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type +// This can't be done in the C API since C doesn't have function overloading. +#define ORT_DEFINE_RELEASE(NAME) \ + inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); } + +ORT_DEFINE_RELEASE(Allocator); +ORT_DEFINE_RELEASE(MemoryInfo); +ORT_DEFINE_RELEASE(CustomOpDomain); +ORT_DEFINE_RELEASE(Env); +ORT_DEFINE_RELEASE(RunOptions); +ORT_DEFINE_RELEASE(Session); +ORT_DEFINE_RELEASE(SessionOptions); +ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); +ORT_DEFINE_RELEASE(SequenceTypeInfo); +ORT_DEFINE_RELEASE(MapTypeInfo); +ORT_DEFINE_RELEASE(TypeInfo); +ORT_DEFINE_RELEASE(Value); +ORT_DEFINE_RELEASE(ModelMetadata); +ORT_DEFINE_RELEASE(ThreadingOptions); +ORT_DEFINE_RELEASE(IoBinding); +ORT_DEFINE_RELEASE(ArenaCfg); + +/*! \class Ort::Float16_t + * \brief it is a structure that represents float16 data. + * \details It is necessary for type dispatching to make use of C++ API + * The type is implicitly convertible to/from uint16_t. + * The size of the structure should align with uint16_t and one can freely cast + * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data. + * + * Generally, you can feed any of your types as float16/blfoat16 data to create a tensor + * on top of it, providing it can form a continuous buffer with 16-bit elements with no padding. + * And you can also feed a array of uint16_t elements directly. For example, + * + * \code{.unparsed} + * uint16_t values[] = { 15360, 16384, 16896, 17408, 17664}; + * constexpr size_t values_length = sizeof(values) / sizeof(values[0]); + * std::vector dims = {values_length}; // one dimensional example + * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + * // Note we are passing bytes count in this api, not number of elements -> sizeof(values) + * auto float16_tensor = Ort::Value::CreateTensor(info, values, sizeof(values), + * dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16); + * \endcode + * + * Here is another example, a little bit more elaborate. Let's assume that you use your own float16 type and you want to use + * a templated version of the API above so the type is automatically set based on your type. You will need to supply an extra + * template specialization. + * + * \code{.unparsed} + * namespace yours { struct half {}; } // assume this is your type, define this: + * namespace Ort { + * template<> + * struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }; + * } //namespace Ort + * + * std::vector values; + * std::vector dims = {values.size()}; // one dimensional example + * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + * // Here we are passing element count -> values.size() + * auto float16_tensor = Ort::Value::CreateTensor(info, values.data(), values.size(), dims.data(), dims.size()); + * + * \endcode + */ +struct Float16_t { + uint16_t value; + constexpr Float16_t() noexcept : value(0) {} + constexpr Float16_t(uint16_t v) noexcept : value(v) {} + constexpr operator uint16_t() const noexcept { return value; } + constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; }; + constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; }; +}; + +static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match"); + +/*! \class Ort::BFloat16_t + * \brief is a structure that represents bfloat16 data. + * \details It is necessary for type dispatching to make use of C++ API + * The type is implicitly convertible to/from uint16_t. + * The size of the structure should align with uint16_t and one can freely cast + * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data. + * + * See also code examples for Float16_t above. + */ +struct BFloat16_t { + uint16_t value; + constexpr BFloat16_t() noexcept : value(0) {} + constexpr BFloat16_t(uint16_t v) noexcept : value(v) {} + constexpr operator uint16_t() const noexcept { return value; } + constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; }; + constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; }; +}; + +static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match"); + +// This is used internally by the C++ API. This is the common base class used by the wrapper objects. +template +struct Base { + using contained_type = T; + + Base() = default; + Base(T* p) : p_{p} { + if (!p) + ORT_CXX_API_THROW("Allocation failure", ORT_FAIL); + } + ~Base() { OrtRelease(p_); } + + operator T*() { return p_; } + operator const T*() const { return p_; } + + T* release() { + T* p = p_; + p_ = nullptr; + return p; + } + + protected: + Base(const Base&) = delete; + Base& operator=(const Base&) = delete; + Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } + void operator=(Base&& v) noexcept { + OrtRelease(p_); + p_ = v.p_; + v.p_ = nullptr; + } + + T* p_{}; + + template + friend struct Unowned; // This friend line is needed to keep the centos C++ compiler from giving an error +}; + +template +struct Base { + using contained_type = const T; + + Base() = default; + Base(const T* p) : p_{p} { + if (!p) + ORT_CXX_API_THROW("Invalid instance ptr", ORT_INVALID_ARGUMENT); + } + ~Base() = default; + + operator const T*() const { return p_; } + + protected: + Base(const Base&) = delete; + Base& operator=(const Base&) = delete; + Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } + void operator=(Base&& v) noexcept { + p_ = v.p_; + v.p_ = nullptr; + } + + const T* p_{}; +}; + +template +struct Unowned : T { + Unowned(decltype(T::p_) p) : T{p} {} + Unowned(Unowned&& v) : T{v.p_} {} + ~Unowned() { this->release(); } +}; + +struct AllocatorWithDefaultOptions; +struct MemoryInfo; +struct Env; +struct TypeInfo; +struct Value; +struct ModelMetadata; + +struct Env : Base { + Env(std::nullptr_t) {} + Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); + Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); + Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param); + Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param, + OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); + explicit Env(OrtEnv* p) : Base{p} {} + + Env& EnableTelemetryEvents(); + Env& DisableTelemetryEvents(); + + Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); + + static const OrtApi* s_api; +}; + +struct CustomOpDomain : Base { + explicit CustomOpDomain(std::nullptr_t) {} + explicit CustomOpDomain(const char* domain); + + void Add(OrtCustomOp* op); +}; + +struct RunOptions : Base { + RunOptions(std::nullptr_t) {} + RunOptions(); + + RunOptions& SetRunLogVerbosityLevel(int); + int GetRunLogVerbosityLevel() const; + + RunOptions& SetRunLogSeverityLevel(int); + int GetRunLogSeverityLevel() const; + + RunOptions& SetRunTag(const char* run_tag); + const char* GetRunTag() const; + + RunOptions& AddConfigEntry(const char* config_key, const char* config_value); + + // terminate ALL currently executing Session::Run calls that were made using this RunOptions instance + RunOptions& SetTerminate(); + // unset the terminate flag so this RunOptions instance can be used in a new Session::Run call + RunOptions& UnsetTerminate(); +}; + +struct SessionOptions : Base { + explicit SessionOptions(std::nullptr_t) {} + SessionOptions(); + explicit SessionOptions(OrtSessionOptions* p) : Base{p} {} + + SessionOptions Clone() const; + + SessionOptions& SetIntraOpNumThreads(int intra_op_num_threads); + SessionOptions& SetInterOpNumThreads(int inter_op_num_threads); + SessionOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); + + SessionOptions& EnableCpuMemArena(); + SessionOptions& DisableCpuMemArena(); + + SessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); + + SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix); + SessionOptions& DisableProfiling(); + + SessionOptions& EnableMemPattern(); + SessionOptions& DisableMemPattern(); + + SessionOptions& SetExecutionMode(ExecutionMode execution_mode); + + SessionOptions& SetLogId(const char* logid); + SessionOptions& SetLogSeverityLevel(int level); + + SessionOptions& Add(OrtCustomOpDomain* custom_op_domain); + + SessionOptions& DisablePerSessionThreads(); + + SessionOptions& AddConfigEntry(const char* config_key, const char* config_value); + SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val); + + SessionOptions& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); + SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); + SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); + SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); +}; + +struct ModelMetadata : Base { + explicit ModelMetadata(std::nullptr_t) {} + explicit ModelMetadata(OrtModelMetadata* p) : Base{p} {} + + char* GetProducerName(OrtAllocator* allocator) const; + char* GetGraphName(OrtAllocator* allocator) const; + char* GetDomain(OrtAllocator* allocator) const; + char* GetDescription(OrtAllocator* allocator) const; + char* GetGraphDescription(OrtAllocator* allocator) const; + char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const; + char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const; + int64_t GetVersion() const; +}; + +struct Session : Base { + explicit Session(std::nullptr_t) {} + Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); + Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container); + Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); + + // Run that will allocate the output values + std::vector Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, size_t output_count); + // Run for when there is a list of preallocated outputs + void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, Value* output_values, size_t output_count); + + void Run(const RunOptions& run_options, const struct IoBinding&); + + size_t GetInputCount() const; + size_t GetOutputCount() const; + size_t GetOverridableInitializerCount() const; + + char* GetInputName(size_t index, OrtAllocator* allocator) const; + char* GetOutputName(size_t index, OrtAllocator* allocator) const; + char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const; + char* EndProfiling(OrtAllocator* allocator) const; + uint64_t GetProfilingStartTimeNs() const; + ModelMetadata GetModelMetadata() const; + + TypeInfo GetInputTypeInfo(size_t index) const; + TypeInfo GetOutputTypeInfo(size_t index) const; + TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; +}; + +struct TensorTypeAndShapeInfo : Base { + explicit TensorTypeAndShapeInfo(std::nullptr_t) {} + explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : Base{p} {} + + ONNXTensorElementDataType GetElementType() const; + size_t GetElementCount() const; + + size_t GetDimensionsCount() const; + void GetDimensions(int64_t* values, size_t values_count) const; + void GetSymbolicDimensions(const char** values, size_t values_count) const; + + std::vector GetShape() const; +}; + +struct SequenceTypeInfo : Base { + explicit SequenceTypeInfo(std::nullptr_t) {} + explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : Base{p} {} + + TypeInfo GetSequenceElementType() const; +}; + +struct MapTypeInfo : Base { + explicit MapTypeInfo(std::nullptr_t) {} + explicit MapTypeInfo(OrtMapTypeInfo* p) : Base{p} {} + + ONNXTensorElementDataType GetMapKeyType() const; + TypeInfo GetMapValueType() const; +}; + +struct TypeInfo : Base { + explicit TypeInfo(std::nullptr_t) {} + explicit TypeInfo(OrtTypeInfo* p) : Base{p} {} + + Unowned GetTensorTypeAndShapeInfo() const; + Unowned GetSequenceTypeInfo() const; + Unowned GetMapTypeInfo() const; + + ONNXType GetONNXType() const; +}; + +struct Value : Base { + template + static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len); + static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type); + template + static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len); + static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); + + static Value CreateMap(Value& keys, Value& values); + static Value CreateSequence(std::vector& values); + + template + static Value CreateOpaque(const char* domain, const char* type_name, const T&); + + template + void GetOpaqueData(const char* domain, const char* type_name, T&) const; + + explicit Value(std::nullptr_t) {} + explicit Value(OrtValue* p) : Base{p} {} + Value(Value&&) = default; + Value& operator=(Value&&) = default; + + bool IsTensor() const; + size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements + Value GetValue(int index, OrtAllocator* allocator) const; + + size_t GetStringTensorDataLength() const; + void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const; + + template + T* GetTensorMutableData(); + + template + const T* GetTensorData() const; + + template + T& At(const std::vector& location); + + TypeInfo GetTypeInfo() const; + TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; + + size_t GetStringTensorElementLength(size_t element_index) const; + void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const; + + void FillStringTensor(const char* const* s, size_t s_len); + void FillStringTensorElement(const char* s, size_t index); +}; + +// Represents native memory allocation +struct MemoryAllocation { + MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); + ~MemoryAllocation(); + MemoryAllocation(const MemoryAllocation&) = delete; + MemoryAllocation& operator=(const MemoryAllocation&) = delete; + MemoryAllocation(MemoryAllocation&&); + MemoryAllocation& operator=(MemoryAllocation&&); + + void* get() { return p_; } + size_t size() const { return size_; } + + private: + OrtAllocator* allocator_; + void* p_; + size_t size_; +}; + +struct AllocatorWithDefaultOptions { + AllocatorWithDefaultOptions(); + + operator OrtAllocator*() { return p_; } + operator const OrtAllocator*() const { return p_; } + + void* Alloc(size_t size); + // The return value will own the allocation + MemoryAllocation GetAllocation(size_t size); + void Free(void* p); + + const OrtMemoryInfo* GetInfo() const; + + private: + OrtAllocator* p_{}; +}; + +template +struct BaseMemoryInfo : B { + BaseMemoryInfo() = default; + explicit BaseMemoryInfo(typename B::contained_type* p) : B(p) {} + ~BaseMemoryInfo() = default; + BaseMemoryInfo(BaseMemoryInfo&&) = default; + BaseMemoryInfo& operator=(BaseMemoryInfo&&) = default; + + std::string GetAllocatorName() const; + OrtAllocatorType GetAllocatorType() const; + int GetDeviceId() const; + OrtMemType GetMemoryType() const; + template + bool operator==(const BaseMemoryInfo& o) const; +}; + +struct UnownedMemoryInfo : BaseMemoryInfo > { + explicit UnownedMemoryInfo(std::nullptr_t) {} + explicit UnownedMemoryInfo(const OrtMemoryInfo* p) : BaseMemoryInfo(p) {} +}; + +struct MemoryInfo : BaseMemoryInfo > { + static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); + + explicit MemoryInfo(std::nullptr_t) {} + explicit MemoryInfo(OrtMemoryInfo* p) : BaseMemoryInfo(p) {} + MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); +}; + +struct Allocator : public Base { + Allocator(const Session& session, const MemoryInfo&); + + void* Alloc(size_t size) const; + // The return value will own the allocation + MemoryAllocation GetAllocation(size_t size); + void Free(void* p) const; + UnownedMemoryInfo GetInfo() const; +}; + +struct IoBinding : public Base { + private: + std::vector GetOutputNamesHelper(OrtAllocator*) const; + std::vector GetOutputValuesHelper(OrtAllocator*) const; + + public: + explicit IoBinding(Session& session); + void BindInput(const char* name, const Value&); + void BindOutput(const char* name, const Value&); + void BindOutput(const char* name, const MemoryInfo&); + std::vector GetOutputNames() const; + std::vector GetOutputNames(Allocator&) const; + std::vector GetOutputValues() const; + std::vector GetOutputValues(Allocator&) const; + void ClearBoundInputs(); + void ClearBoundOutputs(); +}; + +/*! \struct Ort::ArenaCfg + * \brief it is a structure that represents the configuration of an arena based allocator + * \details Please see docs/C_API.md for details + */ +struct ArenaCfg : Base { + explicit ArenaCfg(std::nullptr_t) {} + /** + * \param max_mem - use 0 to allow ORT to choose the default + * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested + * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default + * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default + * See docs/C_API.md for details on what the following parameters mean and how to choose these values + */ + ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk); +}; + +// +// Custom OPs (only needed to implement custom OPs) +// + +struct CustomOpApi { + CustomOpApi(const OrtApi& api) : api_(api) {} + + template // T is only implemented for std::vector, std::vector, float, int64_t, and string + T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name); + + OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value); + size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info); + ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info); + size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info); + void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length); + void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); + + template + T* GetTensorMutableData(_Inout_ OrtValue* value); + template + const T* GetTensorData(_Inout_ const OrtValue* value); + + std::vector GetTensorShape(const OrtTensorTypeAndShapeInfo* info); + void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input); + size_t KernelContext_GetInputCount(const OrtKernelContext* context); + const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index); + size_t KernelContext_GetOutputCount(const OrtKernelContext* context); + OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count); + + void ThrowOnError(OrtStatus* result); + + private: + const OrtApi& api_; +}; + +template +struct CustomOpBase : OrtCustomOp { + CustomOpBase() { + OrtCustomOp::version = ORT_API_VERSION; + OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast(this_)->CreateKernel(*api, info); }; + OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast(this_)->GetName(); }; + + OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast(this_)->GetExecutionProviderType(); }; + + OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetInputTypeCount(); }; + OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputType(index); }; + + OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetOutputTypeCount(); }; + OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputType(index); }; + + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast(op_kernel)->Compute(context); }; + OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast(op_kernel); }; + + OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputCharacteristic(index); }; + OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputCharacteristic(index); }; + } + + // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider + const char* GetExecutionProviderType() const { return nullptr; } + + // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below + // (inputs and outputs are required by default) + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + } + + OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + } +}; + +} // namespace Ort + +#include "onnxruntime_cxx_inline.h" -- cgit v1.2.3