SourceXtractorPlusPlus  0.15
Please provide a description of the project.
OnnxTaskFactory.cpp
Go to the documentation of this file.
1 
23 #include <NdArray/NdArray.h>
25 #include <onnxruntime_cxx_api.h>
26 
27 namespace SourceXtractor {
28 
29 // There can be only one!
30 static Ort::Env ORT_ENV;
31 
35 static std::string generatePropertyName(const OnnxModel& model_info, OrtAllocator* allocator) {
36  std::stringstream prop_name;
37 
38  std::string domain = model_info.m_session->GetModelMetadata().GetDomain(allocator);
39  if (!domain.empty()) {
40  prop_name << domain << '.';
41  }
42 
43  std::string graph_name = model_info.m_session->GetModelMetadata().GetGraphName(allocator);
44  if (!graph_name.empty()) {
45  prop_name << graph_name << '.';
46  }
47 
48  prop_name << model_info.m_output_name;
49 
50  return prop_name.str();
51 }
52 
57  std::ostringstream stream;
58  for (auto i = shape.begin(); i != shape.end() - 1; ++i) {
59  stream << *i << " x ";
60  }
61  stream << shape.back();
62  return stream.str();
63 }
64 
66 
68  if (property_id == PropertyId::create<OnnxProperty>()) {
69  return std::make_shared<OnnxSourceTask>(m_models);
70  }
71  return nullptr;
72 }
73 
76 }
77 
79  auto allocator = Ort::AllocatorWithDefaultOptions();
80 
81  const auto& onnx_config = manager.getConfiguration<OnnxConfig>();
82  const auto& models = onnx_config.getModels();
83 
84  for (auto model_path : models) {
85  OnnxModel model_info;
86  model_info.m_model_path = model_path;
87 
88  onnx_logger.info() << "Loading ONNX model " << model_path;
89  model_info.m_session = Euclid::make_unique<Ort::Session>(ORT_ENV, model_path.c_str(), Ort::SessionOptions{nullptr});
90 
91  if (model_info.m_session->GetInputCount() != 1) {
92  throw Elements::Exception() << "Only ONNX models with a single input tensor are supported";
93  }
94  if (model_info.m_session->GetOutputCount() != 1) {
95  throw Elements::Exception() << "Only ONNX models with a single output tensor are supported";
96  }
97 
98  model_info.m_input_name = model_info.m_session->GetInputName(0, allocator);
99  model_info.m_output_name = model_info.m_session->GetOutputName(0, allocator);
100 
101  model_info.m_prop_name = generatePropertyName(model_info, allocator);
102 
103  onnx_logger.info() << "Output name will be " << model_info.m_prop_name;
104 
105  auto input_type = model_info.m_session->GetInputTypeInfo(0);
106  auto output_type = model_info.m_session->GetOutputTypeInfo(0);
107 
108  model_info.m_input_shape = input_type.GetTensorTypeAndShapeInfo().GetShape();
109  model_info.m_input_type = input_type.GetTensorTypeAndShapeInfo().GetElementType();
110  model_info.m_output_shape = output_type.GetTensorTypeAndShapeInfo().GetShape();
111  model_info.m_output_type = output_type.GetTensorTypeAndShapeInfo().GetElementType();
112 
113  if (model_info.m_input_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
114  throw Elements::Exception() << "Only ONNX models with float input are supported";
115  }
116  if (model_info.m_input_shape.size() != 4) {
117  throw Elements::Exception() << "Expected 4 axes for the input layer, got " << model_info.m_input_shape.size();
118  }
119 
120  onnx_logger.info() << "ONNX model with input of " << formatShape(model_info.m_input_shape);
121  onnx_logger.info() << "ONNX model with output of " << formatShape(model_info.m_output_shape);
122 
123  m_models.emplace_back(std::move(model_info));
124  }
125 }
126 
127 template<typename T>
128 static void registerColumnConverter(OutputRegistry& registry, const OnnxModel& model) {
129  auto key = model.m_prop_name;
130 
132  model.m_prop_name, [key](const OnnxProperty& prop) {
133  return prop.getData<T>(key);
134  }, "", model.m_model_path
135  );
136 }
137 
139  for (const auto& model : m_models) {
140  switch (model.m_output_type) {
141  case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
142  registerColumnConverter<float>(registry, model);
143  break;
144  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
145  registerColumnConverter<int32_t>(registry, model);
146  break;
147  default:
148  throw Elements::Exception() << "Unsupported output type: " << model.m_output_type;
149  }
150  }
151 }
152 
153 } // end of namespace SourceXtractor
OnnxTaskFactory.h
SourceXtractor::generatePropertyName
static std::string generatePropertyName(const OnnxModel &model_info, OrtAllocator *allocator)
Definition: OnnxTaskFactory.cpp:35
Euclid::Configuration::ConfigManager::registerConfiguration
void registerConfiguration()
std::string
STL class.
std::shared_ptr
STL class.
SourceXtractor::OnnxModel::m_input_shape
std::vector< std::int64_t > m_input_shape
Input tensor shape.
Definition: OnnxModel.h:38
std::move
T move(T... args)
SourceXtractor::OnnxTaskFactory::m_models
std::vector< OnnxModel > m_models
Definition: OnnxTaskFactory.h:48
SourceXtractor::OnnxModel::m_output_name
std::string m_output_name
Output tensor name.
Definition: OnnxModel.h:35
SourceXtractor::OnnxTaskFactory::registerPropertyInstances
void registerPropertyInstances(OutputRegistry &registry) override
Definition: OnnxTaskFactory.cpp:138
Euclid::Configuration::ConfigManager
std::vector
STL class.
std::vector::size
T size(T... args)
SourceXtractor::OutputRegistry
Definition: OutputRegistry.h:36
SourceXtractor::registerColumnConverter
static void registerColumnConverter(OutputRegistry &registry, const OnnxModel &model)
Definition: OnnxTaskFactory.cpp:128
std::stringstream
STL class.
Euclid::NdArray::NdArray
Euclid::Configuration::ConfigManager::getConfiguration
T & getConfiguration()
std::vector::back
T back(T... args)
SourceXtractor::OnnxTaskFactory::createTask
std::shared_ptr< Task > createTask(const PropertyId &property_id) const override
Returns a Task producing a Property corresponding to the given PropertyId.
Definition: OnnxTaskFactory.cpp:67
SourceXtractor::OnnxConfig::getModels
const std::vector< std::string > & getModels() const
Definition: OnnxConfig.h:44
SourceXtractor::PropertyId
Identifier used to set and retrieve properties.
Definition: PropertyId.h:40
SourceXtractor
Definition: Aperture.h:30
SourceXtractor::OnnxTaskFactory::OnnxTaskFactory
OnnxTaskFactory()
Definition: OnnxTaskFactory.cpp:65
OnnxSourceTask.h
Elements::Exception
SourceXtractor::OnnxConfig
Definition: OnnxConfig.h:28
SourceXtractor::OnnxModel::m_model_path
std::string m_model_path
Path to the ONNX model.
Definition: OnnxModel.h:40
SourceXtractor::onnx_logger
Elements::Logging onnx_logger
Logger for the ONNX plugin.
Definition: OnnxPlugin.cpp:26
Elements::Logging::info
void info(const std::string &logMessage)
SourceXtractor::OnnxModel::m_input_name
std::string m_input_name
Input tensor name.
Definition: OnnxModel.h:34
SourceXtractor::OnnxProperty
Definition: OnnxProperty.h:30
NdArray.h
SourceXtractor::OnnxModel::m_prop_name
std::string m_prop_name
Name that will be written into the catalog.
Definition: OnnxModel.h:33
std::ostringstream
STL class.
SourceXtractor::OnnxModel::m_output_shape
std::vector< std::int64_t > m_output_shape
Output tensor shape.
Definition: OnnxModel.h:39
SourceXtractor::OnnxModel
Definition: OnnxModel.h:32
SourceXtractor::OnnxModel::m_session
std::unique_ptr< Ort::Session > m_session
Session, one per model. In theory, it is thread-safe.
Definition: OnnxModel.h:41
OnnxPlugin.h
std::vector::begin
T begin(T... args)
SourceXtractor::OnnxTaskFactory::configure
void configure(Euclid::Configuration::ConfigManager &manager) override
Method which should initialize the object.
Definition: OnnxTaskFactory.cpp:78
SourceXtractor::OnnxTaskFactory::reportConfigDependencies
void reportConfigDependencies(Euclid::Configuration::ConfigManager &manager) const override
Registers all the Configuration dependencies.
Definition: OnnxTaskFactory.cpp:74
SourceXtractor::ORT_ENV
static Ort::Env ORT_ENV
Definition: OnnxTaskFactory.cpp:30
SourceXtractor::OnnxModel::m_output_type
ONNXTensorElementDataType m_output_type
Output type.
Definition: OnnxModel.h:37
std::string::empty
T empty(T... args)
std::allocator
STL class.
std::stringstream::str
T str(T... args)
std::vector::end
T end(T... args)
OnnxProperty.h
OnnxConfig.h
memory_tools.h
SourceXtractor::formatShape
static std::string formatShape(const std::vector< int64_t > &shape)
Definition: OnnxTaskFactory.cpp:56
SourceXtractor::OutputRegistry::registerColumnConverter
void registerColumnConverter(std::string column_name, ColumnConverter< PropertyType, OutType > converter, std::string column_unit="", std::string column_description="")
Definition: OutputRegistry.h:46
SourceXtractor::OnnxModel::m_input_type
ONNXTensorElementDataType m_input_type
Input type.
Definition: OnnxModel.h:36