From af2d3cdd20c56e14b1d859bfbde7443f0659d1ca Mon Sep 17 00:00:00 2001 From: Matteo Interlandi Date: Wed, 19 Aug 2020 09:12:43 -0700 Subject: [PATCH] add tweedie post_transform (#242) add tweedie to lgbm --- .../ml/operator_converters/constants.html | 10 +++++- .../onnx/onnx_operator.html | 36 ++++--------------- .../operator_converters/sklearn/lightgbm.html | 11 ++++-- .../operator_converters/sklearn/skl_sv.html | 12 ++----- .../ml/operator_converters/_gbdt_commons.py | 8 +++++ .../ml/operator_converters/constants.py | 3 ++ .../operator_converters/sklearn/lightgbm.py | 3 ++ tests/test_lightgbm_converter.py | 17 +++++++++ 8 files changed, 56 insertions(+), 44 deletions(-) diff --git a/doc/html/hummingbird/ml/operator_converters/constants.html b/doc/html/hummingbird/ml/operator_converters/constants.html index 22a324f1b..a1237f237 100644 --- a/doc/html/hummingbird/ml/operator_converters/constants.html +++ b/doc/html/hummingbird/ml/operator_converters/constants.html @@ -27,7 +27,7 @@

Module hummingbird.ml.operator_converters.constants Expand source code -Browse git +Browse git
# -------------------------------------------------------------------------
 # Copyright (c) Microsoft Corporation. All rights reserved.
@@ -54,6 +54,9 @@ 

Module hummingbird.ml.operator_converters.constantsGlobal variables

The test input data for models that need to be traced

+
var TWEEDIE
+
+

Tweedie post transform

+
@@ -173,6 +180,7 @@

Index

  • SIGMOID
  • SOFTMAX
  • TEST_INPUT
  • +
  • TWEEDIE
  • diff --git a/doc/html/hummingbird/ml/operator_converters/onnx/onnx_operator.html b/doc/html/hummingbird/ml/operator_converters/onnx/onnx_operator.html index a4d297863..5a387b069 100644 --- a/doc/html/hummingbird/ml/operator_converters/onnx/onnx_operator.html +++ b/doc/html/hummingbird/ml/operator_converters/onnx/onnx_operator.html @@ -327,18 +327,10 @@

    Ancestors

    Methods

    -def forward(self, x) +def forward(self, x) -> Callable[..., Any]
    -

    Defines the computation performed at every call.

    -

    Should be overridden by all subclasses.

    -
    -

    Note

    -

    Although the recipe for forward pass needs to be defined within -this function, one should call the :class:Module instance afterwards -instead of this since the former takes care of running the -registered hooks while the latter silently ignores them.

    -
    +
    Expand source code @@ -377,18 +369,10 @@

    Ancestors

    Methods

    -def forward(self, *x) +def forward(self, *x) -> Callable[..., Any]
    -

    Defines the computation performed at every call.

    -

    Should be overridden by all subclasses.

    -
    -

    Note

    -

    Although the recipe for forward pass needs to be defined within -this function, one should call the :class:Module instance afterwards -instead of this since the former takes care of running the -registered hooks while the latter silently ignores them.

    -
    +
    Expand source code @@ -429,18 +413,10 @@

    Ancestors

    Methods

    -def forward(self, x) +def forward(self, x) -> Callable[..., Any]
    -

    Defines the computation performed at every call.

    -

    Should be overridden by all subclasses.

    -
    -

    Note

    -

    Although the recipe for forward pass needs to be defined within -this function, one should call the :class:Module instance afterwards -instead of this since the former takes care of running the -registered hooks while the latter silently ignores them.

    -
    +
    Expand source code diff --git a/doc/html/hummingbird/ml/operator_converters/sklearn/lightgbm.html b/doc/html/hummingbird/ml/operator_converters/sklearn/lightgbm.html index e8cdb6fc8..bb98c478a 100644 --- a/doc/html/hummingbird/ml/operator_converters/sklearn/lightgbm.html +++ b/doc/html/hummingbird/ml/operator_converters/sklearn/lightgbm.html @@ -27,7 +27,7 @@

    Module hummingbird.ml.operator_converters.sklearn.lightg
    Expand source code -Browse git +Browse git
    # -------------------------------------------------------------------------
     # Copyright (c) Microsoft Corporation. All rights reserved.
    @@ -42,6 +42,7 @@ 

    Module hummingbird.ml.operator_converters.sklearn.lightg import numpy as np from onnxconverter_common.registration import register_converter +from .. import constants from .._gbdt_commons import convert_gbdt_classifier_common, convert_gbdt_common from .._tree_commons import TreeParameters @@ -122,6 +123,8 @@

    Module hummingbird.ml.operator_converters.sklearn.lightg # Get tree information out of the model. n_features = operator.raw_operator._n_features tree_infos = operator.raw_operator.booster_.dump_model()["tree_info"] + if operator.raw_operator._objective == "tweedie": + extra_config[constants.POST_TRANSFORM] = constants.TWEEDIE return convert_gbdt_common(tree_infos, _get_tree_parameters, n_features, extra_config=extra_config) @@ -161,7 +164,7 @@

    Returns

    Expand source code -Browse git +Browse git
    def convert_sklearn_lgbm_classifier(operator, device, extra_config):
         """
    @@ -207,7 +210,7 @@ 

    Returns

    Expand source code -Browse git +Browse git
    def convert_sklearn_lgbm_regressor(operator, device, extra_config):
         """
    @@ -226,6 +229,8 @@ 

    Returns

    # Get tree information out of the model. n_features = operator.raw_operator._n_features tree_infos = operator.raw_operator.booster_.dump_model()["tree_info"] + if operator.raw_operator._objective == "tweedie": + extra_config[constants.POST_TRANSFORM] = constants.TWEEDIE return convert_gbdt_common(tree_infos, _get_tree_parameters, n_features, extra_config=extra_config)
    diff --git a/doc/html/hummingbird/ml/operator_converters/sklearn/skl_sv.html b/doc/html/hummingbird/ml/operator_converters/sklearn/skl_sv.html index 647c79f7d..51e5ea058 100644 --- a/doc/html/hummingbird/ml/operator_converters/sklearn/skl_sv.html +++ b/doc/html/hummingbird/ml/operator_converters/sklearn/skl_sv.html @@ -304,18 +304,10 @@

    Ancestors

    Methods

    -def forward(self, x) +def forward(self, x) -> Callable[..., Any]
    -

    Defines the computation performed at every call.

    -

    Should be overridden by all subclasses.

    -
    -

    Note

    -

    Although the recipe for forward pass needs to be defined within -this function, one should call the :class:Module instance afterwards -instead of this since the former takes care of running the -registered hooks while the latter silently ignores them.

    -
    +
    Expand source code diff --git a/hummingbird/ml/operator_converters/_gbdt_commons.py b/hummingbird/ml/operator_converters/_gbdt_commons.py index 68cd5e8b2..7aab75742 100644 --- a/hummingbird/ml/operator_converters/_gbdt_commons.py +++ b/hummingbird/ml/operator_converters/_gbdt_commons.py @@ -124,6 +124,9 @@ def apply_sigmoid(x): def apply_softmax(x): return torch.softmax(x, dim=1) + def apply_tweedie(x): + return torch.exp(x) + # For models following the Sklearn API we need to build the post transform ourselves. if classes is not None and constants.POST_TRANSFORM not in extra_config: if len(classes) <= 2: @@ -143,6 +146,11 @@ def apply_softmax(x): extra_config[constants.POST_TRANSFORM] = lambda x: apply_softmax(apply_base_prediction(base_prediction)(x)) else: extra_config[constants.POST_TRANSFORM] = apply_softmax + elif extra_config[constants.POST_TRANSFORM] == constants.TWEEDIE: + if constants.BASE_PREDICTION in extra_config: + extra_config[constants.POST_TRANSFORM] = lambda x: apply_tweedie(apply_base_prediction(base_prediction)(x)) + else: + extra_config[constants.POST_TRANSFORM] = apply_tweedie else: raise NotImplementedError("Post transform {} not implemeneted yet".format(extra_config[constants.POST_TRANSFORM])) elif constants.BASE_PREDICTION in extra_config: diff --git a/hummingbird/ml/operator_converters/constants.py b/hummingbird/ml/operator_converters/constants.py index a03bf9633..4c1f7219a 100644 --- a/hummingbird/ml/operator_converters/constants.py +++ b/hummingbird/ml/operator_converters/constants.py @@ -23,6 +23,9 @@ SOFTMAX = "SOFTMAX" """Softmax post transform""" +TWEEDIE = "TWEEDIE" +"""Tweedie post transform""" + GET_PARAMETERS_FOR_TREE_TRAVERSAL = "get_parameters_for_tree_trav" """Which function to use to extract the parameters for the tree traversal strategy""" diff --git a/hummingbird/ml/operator_converters/sklearn/lightgbm.py b/hummingbird/ml/operator_converters/sklearn/lightgbm.py index f6de6b258..b51847e15 100644 --- a/hummingbird/ml/operator_converters/sklearn/lightgbm.py +++ b/hummingbird/ml/operator_converters/sklearn/lightgbm.py @@ -11,6 +11,7 @@ import numpy as np from onnxconverter_common.registration import register_converter +from .. import constants from .._gbdt_commons import convert_gbdt_classifier_common, convert_gbdt_common from .._tree_commons import TreeParameters @@ -91,6 +92,8 @@ def convert_sklearn_lgbm_regressor(operator, device, extra_config): # Get tree information out of the model. n_features = operator.raw_operator._n_features tree_infos = operator.raw_operator.booster_.dump_model()["tree_info"] + if operator.raw_operator._objective == "tweedie": + extra_config[constants.POST_TRANSFORM] = constants.TWEEDIE return convert_gbdt_common(tree_infos, _get_tree_parameters, n_features, extra_config=extra_config) diff --git a/tests/test_lightgbm_converter.py b/tests/test_lightgbm_converter.py index c51a004da..16954887b 100644 --- a/tests/test_lightgbm_converter.py +++ b/tests/test_lightgbm_converter.py @@ -243,6 +243,23 @@ def test_lgbm_classifier_random_forest(self): self.assertIsNotNone(torch_model) np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06) + # Test Tweedie loss in lgbm + @unittest.skipIf(not lightgbm_installed(), reason="LightGBM test requires LightGBM installed") + def test_lgbm_tweedie(self): + warnings.filterwarnings("ignore") + model = lgb.LGBMRegressor(objective="tweedie", n_estimators=2, max_depth=5) + + np.random.seed(0) + X = np.random.rand(100, 200) + X = np.array(X, dtype=np.float32) + y = np.random.randint(100, size=100) + + model.fit(X, y) + + torch_model = hummingbird.ml.convert(model, "torch") + self.assertIsNotNone(torch_model) + np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06) + # Backend tests. # Test TorchScript backend regression. @unittest.skipIf(not lightgbm_installed(), reason="LightGBM test requires LightGBM installed")