diff options
Diffstat (limited to 'tracker-neuralnet/model_adapters.cpp')
| -rw-r--r-- | tracker-neuralnet/model_adapters.cpp | 42 | 
1 files changed, 33 insertions, 9 deletions
| diff --git a/tracker-neuralnet/model_adapters.cpp b/tracker-neuralnet/model_adapters.cpp index af599321..a8e55b2a 100644 --- a/tracker-neuralnet/model_adapters.cpp +++ b/tracker-neuralnet/model_adapters.cpp @@ -7,7 +7,7 @@  #include <opencv2/imgproc.hpp>  #include <QDebug> - +#include <algorithm>  namespace neuralnet_tracker_ns  { @@ -165,6 +165,24 @@ double Localizer::last_inference_time_millis() const  } +std::string PoseEstimator::get_network_input_name(size_t i) const +{ +#if ORT_API_VERSION >= 12 +    return std::string(&*session_.GetInputNameAllocated(i, allocator_)); +#else +    return std::string(session_.GetInputName(i, allocator_)); +#endif +} + +std::string PoseEstimator::get_network_output_name(size_t i) const +{ +#if ORT_API_VERSION >= 12 +    return std::string(&*session_.GetOutputNameAllocated(i, allocator_)); +#else +    return std::string(session_.GetOutputName(i, allocator_)); +#endif +} +  PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&session)       : model_version_{session.GetModelMetadata().GetVersion()}      , session_{std::move(session)} @@ -215,14 +233,16 @@ PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&ses      qDebug() << "Pose model inputs (" << session_.GetInputCount() << ")";      qDebug() << "Pose model outputs (" << session_.GetOutputCount() << "):"; +    output_names_.resize(session_.GetOutputCount()); +    output_c_names_.resize(session_.GetOutputCount());      for (size_t i=0; i<session_.GetOutputCount(); ++i)      { -        const char* name = session_.GetOutputName(i, allocator_); +        std::string name = get_network_output_name(i);          const auto& output_info = session_.GetOutputTypeInfo(i);          const auto& onnx_tensor_spec = output_info.GetTensorTypeAndShapeInfo();          auto my_tensor_spec = understood_outputs.find(name); -        qDebug() << "\t" << name << " (" << onnx_tensor_spec.GetShape() << ") dtype: " <<  onnx_tensor_spec.GetElementType() << " " << +        qDebug() << "\t" << name.c_str() << " (" << onnx_tensor_spec.GetShape() << ") dtype: " <<  onnx_tensor_spec.GetElementType() << " " <<              (my_tensor_spec != understood_outputs.end() ? "ok" : "unknown");          if (my_tensor_spec != understood_outputs.end()) @@ -240,7 +260,8 @@ PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&ses              // Create tensor regardless and ignore output              output_val_.push_back(create_tensor(output_info, allocator_));          } -        output_names_.push_back(name); +        output_names_[i] = name; +        output_c_names_[i] = output_names_[i].c_str();      }      has_uncertainty_ = understood_outputs.at("rotaxis_scales_tril").available || @@ -270,9 +291,12 @@ PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&ses      //     output_val_.push_back(create_tensor(output_info, allocator_));      // } +    input_names_.resize(session_.GetInputCount()); +    input_c_names_.resize(session_.GetInputCount());      for (size_t i = 0; i < session_.GetInputCount(); ++i)      { -        input_names_.push_back(session_.GetInputName(i, allocator_)); +        input_names_[i] = get_network_input_name(i); +        input_c_names_[i] = input_names_[i].c_str();      }      assert (input_names_.size() == input_val_.size()); @@ -312,11 +336,11 @@ std::optional<PoseEstimator::Face> PoseEstimator::run(      {          session_.Run(              Ort::RunOptions{ nullptr },  -            input_names_.data(),  +            input_c_names_.data(),              input_val_.data(),               input_val_.size(),  -            output_names_.data(),  -            output_val_.data(),  +            output_c_names_.data(), +            output_val_.data(),              output_val_.size());      }      catch (const Ort::Exception &e) @@ -430,4 +454,4 @@ double PoseEstimator::last_inference_time_millis() const -} // namespace neuralnet_tracker_ns
\ No newline at end of file +} // namespace neuralnet_tracker_ns | 
