diff options
| author | Michael Welter <michael@welter-4d.de> | 2021-07-06 20:55:36 +0200 | 
|---|---|---|
| committer | Michael Welter <michael@welter-4d.de> | 2022-05-15 16:53:20 +0200 | 
| commit | 4697a0e44536ae9dc169aab54df26e5dbebdec7c (patch) | |
| tree | 37a2604459bc4a02acc6b1fe996443cbfdce34b2 /tracker-neuralnet | |
| parent | f58bc9de86e79df8f7a30fe1e21821bdabfb165e (diff) | |
tracker/nn: Don't hardcode the size of the input image for the pose estimator
Diffstat (limited to 'tracker-neuralnet')
| -rw-r--r-- | tracker-neuralnet/ftnoir_tracker_neuralnet.cpp | 37 | ||||
| -rw-r--r-- | tracker-neuralnet/ftnoir_tracker_neuralnet.h | 2 | 
2 files changed, 26 insertions, 13 deletions
| diff --git a/tracker-neuralnet/ftnoir_tracker_neuralnet.cpp b/tracker-neuralnet/ftnoir_tracker_neuralnet.cpp index 50d1e7b4..320c8f23 100644 --- a/tracker-neuralnet/ftnoir_tracker_neuralnet.cpp +++ b/tracker-neuralnet/ftnoir_tracker_neuralnet.cpp @@ -105,10 +105,9 @@ mat33 rotation_from_two_vectors(const vec3 &a, const vec3 &b)  } -/* Computes correction due to head being off screen center. -    x, y: In screen space, i.e. in [-1,1] -    focal_length_x: In screen space -*/ +// Computes correction due to head being off screen center. +// x, y: In screen space, i.e. in [-1,1] +// focal_length_x: In screen space  mat33 compute_rotation_correction(const cv::Point2f &p, float focal_length_x)  {      return rotation_from_two_vectors( @@ -144,6 +143,20 @@ T iou(const cv::Rect_<T> &a, const cv::Rect_<T> &b)      return double{i.area()} / (a.area()+b.area()-i.area());  } +// Returns width and height of the input tensor, or throws. +// Expects the model to take one tensor as input that must +// have the shape B x C x H x W, where B=C=1. +cv::Size get_input_image_shape(const Ort::Session &session) +{ +    if (session.GetInputCount() != 1) +        throw std::invalid_argument("Model must take exactly one input tensor"); +    const std::vector<std::int64_t> shape =  +        session.GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape(); +    if (shape.size() != 4) +        throw std::invalid_argument("Model takes the input tensor in the wrong shape"); +    return { static_cast<int>(shape[3]), static_cast<int>(shape[2]) }; +} +  } // namespace @@ -222,13 +235,16 @@ double Localizer::last_inference_time_millis() const  } -PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&session) : -    session{std::move(session)}, -    scaled_frame(input_img_height, input_img_width, CV_8U), -    input_mat(input_img_height, input_img_width, CV_32F) +PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&_session) : +    session{std::move(_session)}  { +    const cv::Size input_image_shape = get_input_image_shape(session); + +    scaled_frame = cv::Mat(input_image_shape, CV_8U); +    input_mat = cv::Mat(input_image_shape, CV_32F); +      { -        const std::int64_t input_shape[4] = { 1, 1, input_img_height, input_img_width }; +        const std::int64_t input_shape[4] = { 1, 1, input_image_shape.height, input_image_shape.width };          input_val = Ort::Value::CreateTensor<float>(allocator_info, input_mat.ptr<float>(0), input_mat.total(), input_shape, 4);      } @@ -295,7 +311,7 @@ std::optional<PoseEstimator::Face> PoseEstimator::run(      auto p = input_mat.ptr(0); -    cv::resize(cropped, scaled_frame, { input_img_width, input_img_height }, 0, 0, cv::INTER_AREA); +    cv::resize(cropped, scaled_frame, scaled_frame.size(), 0, 0, cv::INTER_AREA);      // Automatic brightness amplification.      const int brightness = find_input_intensity_90_pct_quantile(); @@ -306,7 +322,6 @@ std::optional<PoseEstimator::Face> PoseEstimator::run(      assert (input_mat.ptr(0) == p);      assert (!input_mat.empty() && input_mat.isContinuous()); -    assert (input_mat.cols == input_img_width && input_mat.rows == input_img_height);      const char* input_names[] = {"x"};      const char* output_names[] = {"pos_size", "quat", "box"}; diff --git a/tracker-neuralnet/ftnoir_tracker_neuralnet.h b/tracker-neuralnet/ftnoir_tracker_neuralnet.h index 5f9c6fbe..c74a7b1d 100644 --- a/tracker-neuralnet/ftnoir_tracker_neuralnet.h +++ b/tracker-neuralnet/ftnoir_tracker_neuralnet.h @@ -118,8 +118,6 @@ class PoseEstimator      private:          // Operates on the private image data members          int find_input_intensity_90_pct_quantile() const; -        inline static constexpr int input_img_width = 129; -        inline static constexpr int input_img_height = 129;          Ort::Session session{nullptr};          // Inputs          cv::Mat scaled_frame{}, input_mat{}; | 
