summaryrefslogtreecommitdiffhomepage
path: root/tracker-neuralnet
diff options
context:
space:
mode:
authorMichael Welter <michael@welter-4d.de>2021-07-06 20:55:36 +0200
committerMichael Welter <michael@welter-4d.de>2022-05-15 16:53:20 +0200
commit4697a0e44536ae9dc169aab54df26e5dbebdec7c (patch)
tree37a2604459bc4a02acc6b1fe996443cbfdce34b2 /tracker-neuralnet
parentf58bc9de86e79df8f7a30fe1e21821bdabfb165e (diff)
tracker/nn: Don't hardcode the size of the input image for the pose estimator
Diffstat (limited to 'tracker-neuralnet')
-rw-r--r--tracker-neuralnet/ftnoir_tracker_neuralnet.cpp37
-rw-r--r--tracker-neuralnet/ftnoir_tracker_neuralnet.h2
2 files changed, 26 insertions, 13 deletions
diff --git a/tracker-neuralnet/ftnoir_tracker_neuralnet.cpp b/tracker-neuralnet/ftnoir_tracker_neuralnet.cpp
index 50d1e7b4..320c8f23 100644
--- a/tracker-neuralnet/ftnoir_tracker_neuralnet.cpp
+++ b/tracker-neuralnet/ftnoir_tracker_neuralnet.cpp
@@ -105,10 +105,9 @@ mat33 rotation_from_two_vectors(const vec3 &a, const vec3 &b)
}
-/* Computes correction due to head being off screen center.
- x, y: In screen space, i.e. in [-1,1]
- focal_length_x: In screen space
-*/
+// Computes correction due to head being off screen center.
+// x, y: In screen space, i.e. in [-1,1]
+// focal_length_x: In screen space
mat33 compute_rotation_correction(const cv::Point2f &p, float focal_length_x)
{
return rotation_from_two_vectors(
@@ -144,6 +143,20 @@ T iou(const cv::Rect_<T> &a, const cv::Rect_<T> &b)
return double{i.area()} / (a.area()+b.area()-i.area());
}
+// Returns width and height of the input tensor, or throws.
+// Expects the model to take one tensor as input that must
+// have the shape B x C x H x W, where B=C=1.
+cv::Size get_input_image_shape(const Ort::Session &session)
+{
+ if (session.GetInputCount() != 1)
+ throw std::invalid_argument("Model must take exactly one input tensor");
+ const std::vector<std::int64_t> shape =
+ session.GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
+ if (shape.size() != 4)
+ throw std::invalid_argument("Model takes the input tensor in the wrong shape");
+ return { static_cast<int>(shape[3]), static_cast<int>(shape[2]) };
+}
+
} // namespace
@@ -222,13 +235,16 @@ double Localizer::last_inference_time_millis() const
}
-PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&session) :
- session{std::move(session)},
- scaled_frame(input_img_height, input_img_width, CV_8U),
- input_mat(input_img_height, input_img_width, CV_32F)
+PoseEstimator::PoseEstimator(Ort::MemoryInfo &allocator_info, Ort::Session &&_session) :
+ session{std::move(_session)}
{
+ const cv::Size input_image_shape = get_input_image_shape(session);
+
+ scaled_frame = cv::Mat(input_image_shape, CV_8U);
+ input_mat = cv::Mat(input_image_shape, CV_32F);
+
{
- const std::int64_t input_shape[4] = { 1, 1, input_img_height, input_img_width };
+ const std::int64_t input_shape[4] = { 1, 1, input_image_shape.height, input_image_shape.width };
input_val = Ort::Value::CreateTensor<float>(allocator_info, input_mat.ptr<float>(0), input_mat.total(), input_shape, 4);
}
@@ -295,7 +311,7 @@ std::optional<PoseEstimator::Face> PoseEstimator::run(
auto p = input_mat.ptr(0);
- cv::resize(cropped, scaled_frame, { input_img_width, input_img_height }, 0, 0, cv::INTER_AREA);
+ cv::resize(cropped, scaled_frame, scaled_frame.size(), 0, 0, cv::INTER_AREA);
// Automatic brightness amplification.
const int brightness = find_input_intensity_90_pct_quantile();
@@ -306,7 +322,6 @@ std::optional<PoseEstimator::Face> PoseEstimator::run(
assert (input_mat.ptr(0) == p);
assert (!input_mat.empty() && input_mat.isContinuous());
- assert (input_mat.cols == input_img_width && input_mat.rows == input_img_height);
const char* input_names[] = {"x"};
const char* output_names[] = {"pos_size", "quat", "box"};
diff --git a/tracker-neuralnet/ftnoir_tracker_neuralnet.h b/tracker-neuralnet/ftnoir_tracker_neuralnet.h
index 5f9c6fbe..c74a7b1d 100644
--- a/tracker-neuralnet/ftnoir_tracker_neuralnet.h
+++ b/tracker-neuralnet/ftnoir_tracker_neuralnet.h
@@ -118,8 +118,6 @@ class PoseEstimator
private:
// Operates on the private image data members
int find_input_intensity_90_pct_quantile() const;
- inline static constexpr int input_img_width = 129;
- inline static constexpr int input_img_height = 129;
Ort::Session session{nullptr};
// Inputs
cv::Mat scaled_frame{}, input_mat{};