summaryrefslogtreecommitdiffhomepage
path: root/tracker-neuralnet/model_adapters.cpp
diff options
context:
space:
mode:
authorZhao Zhixu <zzx.2013@qq.com>2023-04-07 21:30:37 +0800
committerGitHub <noreply@github.com>2023-04-07 13:30:37 +0000
commitfc135349a356ee50703361cacb83ea4a83c71936 (patch)
tree8b01ecc0b6d6ed1803f8d79cb356c7627020e45a /tracker-neuralnet/model_adapters.cpp
parent0883d8a05f8f7f16d94516a27bfea9f9913a90be (diff)
tracker/neuralnet: Add support for building on Linux. (#1638)
Diffstat (limited to 'tracker-neuralnet/model_adapters.cpp')
-rw-r--r--tracker-neuralnet/model_adapters.cpp42
1 files changed, 33 insertions, 9 deletions
diff --git a/tracker-neuralnet/model_adapters.cpp b/tracker-neuralnet/model_adapters.cpp
index af599321..a8e55b2a 100644
--- a/tracker-neuralnet/model_adapters.cpp
+++ b/tracker-neuralnet/model_adapters.cpp
@@ -7,7 +7,7 @@
#include <opencv2/imgproc.hpp>
#include <QDebug>
-
+#include <algorithm>
namespace neuralnet_tracker_ns
{
@@ -165,6 +165,24 @@ double Localizer::last_inference_time_millis() const
}
+std::string PoseEstimator::get_network_input_name(size_t i) const
+{
+#if ORT_API_VERSION >= 12
+ return std::string(&*session_.GetInputNameAllocated(i, allocator_));
+#else
+ return std::string(session_.GetInputName(i, allocator_));
+#endif
+}
+
+std::string PoseEstimator::get_network_output_name(size_t i) const
+{
+#if ORT_API_VERSION >= 12
+ return std::string(&*session_.GetOutputNameAllocated(i, allocator_));
+#else
+ return std::string(session_.GetOutputName(i, allocator_));
+#endif
+}
+
PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&session)
: model_version_{session.GetModelMetadata().GetVersion()}
, session_{std::move(session)}
@@ -215,14 +233,16 @@ PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&ses
qDebug() << "Pose model inputs (" << session_.GetInputCount() << ")";
qDebug() << "Pose model outputs (" << session_.GetOutputCount() << "):";
+ output_names_.resize(session_.GetOutputCount());
+ output_c_names_.resize(session_.GetOutputCount());
for (size_t i=0; i<session_.GetOutputCount(); ++i)
{
- const char* name = session_.GetOutputName(i, allocator_);
+ std::string name = get_network_output_name(i);
const auto& output_info = session_.GetOutputTypeInfo(i);
const auto& onnx_tensor_spec = output_info.GetTensorTypeAndShapeInfo();
auto my_tensor_spec = understood_outputs.find(name);
- qDebug() << "\t" << name << " (" << onnx_tensor_spec.GetShape() << ") dtype: " << onnx_tensor_spec.GetElementType() << " " <<
+ qDebug() << "\t" << name.c_str() << " (" << onnx_tensor_spec.GetShape() << ") dtype: " << onnx_tensor_spec.GetElementType() << " " <<
(my_tensor_spec != understood_outputs.end() ? "ok" : "unknown");
if (my_tensor_spec != understood_outputs.end())
@@ -240,7 +260,8 @@ PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&ses
// Create tensor regardless and ignore output
output_val_.push_back(create_tensor(output_info, allocator_));
}
- output_names_.push_back(name);
+ output_names_[i] = name;
+ output_c_names_[i] = output_names_[i].c_str();
}
has_uncertainty_ = understood_outputs.at("rotaxis_scales_tril").available ||
@@ -270,9 +291,12 @@ PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&ses
// output_val_.push_back(create_tensor(output_info, allocator_));
// }
+ input_names_.resize(session_.GetInputCount());
+ input_c_names_.resize(session_.GetInputCount());
for (size_t i = 0; i < session_.GetInputCount(); ++i)
{
- input_names_.push_back(session_.GetInputName(i, allocator_));
+ input_names_[i] = get_network_input_name(i);
+ input_c_names_[i] = input_names_[i].c_str();
}
assert (input_names_.size() == input_val_.size());
@@ -312,11 +336,11 @@ std::optional<PoseEstimator::Face> PoseEstimator::run(
{
session_.Run(
Ort::RunOptions{ nullptr },
- input_names_.data(),
+ input_c_names_.data(),
input_val_.data(),
input_val_.size(),
- output_names_.data(),
- output_val_.data(),
+ output_c_names_.data(),
+ output_val_.data(),
output_val_.size());
}
catch (const Ort::Exception &e)
@@ -430,4 +454,4 @@ double PoseEstimator::last_inference_time_millis() const
-} // namespace neuralnet_tracker_ns \ No newline at end of file
+} // namespace neuralnet_tracker_ns