Skip to content

Commit 3d5b10b

Browse files
committed
Adding cluster error NN evaluation, CSV readout and scaling of error
1 parent 301b925 commit 3d5b10b

File tree

10 files changed

+161
-27
lines changed

10 files changed

+161
-27
lines changed

Common/ML/include/ML/OrtInterface.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,16 @@ class OrtModel
9191

9292
// Inferencing
9393
template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. OrtDataType::Float16_t from O2/Common/ML/include/ML/GPUORTFloat16.h
94-
std::vector<O> inference(std::vector<I>&);
94+
std::vector<O> inference(std::vector<I>&) const;
9595

9696
template <class I, class O>
97-
std::vector<O> inference(std::vector<std::vector<I>>&);
97+
std::vector<O> inference(std::vector<std::vector<I>>&) const;
9898

9999
template <class I, class O>
100-
void inference(I*, int64_t, O*);
100+
void inference(I*, int64_t, O*) const;
101101

102102
template <class I, class O>
103-
void inference(I**, int64_t, O*);
103+
void inference(I**, int64_t, O*) const;
104104

105105
void release(bool = false);
106106

@@ -112,7 +112,8 @@ class OrtModel
112112
// Input & Output specifications of the loaded network
113113
std::vector<const char*> mInputNamesChar, mOutputNamesChar;
114114
std::vector<std::string> mInputNames, mOutputNames;
115-
std::vector<std::vector<int64_t>> mInputShapes, mOutputShapes, mInputShapesCopy, mOutputShapesCopy; // Input shapes
115+
std::vector<std::vector<int64_t>> mInputShapes, mOutputShapes;
116+
mutable std::vector<std::vector<int64_t>> mInputShapesCopy, mOutputShapesCopy; // Input shapes
116117
std::vector<int64_t> mInputSizePerNode, mOutputSizePerNode; // Output shapes
117118
int32_t mInputsTotal = 0, mOutputsTotal = 0; // Total number of inputs and outputs
118119

Common/ML/src/OrtInterface.cxx

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ void OrtModel::setEnv(Ort::Env* env)
289289

290290
// Inference
291291
template <class I, class O>
292-
std::vector<O> OrtModel::inference(std::vector<I>& input)
292+
std::vector<O> OrtModel::inference(std::vector<I>& input) const
293293
{
294294
std::vector<int64_t> inputShape = mInputShapes[0];
295295
inputShape[0] = input.size();
@@ -310,12 +310,12 @@ std::vector<O> OrtModel::inference(std::vector<I>& input)
310310
return outputValuesVec;
311311
}
312312

313-
template std::vector<float> o2::ml::OrtModel::inference<float, float>(std::vector<float>&);
314-
template std::vector<float> o2::ml::OrtModel::inference<OrtDataType::Float16_t, float>(std::vector<OrtDataType::Float16_t>&);
315-
template std::vector<OrtDataType::Float16_t> o2::ml::OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&);
313+
template std::vector<float> o2::ml::OrtModel::inference<float, float>(std::vector<float>&) const;
314+
template std::vector<float> o2::ml::OrtModel::inference<OrtDataType::Float16_t, float>(std::vector<OrtDataType::Float16_t>&) const;
315+
template std::vector<OrtDataType::Float16_t> o2::ml::OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&) const;
316316

317317
template <class I, class O>
318-
void OrtModel::inference(I* input, int64_t input_size, O* output)
318+
void OrtModel::inference(I* input, int64_t input_size, O* output) const
319319
{
320320
// std::vector<std::string> providers = Ort::GetAvailableProviders();
321321
// for (const auto& provider : providers) {
@@ -350,13 +350,13 @@ void OrtModel::inference(I* input, int64_t input_size, O* output)
350350
// mOutputNamesChar.size());
351351
}
352352

353-
template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t*, int64_t, OrtDataType::Float16_t*);
354-
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, int64_t, float*);
355-
template void OrtModel::inference<float, OrtDataType::Float16_t>(float*, int64_t, OrtDataType::Float16_t*);
356-
template void OrtModel::inference<float, float>(float*, int64_t, float*);
353+
template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t*, int64_t, OrtDataType::Float16_t*) const;
354+
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, int64_t, float*) const;
355+
template void OrtModel::inference<float, OrtDataType::Float16_t>(float*, int64_t, OrtDataType::Float16_t*) const;
356+
template void OrtModel::inference<float, float>(float*, int64_t, float*) const;
357357

358358
template <class I, class O>
359-
void OrtModel::inference(I** input, int64_t input_size, O* output)
359+
void OrtModel::inference(I** input, int64_t input_size, O* output) const
360360
{
361361
std::vector<Ort::Value> inputTensors(mInputShapesCopy.size());
362362

@@ -410,13 +410,13 @@ void OrtModel::inference(I** input, int64_t input_size, O* output)
410410
mOutputNamesChar.size());
411411
}
412412

413-
template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t**, int64_t, OrtDataType::Float16_t*);
414-
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t**, int64_t, float*);
415-
template void OrtModel::inference<float, OrtDataType::Float16_t>(float**, int64_t, OrtDataType::Float16_t*);
416-
template void OrtModel::inference<float, float>(float**, int64_t, float*);
413+
template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t**, int64_t, OrtDataType::Float16_t*) const;
414+
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t**, int64_t, float*) const;
415+
template void OrtModel::inference<float, OrtDataType::Float16_t>(float**, int64_t, OrtDataType::Float16_t*) const;
416+
template void OrtModel::inference<float, float>(float**, int64_t, float*) const;
417417

418418
template <class I, class O>
419-
std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& inputs)
419+
std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& inputs) const
420420
{
421421
std::vector<Ort::Value> input_tensors;
422422

@@ -461,8 +461,8 @@ std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& inputs)
461461
return output_vec;
462462
}
463463

464-
template std::vector<float> OrtModel::inference<float, float>(std::vector<std::vector<float>>&);
465-
template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<std::vector<OrtDataType::Float16_t>>&);
464+
template std::vector<float> OrtModel::inference<float, float>(std::vector<std::vector<float>>&) const;
465+
template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<std::vector<OrtDataType::Float16_t>>&) const;
466466

467467
// Release session
468468
void OrtModel::release(bool profilingEnabled)

GPU/GPUTracking/Base/GPUParam.cxx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ void GPUParam::UpdateSettings(const GPUSettingsGRP* g, const GPUSettingsProcessi
125125
qptB5Scaler = CAMath::Abs(bzkG) > 0.1f ? CAMath::Abs(bzkG) / 5.006680f : 1.f; // Repeat here, since passing in g is optional
126126
if (p) {
127127
UpdateRun3ClusterErrors(p->param.tpcErrorParamY, p->param.tpcErrorParamZ);
128+
initClusterErrorModel(p->nn);
128129
}
129130
if (w) {
130131
par.dodEdx = dodEdxEnabled = w->steps.isSet(gpudatatypes::RecoStep::TPCdEdx);

GPU/GPUTracking/Base/GPUParam.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include "GPUSettings.h"
2222
#include "GPUTPCGMPolynomialField.h"
2323

24+
#include "ML/OrtInterface.h"
25+
2426
#if !defined(GPUCA_GPUCODE)
2527
namespace o2::base
2628
{
@@ -67,6 +69,11 @@ struct GPUParam_t {
6769

6870
GPUParamSector SectorParam[GPUCA_NSECTORS];
6971

72+
std::unique_ptr<o2::ml::OrtModel> mModelClusterErrors; // For cluster error estimation
73+
bool useClusterErrorNetwork = false; // Whether to use the cluster error network at all, can be set to false to save time if not needed
74+
bool dumpClusterErrorCSV = false;
75+
float scaleError = 1.f;
76+
7077
protected:
7178
#ifdef GPUCA_TPC_GEOMETRY_O2
7279
float ParamErrors[2][4][4]; // cluster error parameterization used during seeding and fit
@@ -87,6 +94,46 @@ struct GPUParam : public internal::GPUParam_t<GPUSettingsRec, GPUSettingsParam>
8794
void UpdateRun3ClusterErrors(const float* yErrorParam, const float* zErrorParam);
8895
#endif
8996

97+
void initClusterErrorModel(const GPUSettingsProcessingNNclusterizer& p) {
98+
useClusterErrorNetwork = p.nnUseClusterErrorNetwork;
99+
dumpClusterErrorCSV = p.dumpClusterErrorCSV;
100+
scaleError = p.nnScaleClusterError;
101+
if (useClusterErrorNetwork && !p.nnClusterErrorModelPath.empty()) {
102+
mModelClusterErrors = std::make_unique<o2::ml::OrtModel>();
103+
LOG(info) << "Loading cluster error network from " << p.nnClusterErrorModelPath;
104+
// LOG(info) << "use=" << p.nnUseClusterErrorNetwork
105+
// << " model=" << p.nnClusterErrorModelPath
106+
// << " dev=" << p.nnInferenceDevice
107+
// << " allocDevMem=" << p.nnInferenceAllocateDevMem
108+
// << " intra=" << p.nnInferenceIntraOpNumThreads
109+
// << " inter=" << p.nnInferenceInterOpNumThreads
110+
// << " opt=" << p.nnInferenceEnableOrtOptimization
111+
// << " det=" << p.nnInferenceUseDeterministicCompute
112+
// << " prof=" << p.nnInferenceOrtProfiling
113+
// << " verb=" << p.nnInferenceVerbosity;
114+
std::unordered_map<std::string, std::string> mOrtOptions = {
115+
{"model-path", p.nnClusterErrorModelPath},
116+
{"device-type", p.nnInferenceDevice},
117+
{"allocate-device-memory", std::to_string(p.nnInferenceAllocateDevMem)},
118+
{"intra-op-num-threads", "1"},
119+
{"inter-op-num-threads", "1"},
120+
{"enable-optimizations", std::to_string(p.nnInferenceEnableOrtOptimization)},
121+
{"deterministic-compute", std::to_string(p.nnInferenceUseDeterministicCompute)}, // TODO: This unfortunately doesn't guarantee determinism (25.07.2025)
122+
{"enable-profiling", std::to_string(p.nnInferenceOrtProfiling)},
123+
{"profiling-output-path", p.nnInferenceOrtProfilingPath},
124+
{"logging-level", std::to_string(p.nnInferenceVerbosity)},
125+
{"onnx-environment-name", "cluster_error"}
126+
};
127+
// LOG(info) << "NN cluster error options done!";
128+
mModelClusterErrors->initOptions(mOrtOptions);
129+
// LOG(info) << "NN cluster error options loaded!";
130+
mModelClusterErrors->initEnvironment();
131+
// LOG(info) << "NN cluster error environment initialized!";
132+
mModelClusterErrors->initSession();
133+
// LOG(info) << "NN cluster error session initialized!";
134+
}
135+
}
136+
90137
GPUd() float Alpha(int32_t iSector) const
91138
{
92139
if (iSector >= GPUCA_NSECTORS / 2) {

GPU/GPUTracking/Base/GPUReconstruction.cxx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,7 @@ void GPUReconstruction::SetSettings(const GPUSettingsGRP* grp, const GPUSettings
12781278
mRecoSteps.outputs = workflow->outputs;
12791279
}
12801280
param().SetDefaults(mGRPSettings.get(), rec, proc, workflow);
1281+
// param().initClusterErrorModel(proc->nn);
12811282
}
12821283

12831284
void GPUReconstruction::SetOutputControl(void* ptr, size_t size)

GPU/GPUTracking/Definitions/GPUSettingsList.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,10 @@ AddOption(nnSigmoidTrafoClassThreshold, int, 1, "", 0, "If true (default), then
283283
AddOption(nnEvalMode, std::string, "c1:r1", "", 0, "Concatention of modes, e.g. c1:r1 (classification class 1, regression class 1)")
284284
AddOption(nnClusterizerUseClassification, int, 1, "", 0, "If 1, the classification output of the network is used to select clusters, else only the regression output is used and no clusters are rejected by classification")
285285
AddOption(nnClusterizerForceGpuInputFill, int, 0, "", 0, "Forces to use the fillInputNNGPU function")
286+
AddOption(nnUseClusterErrorNetwork, int, 1, "", 0, "If 1, the cluster error network is used to parametrize the cluster errors, else a fixed parametrization is used")
287+
AddOption(nnClusterErrorModelPath, std::string, "", "", 0, "Network for cluster error parameterization")
288+
AddOption(dumpClusterErrorCSV, int, 0, "", 0, "Dumps the cluster errors to CSV if enabled")
289+
AddOption(nnScaleClusterError, float, 1.0, "", 0, "Scale factor for the cluster errors predicted by the network, can be used to effectively increase or decrease the cluster errors without retraining the network")
286290
// CCDB
287291
AddOption(nnLoadFromCCDB, int, 0, "", 0, "If 1 networks are fetched from ccdb, else locally")
288292
AddOption(nnCCDBDumpToFile, int, 0, "", 0, "If 1, additionally dump fetched CCDB networks to nnLocalFolder")

GPU/GPUTracking/Global/GPUChainTracking.cxx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ void GPUChainTracking::RegisterPermanentMemoryAndProcessors()
7878
{
7979
fpdumperr = fopen("dump_cluster_error.csv", "a");
8080
fpdumptrk = fopen("dump_trk_index.csv", "a");
81-
fprintf(fpdumperr, "internal_trkid,cluster.num,err2Y,err2Z,clusterState,clusterY,clusterZ,mP[0],mP[1],mP[2],mP[3],mP[4],mC[0],mC[2],mC[5],mC[9],mC[14]\n");
82-
fprintf(fpdumptrk, "internal_trkid,trkid\n");
81+
// fprintf(fpdumperr, "internal_trkid,cluster.num,err2Y,err2Z,clusterState,cluster.getSigmaPad(),cluster.getSigmaTime(),invAvgCharge,invCharge,xx,yy,zz,mP[0],mP[1],mP[2],mP[3],mP[4],mC[0],mC[2],mC[5],mC[9],mC[14]\n");
82+
// fprintf(fpdumptrk, "internal_trkid,trkid\n");
8383
if (mRec->IsGPU()) {
8484
mFlatObjectsShadow.InitGPUProcessor(mRec, GPUProcessor::PROCESSOR_TYPE_SLAVE);
8585
mFlatObjectsDevice.InitGPUProcessor(mRec, GPUProcessor::PROCESSOR_TYPE_DEVICE, &mFlatObjectsShadow);

GPU/GPUTracking/Merger/GPUTPCGMO2Output.cxx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,9 @@ GPUdii() void GPUTPCGMO2Output::Thread<GPUTPCGMO2Output::output>(int32_t nBlocks
272272
oTrack.setHasASideClusters();
273273
}
274274
#ifndef GPUCA_GPUCODE
275-
fprintf(fpdumptrk, "%d,%d\n", i, iTmp);
275+
if (merger.Param().dumpClusterErrorCSV) {
276+
fprintf(fpdumptrk, "%d,%d\n", i, iTmp);
277+
}
276278
#endif
277279
outputTracks[iTmp] = oTrack;
278280
}

GPU/GPUTracking/Merger/GPUTPCGMTrackParam.cxx

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,47 @@ GPUd() bool GPUTPCGMTrackParam::Fit(GPUTPCGMMerger* GPUrestrict() merger, int32_
279279
const float invCharge = merger->GetConstantMem()->ioPtrs.clustersNative ? (1.f / merger->GetConstantMem()->ioPtrs.clustersNative->clustersLinear[cluster.num].qMax) : 0.f;
280280
float invAvgCharge = (sumInvSqrtCharge += invSqrtCharge) / ++nAvgCharge;
281281
invAvgCharge *= invAvgCharge;
282-
prop.GetErr2(err2Y, err2Z, param, zz, cluster.row, clusterState, cluster.sector, time, invAvgCharge, invCharge);
282+
if (param.useClusterErrorNetwork) {
283+
// Python expands clusterState into 4 bits (cs0..cs3) and drops clusterState.
284+
// Final X dimension: 17 features.
285+
float inputFeatures[17];
286+
float outputFeatures[2];
287+
288+
inputFeatures[0] = xx;
289+
inputFeatures[1] = yy;
290+
inputFeatures[2] = zz;
291+
292+
inputFeatures[3] = static_cast<float>(merger->GetConstantMem()->ioPtrs.clustersNative->clustersLinear[cluster.num].getSigmaPad());
293+
inputFeatures[4] = static_cast<float>(merger->GetConstantMem()->ioPtrs.clustersNative->clustersLinear[cluster.num].getSigmaTime());
294+
295+
inputFeatures[5] = mP[0];
296+
inputFeatures[6] = mP[1];
297+
inputFeatures[7] = mP[2];
298+
inputFeatures[8] = mP[3];
299+
inputFeatures[9] = mP[4];
300+
301+
inputFeatures[10] = mC[0];
302+
inputFeatures[11] = mC[2];
303+
inputFeatures[12] = mC[5];
304+
inputFeatures[13] = mC[9];
305+
inputFeatures[14] = mC[14];
306+
307+
inputFeatures[15] = static_cast<float>((clusterState >> 0) & 1); // cs0
308+
inputFeatures[16] = static_cast<float>((clusterState >> 1) & 1); // cs1
309+
inputFeatures[17] = static_cast<float>((clusterState >> 2) & 1); // cs2
310+
inputFeatures[18] = static_cast<float>((clusterState >> 3) & 1); // cs3
311+
312+
param.mModelClusterErrors->inference(inputFeatures, (int64_t)1, outputFeatures);
313+
err2Y = param.scaleError*outputFeatures[0];
314+
err2Z = param.scaleError*outputFeatures[1];
315+
} else {
316+
prop.GetErr2(err2Y, err2Z, param, zz, cluster.row, clusterState, cluster.sector, time, invAvgCharge, invCharge);
317+
}
283318

284319
#ifndef GPUCA_GPUCODE
285-
fprintf(fpdumperr, "%d,%d,%f,%f,%d,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f\n", iTrk, cluster.num, err2Y, err2Z, clusterState, yy, zz, mP[0], mP[1], mP[2], mP[3], mP[4], mC[0], mC[2], mC[5], mC[9], mC[14]);
320+
if (param.dumpClusterErrorCSV) {
321+
fprintf(fpdumperr, "%d,%d,%f,%f,%d,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f\n", iTrk, cluster.num, err2Y, err2Z, clusterState, merger->GetConstantMem()->ioPtrs.clustersNative->clustersLinear[cluster.num].getSigmaPad(), merger->GetConstantMem()->ioPtrs.clustersNative->clustersLinear[cluster.num].getSigmaTime(), invAvgCharge, invCharge, xx, yy, zz, mP[0], mP[1], mP[2], mP[3], mP[4], mC[0], mC[2], mC[5], mC[9], mC[14]);
322+
}
286323
#endif
287324

288325
if (rejectChi2 >= GPUTPCGMPropagator::rejectInterFill) {

nn_cluster_error.diff

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
diff --git a/GPU/GPUTracking/Base/GPUParam.cxx b/GPU/GPUTracking/Base/GPUParam.cxx
2+
index aa4c3c7671..3963eeced7 100644
3+
--- a/GPU/GPUTracking/Base/GPUParam.cxx
4+
+++ b/GPU/GPUTracking/Base/GPUParam.cxx
5+
@@ -157,6 +157,7 @@ void GPUParam::SetDefaults(const GPUSettingsGRP* g, const GPUSettingsRec* r, con
6+
rec = *r;
7+
}
8+
UpdateSettings(g, p, w);
9+
+ initClusterErrorModel(p->nn);
10+
}
11+
12+
void GPUParam::UpdateRun3ClusterErrors(const float* yErrorParam, const float* zErrorParam)
13+
diff --git a/GPU/GPUTracking/Base/GPUParam.h b/GPU/GPUTracking/Base/GPUParam.h
14+
index 1b46dc4c9c..9c31093b60 100644
15+
--- a/GPU/GPUTracking/Base/GPUParam.h
16+
+++ b/GPU/GPUTracking/Base/GPUParam.h
17+
@@ -93,9 +93,9 @@ struct GPUParam : public internal::GPUParam_t<GPUSettingsRec, GPUSettingsParam>
18+
void UpdateRun3ClusterErrors(const float* yErrorParam, const float* zErrorParam);
19+
#endif
20+
21+
- void initClusterErrorModel(const GPUSettingsProcessing* p) {
22+
- useClusterErrorNetwork = p->nn.nnUseClusterErrorNetwork;
23+
- mOrtOptions["model_path"] = p->nn.nnClusterErrorModelPath;
24+
+ void initClusterErrorModel(const GPUSettingsProcessingNNclusterizer& p) {
25+
+ useClusterErrorNetwork = p.nnUseClusterErrorNetwork;
26+
+ mOrtOptions["model_path"] = p.nnClusterErrorModelPath;
27+
mModelClusterErrors.initOptions(mOrtOptions);
28+
mModelClusterErrors.setIntraOpNumThreads(1);
29+
mModelClusterErrors.initEnvironment();
30+
diff --git a/GPU/GPUTracking/Base/GPUReconstruction.cxx b/GPU/GPUTracking/Base/GPUReconstruction.cxx
31+
index fbbe815f63..c32f574ef1 100644
32+
--- a/GPU/GPUTracking/Base/GPUReconstruction.cxx
33+
+++ b/GPU/GPUTracking/Base/GPUReconstruction.cxx
34+
@@ -1278,6 +1278,7 @@ void GPUReconstruction::SetSettings(const GPUSettingsGRP* grp, const GPUSettings
35+
mRecoSteps.outputs = workflow->outputs;
36+
}
37+
param().SetDefaults(mGRPSettings.get(), rec, proc, workflow);
38+
+ // param().initClusterErrorModel(proc->nn);
39+
}
40+
41+
void GPUReconstruction::SetOutputControl(void* ptr, size_t size)

0 commit comments

Comments
 (0)