diff options
author | Michael Welter <michael@welter-4d.de> | 2021-07-06 20:55:36 +0200 |
---|---|---|
committer | Michael Welter <michael@welter-4d.de> | 2022-05-15 16:53:20 +0200 |
commit | 4697a0e44536ae9dc169aab54df26e5dbebdec7c (patch) | |
tree | 37a2604459bc4a02acc6b1fe996443cbfdce34b2 | |
parent | f58bc9de86e79df8f7a30fe1e21821bdabfb165e (diff) |
tracker/nn: Don't hardcode the size of the input image for the pose estimator
-rw-r--r-- | tracker-neuralnet/ftnoir_tracker_neuralnet.cpp | 37 | ||||
-rw-r--r-- | tracker-neuralnet/ftnoir_tracker_neuralnet.h | 2 |
2 files changed, 26 insertions, 13 deletions
diff --git a/tracker-neuralnet/ftnoir_tracker_neuralnet.cpp b/tracker-neuralnet/ftnoir_tracker_neuralnet.cpp index 50d1e7b4..320c8f23 100644 --- a/tracker-neuralnet/ftnoir_tracker_neuralnet.cpp +++ b/tracker-neuralnet/ftnoir_tracker_neuralnet.cpp @@ -105,10 +105,9 @@ mat33 rotation_from_two_vectors(const vec3 &a, const vec3 &b) } -/* Computes correction due to head being off screen center. - x, y: In screen space, i.e. in [-1,1] - focal_length_x: In screen space -*/ +// Computes correction due to head being off screen center. +// x, y: In screen space, i.e. in [-1,1] +// focal_length_x: In screen space mat33 compute_rotation_correction(const cv::Point2f &p, float focal_length_x) { return rotation_from_two_vectors( @@ -144,6 +143,20 @@ T iou(const cv::Rect_<T> &a, const cv::Rect_<T> &b) return double{i.area()} / (a.area()+b.area()-i.area()); } +// Returns width and height of the input tensor, or throws. +// Expects the model to take one tensor as input that must +// have the shape B x C x H x W, where B=C=1. +cv::Size get_input_image_shape(const Ort::Session &session) +{ + if (session.GetInputCount() != 1) + throw std::invalid_argument("Model must take exactly one input tensor"); + const std::vector<std::int64_t> shape = + session.GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape(); + if (shape.size() != 4) + throw std::invalid_argument("Model takes the input tensor in the wrong shape"); + return { static_cast<int>(shape[3]), static_cast<int>(shape[2]) }; +} + } // namespace @@ -222,13 +235,16 @@ double Localizer::last_inference_time_millis() const } -PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&session) : - session{std::move(session)}, - scaled_frame(input_img_height, input_img_width, CV_8U), - input_mat(input_img_height, input_img_width, CV_32F) +PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&_session) : + session{std::move(_session)} { + const cv::Size input_image_shape = get_input_image_shape(session); + + scaled_frame = cv::Mat(input_image_shape, CV_8U); + input_mat = cv::Mat(input_image_shape, CV_32F); + { - const std::int64_t input_shape[4] = { 1, 1, input_img_height, input_img_width }; + const std::int64_t input_shape[4] = { 1, 1, input_image_shape.height, input_image_shape.width }; input_val = Ort::Value::CreateTensor<float>(allocator_info, input_mat.ptr<float>(0), input_mat.total(), input_shape, 4); } @@ -295,7 +311,7 @@ std::optional<PoseEstimator::Face> PoseEstimator::run( auto p = input_mat.ptr(0); - cv::resize(cropped, scaled_frame, { input_img_width, input_img_height }, 0, 0, cv::INTER_AREA); + cv::resize(cropped, scaled_frame, scaled_frame.size(), 0, 0, cv::INTER_AREA); // Automatic brightness amplification. const int brightness = find_input_intensity_90_pct_quantile(); @@ -306,7 +322,6 @@ std::optional<PoseEstimator::Face> PoseEstimator::run( assert (input_mat.ptr(0) == p); assert (!input_mat.empty() && input_mat.isContinuous()); - assert (input_mat.cols == input_img_width && input_mat.rows == input_img_height); const char* input_names[] = {"x"}; const char* output_names[] = {"pos_size", "quat", "box"}; diff --git a/tracker-neuralnet/ftnoir_tracker_neuralnet.h b/tracker-neuralnet/ftnoir_tracker_neuralnet.h index 5f9c6fbe..c74a7b1d 100644 --- a/tracker-neuralnet/ftnoir_tracker_neuralnet.h +++ b/tracker-neuralnet/ftnoir_tracker_neuralnet.h @@ -118,8 +118,6 @@ class PoseEstimator private: // Operates on the private image data members int find_input_intensity_90_pct_quantile() const; - inline static constexpr int input_img_width = 129; - inline static constexpr int input_img_height = 129; Ort::Session session{nullptr}; // Inputs cv::Mat scaled_frame{}, input_mat{}; |