SourceXtractorPlusPlus  0.15
Please provide a description of the project.
OnnxSourceTask.cpp
Go to the documentation of this file.
1 
22 #include <NdArray/NdArray.h>
24 #include <onnxruntime_cxx_api.h>
25 
26 namespace NdArray = Euclid::NdArray;
27 
28 namespace SourceXtractor {
29 
30 
31 template<typename T>
32 static void fillCutout(const Image<T>& image, int center_x, int center_y, int width, int height, std::vector<T>& out) {
33  int x_start = center_x - width / 2;
34  int y_start = center_y - height / 2;
35  int x_end = x_start + width;
36  int y_end = y_start + height;
37 
38  ImageAccessor<T> accessor(image);
39 
40  int index = 0;
41  for (int iy = y_start; iy < y_end; iy++) {
42  for (int ix = x_start; ix < x_end; ix++, index++) {
43  if (ix >= 0 && iy >= 0 && ix < image.getWidth() && iy < image.getHeight()) {
44  out[index] = accessor.getValue(ix, iy);
45  }
46  }
47  }
48 }
49 
50 OnnxSourceTask::OnnxSourceTask(const std::vector<OnnxModel>& models) : m_models(models) {}
51 
59 template<typename O>
61 computePropertiesSpecialized(const OnnxModel& model, const DetectionFrameImages& detection_frame_images,
62  const PixelCentroid& centroid) {
63  Ort::RunOptions run_options;
64  auto mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
65 
66  const int center_x = static_cast<int>(centroid.getCentroidX() + 0.5);
67  const int center_y = static_cast<int>(centroid.getCentroidY() + 0.5);
68 
69  // Allocate memory
70  std::vector<int64_t> input_shape(model.m_input_shape.begin(), model.m_input_shape.end());
71  input_shape[0] = 1;
72  size_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1u, std::multiplies<size_t>());
73  std::vector<float> input_data(input_size);
74 
75  std::vector<int64_t> output_shape(model.m_output_shape.begin(), model.m_output_shape.end());
76  output_shape[0] = 1;
77  size_t output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1u, std::multiplies<size_t>());
78  std::vector<O> output_data(output_size);
79 
80  // Cut the needed area
81  {
82  const auto& image = detection_frame_images.getLockedImage(LayerSubtractedImage);
83  fillCutout(*image, center_x, center_y, input_shape[2], input_shape[3], input_data);
84  }
85 
86  // Setup input/output tensors
87  auto input_tensor = Ort::Value::CreateTensor<float>(
88  mem_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size());
89  auto output_tensor = Ort::Value::CreateTensor<O>(
90  mem_info, output_data.data(), output_data.size(), output_shape.data(), output_shape.size());
91 
92  // Run the model
93  const char *input_name = model.m_input_name.c_str();
94  const char *output_name = model.m_output_name.c_str();
95  model.m_session->Run(run_options,
96  &input_name, &input_tensor, 1,
97  &output_name, &output_tensor, 1);
98 
99  // Set the output
100  std::vector<size_t> catalog_shape{model.m_output_shape.begin() + 1, model.m_output_shape.end()};
101  return Euclid::make_unique<OnnxProperty::NdWrapper<O>>(catalog_shape, output_data);
102 }
103 
105  const auto& detection_frame_images = source.getProperty<DetectionFrameImages>();
106  const auto& centroid = source.getProperty<PixelCentroid>();
107 
109 
110  for (const auto& model : m_models) {
112 
113  switch (model.m_output_type) {
114  case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
115  result = computePropertiesSpecialized<float>(model, detection_frame_images, centroid);
116  break;
117  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
118  result = computePropertiesSpecialized<int32_t>(model, detection_frame_images, centroid);
119  break;
120  default:
121  throw Elements::Exception() << "This should have not happened!" << model.m_output_type;
122  }
123 
124  output_dict.emplace(model.m_prop_name, std::move(result));
125  }
126 
127  source.setProperty<OnnxProperty>(std::move(output_dict));
128 }
129 
130 } // end of namespace SourceXtractor
SourceXtractor::DetectionFrameImages
Definition: DetectionFrameImages.h:30
SourceXtractor::OnnxSourceTask::m_models
const std::vector< OnnxModel > & m_models
Definition: OnnxSourceTask.h:48
SourceXtractor::ImageAccessor
Definition: ImageAccessor.h:41
Euclid::NdArray
NdArray
Euclid::NdArray::NdArray< T > NdArray
Definition: GrowthCurvePlugin.cpp:26
SourceXtractor::fillCutout
static void fillCutout(const Image< T > &image, int center_x, int center_y, int width, int height, std::vector< T > &out)
Definition: OnnxSourceTask.cpp:32
SourceXtractor::OnnxModel::m_input_shape
std::vector< std::int64_t > m_input_shape
Input tensor shape.
Definition: OnnxModel.h:38
SourceXtractor::DetectionFrameImages::getLockedImage
std::shared_ptr< ImageAccessor< SeFloat > > getLockedImage(FrameImageLayer layer) const
Definition: DetectionFrameImages.h:38
SourceXtractor::Image::getWidth
virtual int getWidth() const =0
Returns the width of the image in pixels.
std::move
T move(T... args)
SourceXtractor::OnnxModel::m_output_name
std::string m_output_name
Output tensor name.
Definition: OnnxModel.h:35
SourceXtractor::PixelCentroid
The centroid of all the pixels in the source, weighted by their DetectionImage pixel values.
Definition: PixelCentroid.h:37
SourceXtractor::PixelCentroid::getCentroidX
SeFloat getCentroidX() const
X coordinate of centroid.
Definition: PixelCentroid.h:48
std::vector
STL class.
std::vector::size
T size(T... args)
SourceXtractor::Image::getHeight
virtual int getHeight() const =0
Returns the height of the image in pixels.
SourceXtractor::Image
Interface representing an image.
Definition: Image.h:43
SourceXtractor::PixelCentroid::getCentroidY
SeFloat getCentroidY() const
Y coordinate of centroid.
Definition: PixelCentroid.h:53
std::multiplies
SourceXtractor
Definition: Aperture.h:30
std::string::c_str
T c_str(T... args)
SourceXtractor::computePropertiesSpecialized
static std::unique_ptr< OnnxProperty::NdWrapperBase > computePropertiesSpecialized(const OnnxModel &model, const DetectionFrameImages &detection_frame_images, const PixelCentroid &centroid)
Definition: OnnxSourceTask.cpp:61
OnnxSourceTask.h
Elements::Exception
std::accumulate
T accumulate(T... args)
std::map
STL class.
SourceXtractor::OnnxSourceTask::OnnxSourceTask
OnnxSourceTask(const std::vector< OnnxModel > &models)
Definition: OnnxSourceTask.cpp:50
SourceXtractor::ImageAccessor::getValue
T getValue(int x, int y)
Definition: ImageAccessor.h:100
SourceXtractor::OnnxModel::m_input_name
std::string m_input_name
Input tensor name.
Definition: OnnxModel.h:34
SourceXtractor::OnnxProperty
Definition: OnnxProperty.h:30
NdArray.h
SourceXtractor::OnnxModel::m_output_shape
std::vector< std::int64_t > m_output_shape
Output tensor shape.
Definition: OnnxModel.h:39
SourceXtractor::OnnxModel
Definition: OnnxModel.h:32
SourceXtractor::OnnxModel::m_session
std::unique_ptr< Ort::Session > m_session
Session, one per model. In theory, it is thread-safe.
Definition: OnnxModel.h:41
std::vector::begin
T begin(T... args)
SourceXtractor::SourceInterface::getProperty
const PropertyType & getProperty(unsigned int index=0) const
Convenience template method to call getProperty() with a more user-friendly syntax.
Definition: SourceInterface.h:57
std::vector::end
T end(T... args)
SourceXtractor::LayerSubtractedImage
@ LayerSubtractedImage
Definition: Frame.h:38
SourceXtractor::SourceInterface
The SourceInterface is an abstract "source" that has properties attached to it.
Definition: SourceInterface.h:46
OnnxProperty.h
std::unique_ptr
STL class.
memory_tools.h
SourceXtractor::SourceInterface::setProperty
void setProperty(Args... args)
Definition: SourceInterface.h:72
std::vector::data
T data(T... args)
PixelCentroid.h
SourceXtractor::OnnxSourceTask::computeProperties
void computeProperties(SourceInterface &source) const override
Computes one or more properties for the Source.
Definition: OnnxSourceTask.cpp:104
DetectionFrameImages.h