Skip to content

Commit 4944e5f

Browse files
committed
Updating OrtInterface with const inference implementing cluster error parametrization by NN
1 parent 301b925 commit 4944e5f

File tree

6 files changed

+63
-25
lines changed

6 files changed

+63
-25
lines changed

Common/ML/include/ML/OrtInterface.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,16 @@ class OrtModel
9191

9292
// Inferencing
9393
template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. OrtDataType::Float16_t from O2/Common/ML/include/ML/GPUORTFloat16.h
94-
std::vector<O> inference(std::vector<I>&);
94+
std::vector<O> inference(std::vector<I>&) const;
9595

9696
template <class I, class O>
97-
std::vector<O> inference(std::vector<std::vector<I>>&);
97+
std::vector<O> inference(std::vector<std::vector<I>>&) const;
9898

9999
template <class I, class O>
100-
void inference(I*, int64_t, O*);
100+
void inference(I*, int64_t, O*) const;
101101

102102
template <class I, class O>
103-
void inference(I**, int64_t, O*);
103+
void inference(I**, int64_t, O*) const;
104104

105105
void release(bool = false);
106106

@@ -112,7 +112,8 @@ class OrtModel
112112
// Input & Output specifications of the loaded network
113113
std::vector<const char*> mInputNamesChar, mOutputNamesChar;
114114
std::vector<std::string> mInputNames, mOutputNames;
115-
std::vector<std::vector<int64_t>> mInputShapes, mOutputShapes, mInputShapesCopy, mOutputShapesCopy; // Input shapes
115+
std::vector<std::vector<int64_t>> mInputShapes, mOutputShapes;
116+
mutable std::vector<std::vector<int64_t>> mInputShapesCopy, mOutputShapesCopy; // Input shapes
116117
std::vector<int64_t> mInputSizePerNode, mOutputSizePerNode; // Output shapes
117118
int32_t mInputsTotal = 0, mOutputsTotal = 0; // Total number of inputs and outputs
118119

Common/ML/src/OrtInterface.cxx

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ void OrtModel::setEnv(Ort::Env* env)
289289

290290
// Inference
291291
template <class I, class O>
292-
std::vector<O> OrtModel::inference(std::vector<I>& input)
292+
std::vector<O> OrtModel::inference(std::vector<I>& input) const
293293
{
294294
std::vector<int64_t> inputShape = mInputShapes[0];
295295
inputShape[0] = input.size();
@@ -310,12 +310,12 @@ std::vector<O> OrtModel::inference(std::vector<I>& input)
310310
return outputValuesVec;
311311
}
312312

313-
template std::vector<float> o2::ml::OrtModel::inference<float, float>(std::vector<float>&);
314-
template std::vector<float> o2::ml::OrtModel::inference<OrtDataType::Float16_t, float>(std::vector<OrtDataType::Float16_t>&);
315-
template std::vector<OrtDataType::Float16_t> o2::ml::OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&);
313+
template std::vector<float> o2::ml::OrtModel::inference<float, float>(std::vector<float>&) const;
314+
template std::vector<float> o2::ml::OrtModel::inference<OrtDataType::Float16_t, float>(std::vector<OrtDataType::Float16_t>&) const;
315+
template std::vector<OrtDataType::Float16_t> o2::ml::OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&) const;
316316

317317
template <class I, class O>
318-
void OrtModel::inference(I* input, int64_t input_size, O* output)
318+
void OrtModel::inference(I* input, int64_t input_size, O* output) const
319319
{
320320
// std::vector<std::string> providers = Ort::GetAvailableProviders();
321321
// for (const auto& provider : providers) {
@@ -350,13 +350,13 @@ void OrtModel::inference(I* input, int64_t input_size, O* output)
350350
// mOutputNamesChar.size());
351351
}
352352

353-
template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t*, int64_t, OrtDataType::Float16_t*);
354-
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, int64_t, float*);
355-
template void OrtModel::inference<float, OrtDataType::Float16_t>(float*, int64_t, OrtDataType::Float16_t*);
356-
template void OrtModel::inference<float, float>(float*, int64_t, float*);
353+
template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t*, int64_t, OrtDataType::Float16_t*) const;
354+
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, int64_t, float*) const;
355+
template void OrtModel::inference<float, OrtDataType::Float16_t>(float*, int64_t, OrtDataType::Float16_t*) const;
356+
template void OrtModel::inference<float, float>(float*, int64_t, float*) const;
357357

358358
template <class I, class O>
359-
void OrtModel::inference(I** input, int64_t input_size, O* output)
359+
void OrtModel::inference(I** input, int64_t input_size, O* output) const
360360
{
361361
std::vector<Ort::Value> inputTensors(mInputShapesCopy.size());
362362

@@ -410,13 +410,13 @@ void OrtModel::inference(I** input, int64_t input_size, O* output)
410410
mOutputNamesChar.size());
411411
}
412412

413-
template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t**, int64_t, OrtDataType::Float16_t*);
414-
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t**, int64_t, float*);
415-
template void OrtModel::inference<float, OrtDataType::Float16_t>(float**, int64_t, OrtDataType::Float16_t*);
416-
template void OrtModel::inference<float, float>(float**, int64_t, float*);
413+
template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t**, int64_t, OrtDataType::Float16_t*) const;
414+
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t**, int64_t, float*) const;
415+
template void OrtModel::inference<float, OrtDataType::Float16_t>(float**, int64_t, OrtDataType::Float16_t*) const;
416+
template void OrtModel::inference<float, float>(float**, int64_t, float*) const;
417417

418418
template <class I, class O>
419-
std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& inputs)
419+
std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& inputs) const
420420
{
421421
std::vector<Ort::Value> input_tensors;
422422

@@ -461,8 +461,8 @@ std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& inputs)
461461
return output_vec;
462462
}
463463

464-
template std::vector<float> OrtModel::inference<float, float>(std::vector<std::vector<float>>&);
465-
template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<std::vector<OrtDataType::Float16_t>>&);
464+
template std::vector<float> OrtModel::inference<float, float>(std::vector<std::vector<float>>&) const;
465+
template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<std::vector<OrtDataType::Float16_t>>&) const;
466466

467467
// Release session
468468
void OrtModel::release(bool profilingEnabled)

GPU/GPUTracking/Base/GPUParam.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include "GPUSettings.h"
2222
#include "GPUTPCGMPolynomialField.h"
2323

24+
#include "ML/OrtInterface.h"
25+
2426
#if !defined(GPUCA_GPUCODE)
2527
namespace o2::base
2628
{
@@ -67,6 +69,10 @@ struct GPUParam_t {
6769

6870
GPUParamSector SectorParam[GPUCA_NSECTORS];
6971

72+
std::unordered_map<std::string, std::string> mOrtOptions;
73+
o2::ml::OrtModel mModelClusterErrors; // For cluster error estimation
74+
bool useClusterErrorNetwork = false; // Whether to use the cluster error network at all, can be set to false to save time if not needed
75+
7076
protected:
7177
#ifdef GPUCA_TPC_GEOMETRY_O2
7278
float ParamErrors[2][4][4]; // cluster error parameterization used during seeding and fit
@@ -87,6 +93,15 @@ struct GPUParam : public internal::GPUParam_t<GPUSettingsRec, GPUSettingsParam>
8793
void UpdateRun3ClusterErrors(const float* yErrorParam, const float* zErrorParam);
8894
#endif
8995

96+
void initClusterErrorModel(const GPUSettingsProcessing* p) {
97+
useClusterErrorNetwork = p->nn.nnUseClusterErrorNetwork;
98+
mOrtOptions["model_path"] = p->nn.nnClusterErrorModelPath;
99+
mModelClusterErrors.initOptions(mOrtOptions);
100+
mModelClusterErrors.setIntraOpNumThreads(1);
101+
mModelClusterErrors.initEnvironment();
102+
mModelClusterErrors.initSession();
103+
}
104+
90105
GPUd() float Alpha(int32_t iSector) const
91106
{
92107
if (iSector >= GPUCA_NSECTORS / 2) {

GPU/GPUTracking/Definitions/GPUSettingsList.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,8 @@ AddOption(nnSigmoidTrafoClassThreshold, int, 1, "", 0, "If true (default), then
283283
AddOption(nnEvalMode, std::string, "c1:r1", "", 0, "Concatention of modes, e.g. c1:r1 (classification class 1, regression class 1)")
284284
AddOption(nnClusterizerUseClassification, int, 1, "", 0, "If 1, the classification output of the network is used to select clusters, else only the regression output is used and no clusters are rejected by classification")
285285
AddOption(nnClusterizerForceGpuInputFill, int, 0, "", 0, "Forces to use the fillInputNNGPU function")
286+
AddOption(nnUseClusterErrorNetwork, int, 1, "", 0, "If 1, the cluster error network is used to parametrize the cluster errors, else a fixed parametrization is used")
287+
AddOption(nnClusterErrorModelPath, std::string, "network_cluster_error.onnx", "", 0, "Network for cluster error parameterization")
286288
// CCDB
287289
AddOption(nnLoadFromCCDB, int, 0, "", 0, "If 1 networks are fetched from ccdb, else locally")
288290
AddOption(nnCCDBDumpToFile, int, 0, "", 0, "If 1, additionally dump fetched CCDB networks to nnLocalFolder")

GPU/GPUTracking/Global/GPUChainTracking.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ void GPUChainTracking::RegisterPermanentMemoryAndProcessors()
7878
{
7979
fpdumperr = fopen("dump_cluster_error.csv", "a");
8080
fpdumptrk = fopen("dump_trk_index.csv", "a");
81-
fprintf(fpdumperr, "internal_trkid,cluster.num,err2Y,err2Z,clusterState,clusterY,clusterZ,mP[0],mP[1],mP[2],mP[3],mP[4],mC[0],mC[2],mC[5],mC[9],mC[14]\n");
81+
fprintf(fpdumperr, "internal_trkid,cluster.num,err2Y,err2Z,clusterState,xx,yy,zz,mP[0],mP[1],mP[2],mP[3],mP[4],mC[0],mC[2],mC[5],mC[9],mC[14]\n");
8282
fprintf(fpdumptrk, "internal_trkid,trkid\n");
8383
if (mRec->IsGPU()) {
8484
mFlatObjectsShadow.InitGPUProcessor(mRec, GPUProcessor::PROCESSOR_TYPE_SLAVE);

GPU/GPUTracking/Merger/GPUTPCGMTrackParam.cxx

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,30 @@ GPUd() bool GPUTPCGMTrackParam::Fit(GPUTPCGMMerger* GPUrestrict() merger, int32_
279279
const float invCharge = merger->GetConstantMem()->ioPtrs.clustersNative ? (1.f / merger->GetConstantMem()->ioPtrs.clustersNative->clustersLinear[cluster.num].qMax) : 0.f;
280280
float invAvgCharge = (sumInvSqrtCharge += invSqrtCharge) / ++nAvgCharge;
281281
invAvgCharge *= invAvgCharge;
282-
prop.GetErr2(err2Y, err2Z, param, zz, cluster.row, clusterState, cluster.sector, time, invAvgCharge, invCharge);
282+
if(param.useClusterErrorNetwork){
283+
float inputFeatures[12];
284+
float outputFeatures[2];
285+
inputFeatures[0] = static_cast<float>(clusterState);
286+
inputFeatures[1] = xx;
287+
inputFeatures[2] = yy;
288+
inputFeatures[3] = zz;
289+
inputFeatures[4] = mP[2];
290+
inputFeatures[5] = mP[3];
291+
inputFeatures[6] = mP[4];
292+
inputFeatures[7] = mC[0];
293+
inputFeatures[8] = mC[2];
294+
inputFeatures[9] = mC[5];
295+
inputFeatures[10] = mC[9];
296+
inputFeatures[11] = mC[14];
297+
param.mModelClusterErrors.inference(inputFeatures, (int64_t)1, outputFeatures);
298+
err2Y = outputFeatures[0];
299+
err2Z = outputFeatures[1];
300+
} else {
301+
prop.GetErr2(err2Y, err2Z, param, zz, cluster.row, clusterState, cluster.sector, time, invAvgCharge, invCharge);
302+
}
283303

284304
#ifndef GPUCA_GPUCODE
285-
fprintf(fpdumperr, "%d,%d,%f,%f,%d,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f\n", iTrk, cluster.num, err2Y, err2Z, clusterState, yy, zz, mP[0], mP[1], mP[2], mP[3], mP[4], mC[0], mC[2], mC[5], mC[9], mC[14]);
305+
fprintf(fpdumperr, "%d,%d,%f,%f,%d,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f\n", iTrk, cluster.num, err2Y, err2Z, clusterState, xx, yy, zz, mP[0], mP[1], mP[2], mP[3], mP[4], mC[0], mC[2], mC[5], mC[9], mC[14]);
286306
#endif
287307

288308
if (rejectChi2 >= GPUTPCGMPropagator::rejectInterFill) {

0 commit comments

Comments
 (0)