When requesting higher taxonomic ranks (genus, family, order, etc.), TreeOfLifeClassifier.predict() seems to aggregate the species-level probabilities up the taxonomy tree. Currently, this aggregation uses CPU loops with dictionary lookups - creating a major bottleneck.
With 10k plant images and the current 350K plantae species in the TOL embeddings, that's 3.5 billion dictionary operations on a single CPU-core. I suggest moving this to the GPU with precomputed species-rank mappings up front and using PyTorch scatter_add for accumulation. In my 10k dataset, this went from ~130 minutes for family-level predictions to ~4.8 minutes.
Happy to submit PR if interested.