diff --git a/python/tvm/sparse/format.py b/python/tvm/sparse/format.py index 6c7ea1cb4..efb8e3588 100644 --- a/python/tvm/sparse/format.py +++ b/python/tvm/sparse/format.py @@ -56,6 +56,41 @@ class FormatRewriteRule(Object): inv_idx_map_func : Union[Callable, IndexMap] A function describing the coordinate mapping from indices in new format. to indices in old format. + + Examples + -------- + .. code-block:: python + + >>> import tvm + >>> from tvm.sparse import format_decompose + >>> def bsr(block_size: int) + ... @T.prim_func + ... def func( + ... a: T.handle, + ... indptr: T.handle, + ... indices: T.handle, + ... m: T.int32, + ... n: T.int32, + ... nnz: T.int32, + ... ) -> None: + ... IO = T.dense_fixed(m) + ... JO = T.sparse_variable(IO, (n, nnz), (indptr, indices), "int32") + ... II = T.dense_fixed(block_size) + ... JI = T.dense_fixed(block_size) + ... A = T.match_sparse_buffer(a, (IO, JO, II, JI), "float32") + ... T.evaluate(0) + ... return func + ... + >>> rewrite_bsr_32 = FormatRewriteRule( + ... str(32), + ... bsr(32), + ... ["A"], + ... ["I", "J"], + ... ["IO", "JO", "II", "JI"], + ... {"I": ["IO", "II"], "J": ["JO", "JI"]}, + ... lambda i, j: (i // 32, j // 32, i % 32, j % 32), + ... lambda io, jo, ii, ji: (io * 32 + ii, jo * 32 + ji), + ... ) """ def __init__(