Skip to content

VMML/tntorch

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Documentation Status

tntorch - Tensor Network Learning with PyTorch

New: our Read the Docs site is out!

Welcome to tntorch, a PyTorch-powered modeling and learning library using tensor networks. Such networks are unique in that they use multilinear neural units (instead of non-linear activation units). Features include:

  • Basic and fancy indexing of tensors, broadcasting, assignment, etc.
  • Tensor decomposition and reconstruction
  • Element-wise and tensor-tensor arithmetics
  • Building tensors from black-box functions using cross-approximation
  • Statistics and sensitivity analysis
  • Optimization using autodifferentiation
  • Misc. operations on tensors: stacking, unfolding, sampling, derivating, etc.

Available tensor formats include:

For example, the following networks both represent a 4D tensor (i.e. a real function that can take I1 x I2 x I3 x I4 possible values) in the TT and TT-Tucker formats:

In tntorch, all tensor decompositions share the same interface. You can handle them in a transparent form, as if they were plain NumPy arrays or PyTorch tensors:

> import tntorch as tn
> t = tn.randn(32, 32, 32, 32, ranks_tt=5)  # Random 4D TT tensor of shape 32 x 32 x 32 x 32 and TT-rank 5
> print(t)

4D TT tensor:

 32  32  32  32
  |   |   |   |
 (0) (1) (2) (3)
 / \ / \ / \ / \
1   5   5   5   1

> print(tn.mean(t))

tensor(8.0388)

> print(tn.norm(t))

tensor(9632.3726)

Decompressing tensors is easy:

> print(t.torch().shape)
torch.Size([32, 32, 32, 32])

Thanks to PyTorch's automatic differentiation, you can easily define all sorts of loss functions on tensors:

def loss(t):
    return torch.norm(t[:, 0, 10:, [3, 4]].torch())  # NumPy-like "fancy indexing" for arrays

Most importantly, loss functions can be defined on compressed tensors as well:

def loss(t):
    return tn.norm(t[:3, :3, :3, :3] - t[-3:, -3:, -3:, -3:])

Check out the introductory notebook for all the details on the basics.

Tutorial Notebooks

Installation

The main dependencies are NumPy and PyTorch. To download and install tntorch:

git clone https://github.com/rballester/tntorch.git
cd tntorch
pip install .

Testing

We use pytest. Simply run:

cd tests/
pytest

Contributing

Pull requests are welcome!

Besides using the issue tracker, feel also free to contact me at [email protected].

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%