Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse matrices as Imgs #331

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
70 changes: 70 additions & 0 deletions src/main/java/net/imglib2/img/sparse/SparseCSCImg.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package net.imglib2.img.sparse;

import net.imglib2.Cursor;
import net.imglib2.Interval;
import net.imglib2.RandomAccess;
import net.imglib2.img.Img;
import net.imglib2.img.ImgFactory;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.NumericType;

public class SparseCSCImg<
D extends NumericType<D> & NativeType<D>,
I extends IntegerType<I> & NativeType<I>> extends SparseImg<D,I> {

public SparseCSCImg(final long numCols, final long numRows, final Img<D> data, final Img<I> indices, final Img<I> indptr) {
super(numCols, numRows, data, indices, indptr);
}

@Override
public RandomAccess<D> randomAccess() {
return new SparseRandomAccess<D, I>(this, 1);
}

@Override
public Cursor<D> localizingCursor() {
return new SparseLocalizingCursor<>(this, 1, data.firstElement());
}

@Override
public SparseCSCImg<D,I> copy() {
Img<D> dataCopy = data.copy();
Img<I> indicesCopy = indices.copy();
Img<I> indptrCopy = indptr.copy();
return new SparseCSCImg<>(dimension(0), dimension(1), dataCopy, indicesCopy, indptrCopy);
}

@Override
public ImgFactory<D> factory() {
return new SparseImgFactory<>(data.getAt(0), indices.getAt(0), 1);
}

@Override
public ColumnMajorIterationOrder2D iterationOrder() {
return new ColumnMajorIterationOrder2D(this);
}

/**
* An iteration order that scans a 2D image in column-major order.
* I.e., cursors iterate column by column and row by row. For instance a
* sparse img ranging from <em>(0,0)</em> to <em>(1,1)</em> is iterated like
* <em>(0,0), (0,1), (1,0), (1,1)</em>
*/
public static class ColumnMajorIterationOrder2D {

private final Interval interval;
public ColumnMajorIterationOrder2D(final Interval interval) {
this.interval = interval;
}

@Override
public boolean equals(final Object obj) {

if (!(obj instanceof SparseCSCImg.ColumnMajorIterationOrder2D))
return false;

return SparseImg.haveSameIterationSpace(interval, ((ColumnMajorIterationOrder2D) obj).interval);
}
}
}
70 changes: 70 additions & 0 deletions src/main/java/net/imglib2/img/sparse/SparseCSRImg.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package net.imglib2.img.sparse;

import net.imglib2.Cursor;
import net.imglib2.Interval;
import net.imglib2.RandomAccess;
import net.imglib2.img.Img;
import net.imglib2.img.ImgFactory;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.NumericType;

public class SparseCSRImg<
D extends NumericType<D> & NativeType<D>,
I extends IntegerType<I> & NativeType<I>> extends SparseImg<D,I> {

public SparseCSRImg(final long numCols, final long numRows, final Img<D> data, final Img<I> indices, final Img<I> indptr) {
super(numCols, numRows, data, indices, indptr);
}

@Override
public RandomAccess<D> randomAccess() {
return new SparseRandomAccess<D, I>(this, 0);
}

@Override
public Cursor<D> localizingCursor() {
return new SparseLocalizingCursor<>(this, 0, data.firstElement());
}

@Override
public RowMajorIterationOrder2D iterationOrder() {
return new RowMajorIterationOrder2D(this);
}

@Override
public SparseCSRImg<D,I> copy() {
Img<D> dataCopy = data.copy();
Img<I> indicesCopy = indices.copy();
Img<I> indptrCopy = indptr.copy();
return new SparseCSRImg<>(dimension(0), dimension(1), dataCopy, indicesCopy, indptrCopy);
}

@Override
public ImgFactory<D> factory() {
return new SparseImgFactory<>(data.getAt(0), indices.getAt(0), 0);
}

/**
* An iteration order that scans a 2D image in row-major order.
* I.e., cursors iterate row by row and column by column. For instance a
* sparse img ranging from <em>(0,0)</em> to <em>(1,1)</em> is iterated like
* <em>(0,0), (1,0), (0,1), (1,1)</em>
*/
public static class RowMajorIterationOrder2D {

private final Interval interval;
public RowMajorIterationOrder2D(final Interval interval) {
this.interval = interval;
}

@Override
public boolean equals(final Object obj) {

if (!(obj instanceof RowMajorIterationOrder2D))
return false;

return SparseImg.haveSameIterationSpace(interval, ((RowMajorIterationOrder2D) obj).interval);
}
}
}
173 changes: 173 additions & 0 deletions src/main/java/net/imglib2/img/sparse/SparseImg.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package net.imglib2.img.sparse;

import net.imglib2.Cursor;
import net.imglib2.Interval;
import net.imglib2.RandomAccess;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.NumericType;
import net.imglib2.type.numeric.integer.LongType;

import java.util.ArrayList;
import java.util.List;

abstract public class SparseImg<
D extends NumericType<D> & NativeType<D>,
I extends IntegerType<I> & NativeType<I>> implements Img<D> {

protected final long[] max;
protected final Img<D> data;
protected final Img<I> indices;
protected final Img<I> indptr;

public SparseImg(long numCols, long numRows, Img<D> data, Img<I> indices, Img<I> indptr) {

this.data = data;
this.indices = indices;
this.indptr = indptr;
this.max = new long[]{numCols-1, numRows-1};

if (data.numDimensions() != 1 || indices.numDimensions() != 1 || indptr.numDimensions() != 1)
throw new IllegalArgumentException("Data, index, and indptr Img must be one dimensional.");
if (data.min(0) != 0 || indices.min(0) != 0 || indptr.min(0) != 0)
throw new IllegalArgumentException("Data, index, and indptr arrays must start from 0.");
if (data.max(0) != indices.max(0))
throw new IllegalArgumentException("Data and index array must be of the same size.");
if (indptr.max(0) != max[0]+1 && indptr.max(0) != max[1]+1)
throw new IllegalArgumentException("Indptr array does not fit number of slices.");
}

public static <T extends NumericType<T> & NativeType<T>> SparseImg<T, LongType> convertToSparse(Img<T> img) {
return convertToSparse(img, 0); // CSR per default
}

public static <T extends NumericType<T> & NativeType<T>> SparseImg<T, LongType> convertToSparse(Img<T> img, int leadingDimension) {
if (leadingDimension != 0 && leadingDimension != 1)
throw new IllegalArgumentException("Leading dimension in sparse array must be 0 or 1.");

T zeroValue = img.getAt(0, 0).copy();
zeroValue.setZero();

int nnz = getNumberOfNonzeros(img);
int ptrDimension = 1 - leadingDimension;
Img<T> data = new ArrayImgFactory<>(zeroValue).create(nnz);
Img<LongType> indices = new ArrayImgFactory<>(new LongType()).create(nnz);
Img<LongType> indptr = new ArrayImgFactory<>(new LongType()).create(img.dimension(ptrDimension) + 1);

long count = 0;
T actualValue;
RandomAccess<T> ra = img.randomAccess();
RandomAccess<T> dataAccess = data.randomAccess();
RandomAccess<LongType> indicesAccess = indices.randomAccess();
RandomAccess<LongType> indptrAccess = indptr.randomAccess();
indptrAccess.setPosition(0,0);
indptrAccess.get().setLong(0L);

for (long j = 0; j < img.dimension(ptrDimension); j++) {
ra.setPosition(j, ptrDimension);
for (long i = 0; i < img.dimension(leadingDimension); i++) {
ra.setPosition(i, leadingDimension);
actualValue = ra.get();
if (!actualValue.valueEquals(zeroValue)) {
dataAccess.setPosition(count, 0);
dataAccess.get().set(actualValue);
indicesAccess.setPosition(count, 0);
indicesAccess.get().setLong(i);
count++;
}
}
indptrAccess.fwd(0);
indptrAccess.get().setLong(count);
}

return (leadingDimension == 0) ? new SparseCSRImg<>(img.dimension(0), img.dimension(1), data, indices, indptr)
: new SparseCSCImg<>(img.dimension(0), img.dimension(1), data, indices, indptr);
}

public static <T extends NumericType<T>> int getNumberOfNonzeros(Img<T> img) {
T zeroValue = img.getAt(0, 0).copy();
zeroValue.setZero();

int nnz = 0;
for (T pixel : img)
if (!pixel.valueEquals(zeroValue))
++nnz;
return nnz;
}

@Override
public long min(int d) {
return 0L;
}

@Override
public long max(int d) {
return max[d];
}

@Override
public int numDimensions() {
return 2;
}

@Override
public RandomAccess<D> randomAccess(Interval interval) {
return randomAccess();
}

public Img<D> getDataArray() {
return data;
}

public Img<I> getIndicesArray() {
return indices;
}

public Img<I> getIndexPointerArray() {
return indptr;
}

@Override
public Cursor<D> cursor() {
return localizingCursor();
}

@Override
public long size() {
return max[0] * max[1];
}

/**
* Checks if two intervals have the same iteration space.
*
* @param a One interval
* @param b Other interval
* @return true if both intervals have compatible non-singleton dimensions, false otherwise
*/
protected static boolean haveSameIterationSpace(Interval a, Interval b) {
List<Integer> nonSingletonDimA = nonSingletonDimensions(a);
List<Integer> nonSingletonDimB = nonSingletonDimensions(b);

if (nonSingletonDimA.size() != nonSingletonDimB.size())
return false;

for (int i = 0; i < nonSingletonDimA.size(); i++) {
Integer dimA = nonSingletonDimA.get(i);
Integer dimB = nonSingletonDimB.get(i);
if (a.min(dimA) != b.min(dimB) || a.max(dimA) != b.max(dimB))
return false;
}

return true;
}

protected static List<Integer> nonSingletonDimensions(Interval interval) {
List<Integer> nonSingletonDim = new ArrayList<>();
for (int i = 0; i < interval.numDimensions(); i++)
if (interval.dimension(i) > 1)
nonSingletonDim.add(i);
return nonSingletonDim;
}
}
60 changes: 60 additions & 0 deletions src/main/java/net/imglib2/img/sparse/SparseImgFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package net.imglib2.img.sparse;

import net.imglib2.Dimensions;
import net.imglib2.exception.IncompatibleTypeException;
import net.imglib2.img.Img;
import net.imglib2.img.ImgFactory;
import net.imglib2.img.array.ArrayImg;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.NumericType;

/**
* Factory for {@link SparseImg}s.
* @param <D> type of data
* @param <I> type of indices
*/
public class SparseImgFactory<
D extends NumericType<D> & NativeType<D>,
I extends IntegerType<I> & NativeType<I>> extends ImgFactory<D> {

protected final int leadingDimension;
protected final I indexType;


protected SparseImgFactory(D type, I indexType, int leadingDimension) {
super(type);
this.leadingDimension = leadingDimension;
this.indexType = indexType;
}

@Override
public SparseImg<D, I> create(long... dimensions) {
if (dimensions.length != 2)
throw new IllegalArgumentException("Only 2D images are supported");

Dimensions.verify(dimensions);
ArrayImg<D, ?> data = new ArrayImgFactory<>(type()).create(1);
ArrayImg<I, ?> indices = new ArrayImgFactory<>(indexType).create(1);
int secondaryDimension = 1 - leadingDimension;
ArrayImg<I, ?> indptr = new ArrayImgFactory<>(indexType).create(dimensions[secondaryDimension] + 1);

return (leadingDimension == 0) ? new SparseCSRImg<>(dimensions[0], dimensions[1], data, indices, indptr)
: new SparseCSCImg<>(dimensions[0], dimensions[1], data, indices, indptr);
}

@Override
@SuppressWarnings({"unchecked", "rawtypes"})
public <S> ImgFactory<S> imgFactory(S type) throws IncompatibleTypeException {
if (type instanceof NumericType && type instanceof NativeType)
return new SparseImgFactory<>((NumericType & NativeType) type, indexType, leadingDimension);
else
throw new IncompatibleTypeException(this, type.getClass().getCanonicalName() + " does not implement NumericType & NativeType.");
}

@Override
public Img<D> create(long[] dim, D type) {
return create(dim);
}
}
Loading