From 46421cfebc2da828d0bd611567dfdc3f6ce039a2 Mon Sep 17 00:00:00 2001 From: Stanislaw Halik Date: Sun, 2 Oct 2022 19:37:10 +0200 Subject: update msvc --- .../build/native/include/onnxruntime_cxx_api.h | 650 --------------------- 1 file changed, 650 deletions(-) delete mode 100644 onnxruntime-1.8.1/build/native/include/onnxruntime_cxx_api.h (limited to 'onnxruntime-1.8.1/build/native/include/onnxruntime_cxx_api.h') diff --git a/onnxruntime-1.8.1/build/native/include/onnxruntime_cxx_api.h b/onnxruntime-1.8.1/build/native/include/onnxruntime_cxx_api.h deleted file mode 100644 index 4c1b707..0000000 --- a/onnxruntime-1.8.1/build/native/include/onnxruntime_cxx_api.h +++ /dev/null @@ -1,650 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// Summary: The Ort C++ API is a header only wrapper around the Ort C API. -// -// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors -// and automatically releasing resources in the destructors. -// -// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers. -// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). -// -// Only move assignment between objects is allowed, there are no copy constructors. Some objects have explicit 'Clone' -// methods for this purpose. - -#pragma once -#include "onnxruntime_c_api.h" -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef ORT_NO_EXCEPTIONS -#include -#endif - -namespace Ort { - -// All C++ methods that can fail will throw an exception of this type -struct Exception : std::exception { - Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {} - - OrtErrorCode GetOrtErrorCode() const { return code_; } - const char* what() const noexcept override { return message_.c_str(); } - - private: - std::string message_; - OrtErrorCode code_; -}; - -#ifdef ORT_NO_EXCEPTIONS -#define ORT_CXX_API_THROW(string, code) \ - do { \ - std::cerr << Ort::Exception(string, code) \ - .what() \ - << std::endl; \ - abort(); \ - } while (false) -#else -#define ORT_CXX_API_THROW(string, code) \ - throw Ort::Exception(string, code) -#endif - -// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, it's in a template so that we can define a global variable in a header and make -// it transparent to the users of the API. -template -struct Global { - static const OrtApi* api_; -}; - -// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it. - -template -#ifdef ORT_API_MANUAL_INIT -const OrtApi* Global::api_{}; -inline void InitApi() { Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } -#else -const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); -#endif - -// This returns a reference to the OrtApi interface in use, in case someone wants to use the C API functions -inline const OrtApi& GetApi() { return *Global::api_; } - -// This is a C++ wrapper for GetAvailableProviders() C API and returns -// a vector of strings representing the available execution providers. -std::vector GetAvailableProviders(); - -// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type -// This can't be done in the C API since C doesn't have function overloading. -#define ORT_DEFINE_RELEASE(NAME) \ - inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); } - -ORT_DEFINE_RELEASE(Allocator); -ORT_DEFINE_RELEASE(MemoryInfo); -ORT_DEFINE_RELEASE(CustomOpDomain); -ORT_DEFINE_RELEASE(Env); -ORT_DEFINE_RELEASE(RunOptions); -ORT_DEFINE_RELEASE(Session); -ORT_DEFINE_RELEASE(SessionOptions); -ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); -ORT_DEFINE_RELEASE(SequenceTypeInfo); -ORT_DEFINE_RELEASE(MapTypeInfo); -ORT_DEFINE_RELEASE(TypeInfo); -ORT_DEFINE_RELEASE(Value); -ORT_DEFINE_RELEASE(ModelMetadata); -ORT_DEFINE_RELEASE(ThreadingOptions); -ORT_DEFINE_RELEASE(IoBinding); -ORT_DEFINE_RELEASE(ArenaCfg); - -/*! \class Ort::Float16_t - * \brief it is a structure that represents float16 data. - * \details It is necessary for type dispatching to make use of C++ API - * The type is implicitly convertible to/from uint16_t. - * The size of the structure should align with uint16_t and one can freely cast - * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data. - * - * Generally, you can feed any of your types as float16/blfoat16 data to create a tensor - * on top of it, providing it can form a continuous buffer with 16-bit elements with no padding. - * And you can also feed a array of uint16_t elements directly. For example, - * - * \code{.unparsed} - * uint16_t values[] = { 15360, 16384, 16896, 17408, 17664}; - * constexpr size_t values_length = sizeof(values) / sizeof(values[0]); - * std::vector dims = {values_length}; // one dimensional example - * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); - * // Note we are passing bytes count in this api, not number of elements -> sizeof(values) - * auto float16_tensor = Ort::Value::CreateTensor(info, values, sizeof(values), - * dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16); - * \endcode - * - * Here is another example, a little bit more elaborate. Let's assume that you use your own float16 type and you want to use - * a templated version of the API above so the type is automatically set based on your type. You will need to supply an extra - * template specialization. - * - * \code{.unparsed} - * namespace yours { struct half {}; } // assume this is your type, define this: - * namespace Ort { - * template<> - * struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }; - * } //namespace Ort - * - * std::vector values; - * std::vector dims = {values.size()}; // one dimensional example - * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); - * // Here we are passing element count -> values.size() - * auto float16_tensor = Ort::Value::CreateTensor(info, values.data(), values.size(), dims.data(), dims.size()); - * - * \endcode - */ -struct Float16_t { - uint16_t value; - constexpr Float16_t() noexcept : value(0) {} - constexpr Float16_t(uint16_t v) noexcept : value(v) {} - constexpr operator uint16_t() const noexcept { return value; } - constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; }; - constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; }; -}; - -static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match"); - -/*! \class Ort::BFloat16_t - * \brief is a structure that represents bfloat16 data. - * \details It is necessary for type dispatching to make use of C++ API - * The type is implicitly convertible to/from uint16_t. - * The size of the structure should align with uint16_t and one can freely cast - * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data. - * - * See also code examples for Float16_t above. - */ -struct BFloat16_t { - uint16_t value; - constexpr BFloat16_t() noexcept : value(0) {} - constexpr BFloat16_t(uint16_t v) noexcept : value(v) {} - constexpr operator uint16_t() const noexcept { return value; } - constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; }; - constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; }; -}; - -static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match"); - -// This is used internally by the C++ API. This is the common base class used by the wrapper objects. -template -struct Base { - using contained_type = T; - - Base() = default; - Base(T* p) : p_{p} { - if (!p) - ORT_CXX_API_THROW("Allocation failure", ORT_FAIL); - } - ~Base() { OrtRelease(p_); } - - operator T*() { return p_; } - operator const T*() const { return p_; } - - T* release() { - T* p = p_; - p_ = nullptr; - return p; - } - - protected: - Base(const Base&) = delete; - Base& operator=(const Base&) = delete; - Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } - void operator=(Base&& v) noexcept { - OrtRelease(p_); - p_ = v.p_; - v.p_ = nullptr; - } - - T* p_{}; - - template - friend struct Unowned; // This friend line is needed to keep the centos C++ compiler from giving an error -}; - -template -struct Base { - using contained_type = const T; - - Base() = default; - Base(const T* p) : p_{p} { - if (!p) - ORT_CXX_API_THROW("Invalid instance ptr", ORT_INVALID_ARGUMENT); - } - ~Base() = default; - - operator const T*() const { return p_; } - - protected: - Base(const Base&) = delete; - Base& operator=(const Base&) = delete; - Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } - void operator=(Base&& v) noexcept { - p_ = v.p_; - v.p_ = nullptr; - } - - const T* p_{}; -}; - -template -struct Unowned : T { - Unowned(decltype(T::p_) p) : T{p} {} - Unowned(Unowned&& v) : T{v.p_} {} - ~Unowned() { this->release(); } -}; - -struct AllocatorWithDefaultOptions; -struct MemoryInfo; -struct Env; -struct TypeInfo; -struct Value; -struct ModelMetadata; - -struct Env : Base { - Env(std::nullptr_t) {} - Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); - Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); - Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param); - Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param, - OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); - explicit Env(OrtEnv* p) : Base{p} {} - - Env& EnableTelemetryEvents(); - Env& DisableTelemetryEvents(); - - Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); - - static const OrtApi* s_api; -}; - -struct CustomOpDomain : Base { - explicit CustomOpDomain(std::nullptr_t) {} - explicit CustomOpDomain(const char* domain); - - void Add(OrtCustomOp* op); -}; - -struct RunOptions : Base { - RunOptions(std::nullptr_t) {} - RunOptions(); - - RunOptions& SetRunLogVerbosityLevel(int); - int GetRunLogVerbosityLevel() const; - - RunOptions& SetRunLogSeverityLevel(int); - int GetRunLogSeverityLevel() const; - - RunOptions& SetRunTag(const char* run_tag); - const char* GetRunTag() const; - - RunOptions& AddConfigEntry(const char* config_key, const char* config_value); - - // terminate ALL currently executing Session::Run calls that were made using this RunOptions instance - RunOptions& SetTerminate(); - // unset the terminate flag so this RunOptions instance can be used in a new Session::Run call - RunOptions& UnsetTerminate(); -}; - -struct SessionOptions : Base { - explicit SessionOptions(std::nullptr_t) {} - SessionOptions(); - explicit SessionOptions(OrtSessionOptions* p) : Base{p} {} - - SessionOptions Clone() const; - - SessionOptions& SetIntraOpNumThreads(int intra_op_num_threads); - SessionOptions& SetInterOpNumThreads(int inter_op_num_threads); - SessionOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); - - SessionOptions& EnableCpuMemArena(); - SessionOptions& DisableCpuMemArena(); - - SessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); - - SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix); - SessionOptions& DisableProfiling(); - - SessionOptions& EnableMemPattern(); - SessionOptions& DisableMemPattern(); - - SessionOptions& SetExecutionMode(ExecutionMode execution_mode); - - SessionOptions& SetLogId(const char* logid); - SessionOptions& SetLogSeverityLevel(int level); - - SessionOptions& Add(OrtCustomOpDomain* custom_op_domain); - - SessionOptions& DisablePerSessionThreads(); - - SessionOptions& AddConfigEntry(const char* config_key, const char* config_value); - SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val); - - SessionOptions& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); - SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); - SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); - SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); -}; - -struct ModelMetadata : Base { - explicit ModelMetadata(std::nullptr_t) {} - explicit ModelMetadata(OrtModelMetadata* p) : Base{p} {} - - char* GetProducerName(OrtAllocator* allocator) const; - char* GetGraphName(OrtAllocator* allocator) const; - char* GetDomain(OrtAllocator* allocator) const; - char* GetDescription(OrtAllocator* allocator) const; - char* GetGraphDescription(OrtAllocator* allocator) const; - char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const; - char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const; - int64_t GetVersion() const; -}; - -struct Session : Base { - explicit Session(std::nullptr_t) {} - Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); - Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container); - Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); - - // Run that will allocate the output values - std::vector Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, - const char* const* output_names, size_t output_count); - // Run for when there is a list of preallocated outputs - void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, - const char* const* output_names, Value* output_values, size_t output_count); - - void Run(const RunOptions& run_options, const struct IoBinding&); - - size_t GetInputCount() const; - size_t GetOutputCount() const; - size_t GetOverridableInitializerCount() const; - - char* GetInputName(size_t index, OrtAllocator* allocator) const; - char* GetOutputName(size_t index, OrtAllocator* allocator) const; - char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const; - char* EndProfiling(OrtAllocator* allocator) const; - uint64_t GetProfilingStartTimeNs() const; - ModelMetadata GetModelMetadata() const; - - TypeInfo GetInputTypeInfo(size_t index) const; - TypeInfo GetOutputTypeInfo(size_t index) const; - TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; -}; - -struct TensorTypeAndShapeInfo : Base { - explicit TensorTypeAndShapeInfo(std::nullptr_t) {} - explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : Base{p} {} - - ONNXTensorElementDataType GetElementType() const; - size_t GetElementCount() const; - - size_t GetDimensionsCount() const; - void GetDimensions(int64_t* values, size_t values_count) const; - void GetSymbolicDimensions(const char** values, size_t values_count) const; - - std::vector GetShape() const; -}; - -struct SequenceTypeInfo : Base { - explicit SequenceTypeInfo(std::nullptr_t) {} - explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : Base{p} {} - - TypeInfo GetSequenceElementType() const; -}; - -struct MapTypeInfo : Base { - explicit MapTypeInfo(std::nullptr_t) {} - explicit MapTypeInfo(OrtMapTypeInfo* p) : Base{p} {} - - ONNXTensorElementDataType GetMapKeyType() const; - TypeInfo GetMapValueType() const; -}; - -struct TypeInfo : Base { - explicit TypeInfo(std::nullptr_t) {} - explicit TypeInfo(OrtTypeInfo* p) : Base{p} {} - - Unowned GetTensorTypeAndShapeInfo() const; - Unowned GetSequenceTypeInfo() const; - Unowned GetMapTypeInfo() const; - - ONNXType GetONNXType() const; -}; - -struct Value : Base { - template - static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len); - static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type); - template - static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len); - static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); - - static Value CreateMap(Value& keys, Value& values); - static Value CreateSequence(std::vector& values); - - template - static Value CreateOpaque(const char* domain, const char* type_name, const T&); - - template - void GetOpaqueData(const char* domain, const char* type_name, T&) const; - - explicit Value(std::nullptr_t) {} - explicit Value(OrtValue* p) : Base{p} {} - Value(Value&&) = default; - Value& operator=(Value&&) = default; - - bool IsTensor() const; - size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements - Value GetValue(int index, OrtAllocator* allocator) const; - - size_t GetStringTensorDataLength() const; - void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const; - - template - T* GetTensorMutableData(); - - template - const T* GetTensorData() const; - - template - T& At(const std::vector& location); - - TypeInfo GetTypeInfo() const; - TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; - - size_t GetStringTensorElementLength(size_t element_index) const; - void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const; - - void FillStringTensor(const char* const* s, size_t s_len); - void FillStringTensorElement(const char* s, size_t index); -}; - -// Represents native memory allocation -struct MemoryAllocation { - MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); - ~MemoryAllocation(); - MemoryAllocation(const MemoryAllocation&) = delete; - MemoryAllocation& operator=(const MemoryAllocation&) = delete; - MemoryAllocation(MemoryAllocation&&); - MemoryAllocation& operator=(MemoryAllocation&&); - - void* get() { return p_; } - size_t size() const { return size_; } - - private: - OrtAllocator* allocator_; - void* p_; - size_t size_; -}; - -struct AllocatorWithDefaultOptions { - AllocatorWithDefaultOptions(); - - operator OrtAllocator*() { return p_; } - operator const OrtAllocator*() const { return p_; } - - void* Alloc(size_t size); - // The return value will own the allocation - MemoryAllocation GetAllocation(size_t size); - void Free(void* p); - - const OrtMemoryInfo* GetInfo() const; - - private: - OrtAllocator* p_{}; -}; - -template -struct BaseMemoryInfo : B { - BaseMemoryInfo() = default; - explicit BaseMemoryInfo(typename B::contained_type* p) : B(p) {} - ~BaseMemoryInfo() = default; - BaseMemoryInfo(BaseMemoryInfo&&) = default; - BaseMemoryInfo& operator=(BaseMemoryInfo&&) = default; - - std::string GetAllocatorName() const; - OrtAllocatorType GetAllocatorType() const; - int GetDeviceId() const; - OrtMemType GetMemoryType() const; - template - bool operator==(const BaseMemoryInfo& o) const; -}; - -struct UnownedMemoryInfo : BaseMemoryInfo > { - explicit UnownedMemoryInfo(std::nullptr_t) {} - explicit UnownedMemoryInfo(const OrtMemoryInfo* p) : BaseMemoryInfo(p) {} -}; - -struct MemoryInfo : BaseMemoryInfo > { - static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); - - explicit MemoryInfo(std::nullptr_t) {} - explicit MemoryInfo(OrtMemoryInfo* p) : BaseMemoryInfo(p) {} - MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); -}; - -struct Allocator : public Base { - Allocator(const Session& session, const MemoryInfo&); - - void* Alloc(size_t size) const; - // The return value will own the allocation - MemoryAllocation GetAllocation(size_t size); - void Free(void* p) const; - UnownedMemoryInfo GetInfo() const; -}; - -struct IoBinding : public Base { - private: - std::vector GetOutputNamesHelper(OrtAllocator*) const; - std::vector GetOutputValuesHelper(OrtAllocator*) const; - - public: - explicit IoBinding(Session& session); - void BindInput(const char* name, const Value&); - void BindOutput(const char* name, const Value&); - void BindOutput(const char* name, const MemoryInfo&); - std::vector GetOutputNames() const; - std::vector GetOutputNames(Allocator&) const; - std::vector GetOutputValues() const; - std::vector GetOutputValues(Allocator&) const; - void ClearBoundInputs(); - void ClearBoundOutputs(); -}; - -/*! \struct Ort::ArenaCfg - * \brief it is a structure that represents the configuration of an arena based allocator - * \details Please see docs/C_API.md for details - */ -struct ArenaCfg : Base { - explicit ArenaCfg(std::nullptr_t) {} - /** - * \param max_mem - use 0 to allow ORT to choose the default - * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested - * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default - * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default - * See docs/C_API.md for details on what the following parameters mean and how to choose these values - */ - ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk); -}; - -// -// Custom OPs (only needed to implement custom OPs) -// - -struct CustomOpApi { - CustomOpApi(const OrtApi& api) : api_(api) {} - - template // T is only implemented for std::vector, std::vector, float, int64_t, and string - T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name); - - OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value); - size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info); - ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info); - size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info); - void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length); - void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); - - template - T* GetTensorMutableData(_Inout_ OrtValue* value); - template - const T* GetTensorData(_Inout_ const OrtValue* value); - - std::vector GetTensorShape(const OrtTensorTypeAndShapeInfo* info); - void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input); - size_t KernelContext_GetInputCount(const OrtKernelContext* context); - const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index); - size_t KernelContext_GetOutputCount(const OrtKernelContext* context); - OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count); - - void ThrowOnError(OrtStatus* result); - - private: - const OrtApi& api_; -}; - -template -struct CustomOpBase : OrtCustomOp { - CustomOpBase() { - OrtCustomOp::version = ORT_API_VERSION; - OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast(this_)->CreateKernel(*api, info); }; - OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast(this_)->GetName(); }; - - OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast(this_)->GetExecutionProviderType(); }; - - OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetInputTypeCount(); }; - OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputType(index); }; - - OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetOutputTypeCount(); }; - OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputType(index); }; - - OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast(op_kernel)->Compute(context); }; - OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast(op_kernel); }; - - OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputCharacteristic(index); }; - OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputCharacteristic(index); }; - } - - // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider - const char* GetExecutionProviderType() const { return nullptr; } - - // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below - // (inputs and outputs are required by default) - OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const { - return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; - } - - OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const { - return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; - } -}; - -} // namespace Ort - -#include "onnxruntime_cxx_inline.h" -- cgit v1.2.3