Skip to content

Commit

Permalink
add tweedie post_transform (#242)
Browse files Browse the repository at this point in the history
add tweedie to lgbm
  • Loading branch information
interesaaat authored Aug 19, 2020
1 parent 7e5aa66 commit af2d3cd
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 44 deletions.
10 changes: 9 additions & 1 deletion doc/html/hummingbird/ml/operator_converters/constants.html
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ <h1 class="title">Module <code>hummingbird.ml.operator_converters.constants</cod
<details class="source">
<summary>
<span>Expand source code</span>
<a href="https://github.com/microsoft/hummingbird/blob/master/hummingbird/ml/operator_converters/constants.py#L0-L47" class="git-link">Browse git</a>
<a href="https://github.com/microsoft/hummingbird/blob/master/hummingbird/ml/operator_converters/constants.py#L0-L50" class="git-link">Browse git</a>
</summary>
<pre><code class="python"># -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
Expand All @@ -54,6 +54,9 @@ <h1 class="title">Module <code>hummingbird.ml.operator_converters.constants</cod
SOFTMAX = &#34;SOFTMAX&#34;
&#34;&#34;&#34;Softmax post transform&#34;&#34;&#34;

TWEEDIE = &#34;TWEEDIE&#34;
&#34;&#34;&#34;Tweedie post transform&#34;&#34;&#34;

GET_PARAMETERS_FOR_TREE_TRAVERSAL = &#34;get_parameters_for_tree_trav&#34;
&#34;&#34;&#34;Which function to use to extract the parameters for the tree traversal strategy&#34;&#34;&#34;

Expand Down Expand Up @@ -136,6 +139,10 @@ <h2 class="section-title" id="header-variables">Global variables</h2>
<dd>
<div class="desc"><p>The test input data for models that need to be traced</p></div>
</dd>
<dt id="hummingbird.ml.operator_converters.constants.TWEEDIE"><code class="name">var <span class="ident">TWEEDIE</span></code></dt>
<dd>
<div class="desc"><p>Tweedie post transform</p></div>
</dd>
</dl>
</section>
<section>
Expand Down Expand Up @@ -173,6 +180,7 @@ <h1>Index</h1>
<li><code><a title="hummingbird.ml.operator_converters.constants.SIGMOID" href="#hummingbird.ml.operator_converters.constants.SIGMOID">SIGMOID</a></code></li>
<li><code><a title="hummingbird.ml.operator_converters.constants.SOFTMAX" href="#hummingbird.ml.operator_converters.constants.SOFTMAX">SOFTMAX</a></code></li>
<li><code><a title="hummingbird.ml.operator_converters.constants.TEST_INPUT" href="#hummingbird.ml.operator_converters.constants.TEST_INPUT">TEST_INPUT</a></code></li>
<li><code><a title="hummingbird.ml.operator_converters.constants.TWEEDIE" href="#hummingbird.ml.operator_converters.constants.TWEEDIE">TWEEDIE</a></code></li>
</ul>
</li>
</ul>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,18 +327,10 @@ <h3>Ancestors</h3>
<h3>Methods</h3>
<dl>
<dt id="hummingbird.ml.operator_converters.onnx.onnx_operator.Cast.forward"><code class="name flex">
<span>def <span class="ident">forward</span></span>(<span>self, x)</span>
<span>def <span class="ident">forward</span></span>(<span>self, x) -> Callable[..., Any]</span>
</code></dt>
<dd>
<div class="desc"><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the :class:<code>Module</code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div></div>
<div class="desc"></div>
<details class="source">
<summary>
<span>Expand source code</span>
Expand Down Expand Up @@ -377,18 +369,10 @@ <h3>Ancestors</h3>
<h3>Methods</h3>
<dl>
<dt id="hummingbird.ml.operator_converters.onnx.onnx_operator.Concat.forward"><code class="name flex">
<span>def <span class="ident">forward</span></span>(<span>self, *x)</span>
<span>def <span class="ident">forward</span></span>(<span>self, *x) -> Callable[..., Any]</span>
</code></dt>
<dd>
<div class="desc"><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the :class:<code>Module</code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div></div>
<div class="desc"></div>
<details class="source">
<summary>
<span>Expand source code</span>
Expand Down Expand Up @@ -429,18 +413,10 @@ <h3>Ancestors</h3>
<h3>Methods</h3>
<dl>
<dt id="hummingbird.ml.operator_converters.onnx.onnx_operator.Reshape.forward"><code class="name flex">
<span>def <span class="ident">forward</span></span>(<span>self, x)</span>
<span>def <span class="ident">forward</span></span>(<span>self, x) -> Callable[..., Any]</span>
</code></dt>
<dd>
<div class="desc"><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the :class:<code>Module</code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div></div>
<div class="desc"></div>
<details class="source">
<summary>
<span>Expand source code</span>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ <h1 class="title">Module <code>hummingbird.ml.operator_converters.sklearn.lightg
<details class="source">
<summary>
<span>Expand source code</span>
<a href="https://github.com/microsoft/hummingbird/blob/master/hummingbird/ml/operator_converters/sklearn/lightgbm.py#L0-L100" class="git-link">Browse git</a>
<a href="https://github.com/microsoft/hummingbird/blob/master/hummingbird/ml/operator_converters/sklearn/lightgbm.py#L0-L103" class="git-link">Browse git</a>
</summary>
<pre><code class="python"># -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
Expand All @@ -42,6 +42,7 @@ <h1 class="title">Module <code>hummingbird.ml.operator_converters.sklearn.lightg
import numpy as np
from onnxconverter_common.registration import register_converter

from .. import constants
from .._gbdt_commons import convert_gbdt_classifier_common, convert_gbdt_common
from .._tree_commons import TreeParameters

Expand Down Expand Up @@ -122,6 +123,8 @@ <h1 class="title">Module <code>hummingbird.ml.operator_converters.sklearn.lightg
# Get tree information out of the model.
n_features = operator.raw_operator._n_features
tree_infos = operator.raw_operator.booster_.dump_model()[&#34;tree_info&#34;]
if operator.raw_operator._objective == &#34;tweedie&#34;:
extra_config[constants.POST_TRANSFORM] = constants.TWEEDIE

return convert_gbdt_common(tree_infos, _get_tree_parameters, n_features, extra_config=extra_config)

Expand Down Expand Up @@ -161,7 +164,7 @@ <h2 id="returns">Returns</h2>
<details class="source">
<summary>
<span>Expand source code</span>
<a href="https://github.com/microsoft/hummingbird/blob/master/hummingbird/ml/operator_converters/sklearn/lightgbm.py#L55-L74" class="git-link">Browse git</a>
<a href="https://github.com/microsoft/hummingbird/blob/master/hummingbird/ml/operator_converters/sklearn/lightgbm.py#L56-L75" class="git-link">Browse git</a>
</summary>
<pre><code class="python">def convert_sklearn_lgbm_classifier(operator, device, extra_config):
&#34;&#34;&#34;
Expand Down Expand Up @@ -207,7 +210,7 @@ <h2 id="returns">Returns</h2>
<details class="source">
<summary>
<span>Expand source code</span>
<a href="https://github.com/microsoft/hummingbird/blob/master/hummingbird/ml/operator_converters/sklearn/lightgbm.py#L77-L95" class="git-link">Browse git</a>
<a href="https://github.com/microsoft/hummingbird/blob/master/hummingbird/ml/operator_converters/sklearn/lightgbm.py#L78-L98" class="git-link">Browse git</a>
</summary>
<pre><code class="python">def convert_sklearn_lgbm_regressor(operator, device, extra_config):
&#34;&#34;&#34;
Expand All @@ -226,6 +229,8 @@ <h2 id="returns">Returns</h2>
# Get tree information out of the model.
n_features = operator.raw_operator._n_features
tree_infos = operator.raw_operator.booster_.dump_model()[&#34;tree_info&#34;]
if operator.raw_operator._objective == &#34;tweedie&#34;:
extra_config[constants.POST_TRANSFORM] = constants.TWEEDIE

return convert_gbdt_common(tree_infos, _get_tree_parameters, n_features, extra_config=extra_config)</code></pre>
</details>
Expand Down
12 changes: 2 additions & 10 deletions doc/html/hummingbird/ml/operator_converters/sklearn/skl_sv.html
Original file line number Diff line number Diff line change
Expand Up @@ -304,18 +304,10 @@ <h3>Ancestors</h3>
<h3>Methods</h3>
<dl>
<dt id="hummingbird.ml.operator_converters.sklearn.skl_sv.SVC.forward"><code class="name flex">
<span>def <span class="ident">forward</span></span>(<span>self, x)</span>
<span>def <span class="ident">forward</span></span>(<span>self, x) -> Callable[..., Any]</span>
</code></dt>
<dd>
<div class="desc"><p>Defines the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the :class:<code>Module</code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div></div>
<div class="desc"></div>
<details class="source">
<summary>
<span>Expand source code</span>
Expand Down
8 changes: 8 additions & 0 deletions hummingbird/ml/operator_converters/_gbdt_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def apply_sigmoid(x):
def apply_softmax(x):
return torch.softmax(x, dim=1)

def apply_tweedie(x):
return torch.exp(x)

# For models following the Sklearn API we need to build the post transform ourselves.
if classes is not None and constants.POST_TRANSFORM not in extra_config:
if len(classes) <= 2:
Expand All @@ -143,6 +146,11 @@ def apply_softmax(x):
extra_config[constants.POST_TRANSFORM] = lambda x: apply_softmax(apply_base_prediction(base_prediction)(x))
else:
extra_config[constants.POST_TRANSFORM] = apply_softmax
elif extra_config[constants.POST_TRANSFORM] == constants.TWEEDIE:
if constants.BASE_PREDICTION in extra_config:
extra_config[constants.POST_TRANSFORM] = lambda x: apply_tweedie(apply_base_prediction(base_prediction)(x))
else:
extra_config[constants.POST_TRANSFORM] = apply_tweedie
else:
raise NotImplementedError("Post transform {} not implemeneted yet".format(extra_config[constants.POST_TRANSFORM]))
elif constants.BASE_PREDICTION in extra_config:
Expand Down
3 changes: 3 additions & 0 deletions hummingbird/ml/operator_converters/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
SOFTMAX = "SOFTMAX"
"""Softmax post transform"""

TWEEDIE = "TWEEDIE"
"""Tweedie post transform"""

GET_PARAMETERS_FOR_TREE_TRAVERSAL = "get_parameters_for_tree_trav"
"""Which function to use to extract the parameters for the tree traversal strategy"""

Expand Down
3 changes: 3 additions & 0 deletions hummingbird/ml/operator_converters/sklearn/lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
from onnxconverter_common.registration import register_converter

from .. import constants
from .._gbdt_commons import convert_gbdt_classifier_common, convert_gbdt_common
from .._tree_commons import TreeParameters

Expand Down Expand Up @@ -91,6 +92,8 @@ def convert_sklearn_lgbm_regressor(operator, device, extra_config):
# Get tree information out of the model.
n_features = operator.raw_operator._n_features
tree_infos = operator.raw_operator.booster_.dump_model()["tree_info"]
if operator.raw_operator._objective == "tweedie":
extra_config[constants.POST_TRANSFORM] = constants.TWEEDIE

return convert_gbdt_common(tree_infos, _get_tree_parameters, n_features, extra_config=extra_config)

Expand Down
17 changes: 17 additions & 0 deletions tests/test_lightgbm_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,23 @@ def test_lgbm_classifier_random_forest(self):
self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)

# Test Tweedie loss in lgbm
@unittest.skipIf(not lightgbm_installed(), reason="LightGBM test requires LightGBM installed")
def test_lgbm_tweedie(self):
warnings.filterwarnings("ignore")
model = lgb.LGBMRegressor(objective="tweedie", n_estimators=2, max_depth=5)

np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(100, size=100)

model.fit(X, y)

torch_model = hummingbird.ml.convert(model, "torch")
self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)

# Backend tests.
# Test TorchScript backend regression.
@unittest.skipIf(not lightgbm_installed(), reason="LightGBM test requires LightGBM installed")
Expand Down

0 comments on commit af2d3cd

Please sign in to comment.