Commercial enterprise offerings

Accelerating Tree-Based Models in SQL with Orbital 0.3.0

Written by Alessandro Molina
2024-11-07
Tree diagram and SQL code for a decision tree model showing data branching into various numerical outcomes based on CASE logic.

Orbital is a library that converts Scikit-learn pipelines into SQL queries, enabling machine learning model inference directly within SQL databases.

In Orbital 0.3.0, the SQL generation logic for tree-based models — including Decision Trees, Random Forests, and Gradient Boosted Trees — has seen major changes that had a significant impact on performance and size of the generated queries

  • The way branches are built has been rewritten using a dedicated BranchConditionCreator
  • Transformation functions have been refactored to their own module that supports registering new transformations
  • Class scores at each leaf are now preserved as temporary variables

This allows the new implementation to introduce intermediate variable reuse: repeated sub-expressions are now computed once and stored as SQL variables, dramatically reducing query size and execution time. This change leads to cleaner SQL, faster predictions, and better compatibility with complex database optimizers.

Trees in SQL

Many tree models can be reproduced in pure SQL using CASE branches.

Each CASE compares a feature against a threshold and directs evaluation to one of its branches, until the final leaf node is reached, where the model returns a numerical score (or “vote”) for a given class.

This results in fairly long and complex SQL that looks like this (each "tcl_vNNN" you see is one of the features provided to the model)

orbital.export_sql("DB_TABLE", treemodel_orbital_pipeline)
CASE 
  WHEN "t3"."tcl_v269" <= -4.108497619628906 THEN
    CASE 
      WHEN "t3"."tcl_v271" <= 0.8464506268501282 THEN
        CASE 
          WHEN "t3"."tcl_v271" <= 0.5293939113616943 THEN
            -0.06815864890813828
          ELSE
            -0.07052286714315414
        END
      ELSE
        CASE 
          WHEN "t3"."tcl_v269" <= -4.217958927154541 THEN
            -0.07220334559679031
          ELSE
            -0.07770664244890213
        END
    END
  ELSE
    CASE 
      WHEN "t3"."tcl_v269" <= -4.100806713104248 THEN
        CASE 
          WHEN "t3"."tcl_v271" <= 0.05221748352050781 THEN
            0.1321529746055603
          ELSE
            0.0915503054857254
        END
      ELSE
        CASE 
          WHEN "t3"."tcl_v273" <= -0.3382622003555298 THEN
            0.005495843011885881
          ELSE
            -0.0033024316653609276
        END
    END
END
- "t3"."sfmx_v274" AS "output_probability.medium"
FROM "t3" AS "t3";

This was after Orbital optimizer had already applied all possible folding of constant expressions, created temporary variables for the variables consumed as inputs of the tree and precomputing CASE expressions where possible. A second pass was performed by the SQLGlot optimizer, further improving the query.

Nonetheless the pipeline_boosted_tree_classifier.py in Orbital examples, resulted in a query of 2,188,462 characters. While most production database systems were able to parse and execute the query, it was significantly harder for them to parse and execute.

Running the example on the GitHub CI using DuckDB did take 1 minute and 30 seconds, which is not bad, but definitely not as fast as running it on scikit-learn on a fast system.

Running example: examples/pipeline_boosted_tree_classifier.py

real    1m29.621s
user    1m29.767s
sys     0m0.853s

Optimizing Trees

Orbital 0.3.0 introduced various improvements that contribute to tree-based models performances, primarily

  • Using a dedicated BranchConditionCreator to create tree branches
  • Preserving votes for each class in temporary variables

Nearly all these improvements are centered around detecting repeated expressions and precomputing them.

As the branches of the trees tend to reuse the same comparisons, being able to identify recurring ones could help avoid the cost of recomputing them.

Unique identifier for expressions

The new BranchConditionCreator introduced in Orbital 0.3.0 takes care of building the comparison expressions for the case statements. Typically something like WHEN "t3"."tcl_v273" <= -0.3382622003555298 THEN

Those expressions are all composed by the same 3 elements:

  • The feature being compared
  • The comparison operator
  • The threshold to which the feature is compared

The threshold and the operator are constants, so it’s easy to identify if two expressions are equal by checking the variable name, the operator, and the threshold.

Those expressions can then be precomputed, assigned to a variable, and retrieved via a lookup table. This allows all previous case expressions to become something like

CASE WHEN "t3"."cnc_v853" THEN 
    -0.0005402276874519885 
ELSE 
    0.09522797167301178 
END

Preserving votes

In tree-based classifiers such as Random Forests or Gradient Boosted Trees, the model doesn’t directly output a class label. Instead, each tree contributes a vote (or a numerical score) for every possible class.

The final prediction is obtained by aggregating these votes (for example, by averaging them in Random Forests, or summing them in Gradient Boosted Trees) and then converting the result into class probabilities using a normalization step such as the softmax function.

In SQL, each leaf of a tree can be expressed as a CASE statement that returns the vote (or score) associated with that leaf:

CASE 
  WHEN "t3"."tcl_v273" <= 1.4952754974365234 THEN 0.0908372700214386 
  ELSE 0.0740339607000351 
END

A complete prediction requires combining all these leaf-level votes, often hundreds of such fragments, across all trees in the ensemble. Previously, Orbital would inline these vote expressions every time they appeared, which inflated query size and duplicated computation.

In Orbital 0.3.0, votes are now preserved as temporary variables, so each vote expression is evaluated only once and reused wherever needed. This optimization greatly simplifies the final SQL while improving runtime performance, since databases can now cache and reuse these intermediate results:

( ...
  EXP("t4"."vte_v943" - GREATEST("t4"."vte_v942", "t4"."vte_v943", "t4"."vte_v944")) +
  EXP("t4"."vte_v944" - GREATEST("t4"."vte_v942", "t4"."vte_v943", "t4"."vte_v944"))
) AS "output_probability.medium"

This change reduces redundant computations and ensures that the probability for each class is derived efficiently from the precomputed vote variables.

SQL in Orbital 0.3.0

All those improvements combined reduced the size of the generated SQL query for the Gradient Boosted Tree classifier example from 2 million characters to 301,875. A major improvement that means the database parser has much less work to do, and the resulting query plan generated out of the query is much smaller and simpler.

The impact is not limited to the size of the query; a simpler query means that the database optimizer has an easier time finding optimization strategies, and the temporary variables Orbital itself introduced help the database compute expressions only once.

This means that the same example now runs three times faster and completes in less than 30 seconds on GitHub CI using DuckDB.

>>> Running example: examples/pipeline_boosted_tree_classifier.py

real    0m26.235s
user    0m26.688s
sys     0m0.390s

Next steps

While these improvements bring a major performance boost, lightweight database engines such as SQLite still struggle with the complexity of the generated SQL.

Even though the datasets themselves may be small, SQLite often runs out of memory during query plan generation, before the actual execution, due to the large number of nested CASE expressions and temporary variables.

One of Orbital’s long-term goals is to make tree-based prediction possible even in minimal environments like SQLite, enabling local experimentation without needing a full database server.

To achieve this, we plan to explore precomputing shared subtrees within the model. By identifying branches that reuse the same internal nodes or partial decision paths, Orbital could compute those once and reference them across multiple parts of the SQL.

This would further shrink query size and reduce parsing overhead, bringing support for more lightweight database engines closer to reality.