diff options
author | Zhao Zhixu <zzx.2013@qq.com> | 2023-04-07 21:30:37 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-07 13:30:37 +0000 |
commit | fc135349a356ee50703361cacb83ea4a83c71936 (patch) | |
tree | 8b01ecc0b6d6ed1803f8d79cb356c7627020e45a /tracker-neuralnet/model_adapters.cpp | |
parent | 0883d8a05f8f7f16d94516a27bfea9f9913a90be (diff) |
tracker/neuralnet: Add support for building on Linux. (#1638)
Diffstat (limited to 'tracker-neuralnet/model_adapters.cpp')
-rw-r--r-- | tracker-neuralnet/model_adapters.cpp | 42 |
1 files changed, 33 insertions, 9 deletions
diff --git a/tracker-neuralnet/model_adapters.cpp b/tracker-neuralnet/model_adapters.cpp index af599321..a8e55b2a 100644 --- a/tracker-neuralnet/model_adapters.cpp +++ b/tracker-neuralnet/model_adapters.cpp @@ -7,7 +7,7 @@ #include <opencv2/imgproc.hpp> #include <QDebug> - +#include <algorithm> namespace neuralnet_tracker_ns { @@ -165,6 +165,24 @@ double Localizer::last_inference_time_millis() const } +std::string PoseEstimator::get_network_input_name(size_t i) const +{ +#if ORT_API_VERSION >= 12 + return std::string(&*session_.GetInputNameAllocated(i, allocator_)); +#else + return std::string(session_.GetInputName(i, allocator_)); +#endif +} + +std::string PoseEstimator::get_network_output_name(size_t i) const +{ +#if ORT_API_VERSION >= 12 + return std::string(&*session_.GetOutputNameAllocated(i, allocator_)); +#else + return std::string(session_.GetOutputName(i, allocator_)); +#endif +} + PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&session) : model_version_{session.GetModelMetadata().GetVersion()} , session_{std::move(session)} @@ -215,14 +233,16 @@ PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&ses qDebug() << "Pose model inputs (" << session_.GetInputCount() << ")"; qDebug() << "Pose model outputs (" << session_.GetOutputCount() << "):"; + output_names_.resize(session_.GetOutputCount()); + output_c_names_.resize(session_.GetOutputCount()); for (size_t i=0; i<session_.GetOutputCount(); ++i) { - const char* name = session_.GetOutputName(i, allocator_); + std::string name = get_network_output_name(i); const auto& output_info = session_.GetOutputTypeInfo(i); const auto& onnx_tensor_spec = output_info.GetTensorTypeAndShapeInfo(); auto my_tensor_spec = understood_outputs.find(name); - qDebug() << "\t" << name << " (" << onnx_tensor_spec.GetShape() << ") dtype: " << onnx_tensor_spec.GetElementType() << " " << + qDebug() << "\t" << name.c_str() << " (" << onnx_tensor_spec.GetShape() << ") dtype: " << onnx_tensor_spec.GetElementType() << " " << (my_tensor_spec != understood_outputs.end() ? "ok" : "unknown"); if (my_tensor_spec != understood_outputs.end()) @@ -240,7 +260,8 @@ PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&ses // Create tensor regardless and ignore output output_val_.push_back(create_tensor(output_info, allocator_)); } - output_names_.push_back(name); + output_names_[i] = name; + output_c_names_[i] = output_names_[i].c_str(); } has_uncertainty_ = understood_outputs.at("rotaxis_scales_tril").available || @@ -270,9 +291,12 @@ PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&ses // output_val_.push_back(create_tensor(output_info, allocator_)); // } + input_names_.resize(session_.GetInputCount()); + input_c_names_.resize(session_.GetInputCount()); for (size_t i = 0; i < session_.GetInputCount(); ++i) { - input_names_.push_back(session_.GetInputName(i, allocator_)); + input_names_[i] = get_network_input_name(i); + input_c_names_[i] = input_names_[i].c_str(); } assert (input_names_.size() == input_val_.size()); @@ -312,11 +336,11 @@ std::optional<PoseEstimator::Face> PoseEstimator::run( { session_.Run( Ort::RunOptions{ nullptr }, - input_names_.data(), + input_c_names_.data(), input_val_.data(), input_val_.size(), - output_names_.data(), - output_val_.data(), + output_c_names_.data(), + output_val_.data(), output_val_.size()); } catch (const Ort::Exception &e) @@ -430,4 +454,4 @@ double PoseEstimator::last_inference_time_millis() const -} // namespace neuralnet_tracker_ns
\ No newline at end of file +} // namespace neuralnet_tracker_ns |