diff --git a/metagraph/external-libraries/caches b/metagraph/external-libraries/caches index b969baa3e7..9843c53447 160000 --- a/metagraph/external-libraries/caches +++ b/metagraph/external-libraries/caches @@ -1 +1 @@ -Subproject commit b969baa3e774fac964487e14bb0122708b09114b +Subproject commit 9843c534478a9188343c1cf41fb3e9269df872dd diff --git a/metagraph/integration_tests/base.py b/metagraph/integration_tests/base.py index c655bf7055..84f3f5aafe 100644 --- a/metagraph/integration_tests/base.py +++ b/metagraph/integration_tests/base.py @@ -106,7 +106,9 @@ def _clean(graph, output, extra_params=''): def _annotate_graph(input, graph_path, output, anno_repr, separate=False, no_fork_opt=False, no_anchor_opt=False): target_anno = anno_repr - if anno_repr in {'row_sparse'} or anno_repr.endswith('brwt') or anno_repr.startswith('row_diff'): + if (anno_repr in {'row_sparse', 'column_coord'} or + anno_repr.endswith('brwt') or + anno_repr.startswith('row_diff')): target_anno = anno_repr anno_repr = 'column' elif anno_repr in {'flat', 'rbfish'}: @@ -117,6 +119,9 @@ def _annotate_graph(input, graph_path, output, anno_repr, -i {graph_path} --anno-type {anno_repr} \ -o {output} {input}' + if target_anno.endswith('_coord'): + command += ' --coordinates' + with_counts = target_anno.endswith('int_brwt') if with_counts: command += ' --count-kmers' diff --git a/metagraph/integration_tests/test_align.py b/metagraph/integration_tests/test_align.py index fa789f5e70..b7fcb25522 100644 --- a/metagraph/integration_tests/test_align.py +++ b/metagraph/integration_tests/test_align.py @@ -60,6 +60,66 @@ def test_simple_align_all_graphs(self, representation): self.assertEqual(last_split[1], "AACAGAGAATTGTTTAAATTACAATCTTAGCTATGGGTGCTAAAGGTGGAGTTATAGACTTTTTCACTGATTTGTCGTTGGAAAAAGCTTTTCATCTCGGGTTTACAAGTCTGGTGTATTTGTTTATACTAGAAGGACAGGCGCATTTGA") self.assertEqual(last_split[4], "22") + @parameterized.expand(GRAPH_TYPES) + def test_simple_align_map_all_graphs(self, representation): + + self._build_graph(input=TEST_DATA_DIR + '/genome.MT.fa', + output=self.tempdir.name + '/genome.MT', + k=11, repr=representation, + extra_params="--mask-dummy") + + res = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) + params_str = res.stdout.decode().split('\n')[2:] + self.assertEqual('k: 11', params_str[0]) + self.assertEqual('nodes (k): 16438', params_str[1]) + self.assertEqual('mode: basic', params_str[2]) + + stats_command = '{exe} align -i {graph} --map --count-kmers {reads}'.format( + exe=METAGRAPH, + graph=self.tempdir.name + '/genome.MT' + graph_file_extension[representation], + reads=TEST_DATA_DIR + '/genome_MT1.fq', + ) + res = subprocess.run(stats_command.split(), stdout=PIPE) + self.assertEqual(res.returncode, 0) + params_str = res.stdout.decode().rstrip().split('\n') + self.assertEqual(len(params_str), 6) + self.assertEqual(params_str[0], 'MT-10/1\t1/140/1') + self.assertEqual(params_str[1], 'MT-8/1\t140/140/140') + self.assertEqual(params_str[2], 'MT-6/1\t140/140/140') + self.assertEqual(params_str[3], 'MT-4/1\t0/140/0') + self.assertEqual(params_str[4], 'MT-2/1\t140/140/140') + self.assertEqual(params_str[5], 'MT-11/1\t1/140/1') + + @parameterized.expand(GRAPH_TYPES) + def test_simple_align_map_canonical_all_graphs(self, representation): + + self._build_graph(input=TEST_DATA_DIR + '/genome.MT.fa', + output=self.tempdir.name + '/genome.MT', + k=11, repr=representation, mode='canonical', + extra_params="--mask-dummy") + + res = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) + params_str = res.stdout.decode().split('\n')[2:] + self.assertEqual('k: 11', params_str[0]) + self.assertEqual('nodes (k): 32782', params_str[1]) + self.assertEqual('mode: canonical', params_str[2]) + + stats_command = '{exe} align -i {graph} --map --count-kmers {reads}'.format( + exe=METAGRAPH, + graph=self.tempdir.name + '/genome.MT' + graph_file_extension[representation], + reads=TEST_DATA_DIR + '/genome_MT1.fq', + ) + res = subprocess.run(stats_command.split(), stdout=PIPE) + self.assertEqual(res.returncode, 0) + params_str = res.stdout.decode().rstrip().split('\n') + self.assertEqual(len(params_str), 6) + self.assertEqual(params_str[0], 'MT-10/1\t140/140/140') + self.assertEqual(params_str[1], 'MT-8/1\t140/140/140') + self.assertEqual(params_str[2], 'MT-6/1\t140/140/140') + self.assertEqual(params_str[3], 'MT-4/1\t129/140/129') + self.assertEqual(params_str[4], 'MT-2/1\t140/140/139') + self.assertEqual(params_str[5], 'MT-11/1\t2/140/2') + @parameterized.expand(['succinct']) def test_simple_align_json_all_graphs(self, representation): diff --git a/metagraph/integration_tests/test_query.py b/metagraph/integration_tests/test_query.py index 1780680900..159cab88e3 100644 --- a/metagraph/integration_tests/test_query.py +++ b/metagraph/integration_tests/test_query.py @@ -6,6 +6,7 @@ from tempfile import TemporaryDirectory import glob import os +import numpy as np from helpers import get_test_class_name from base import TestingBase, METAGRAPH, TEST_DATA_DIR, graph_file_extension @@ -16,6 +17,7 @@ PROTEIN_MODE = os.readlink(METAGRAPH).endswith("_Protein") anno_file_extension = {'column': '.column.annodbg', + 'column_coord': '.column_coord.annodbg', 'row': '.row.annodbg', 'row_diff': '.row_diff.annodbg', 'row_sparse': '.row_sparse.annodbg', @@ -503,6 +505,257 @@ def test_batch_query_with_tiny_batch(self): self.assertEqual(res.returncode, 0) self.assertEqual(len(res.stdout), 136959) + def test_query_coordinates(self): + if not self.anno_repr.endswith('_coord'): + self.skipTest('annotation does not support coordinates') + + query_command = f'{METAGRAPH} query --query-coords \ + -i {self.tempdir.name}/graph{graph_file_extension[self.graph_repr]} \ + -a {self.tempdir.name}/annotation{anno_file_extension[self.anno_repr]} \ + --discovery-fraction 0.05 {TEST_DATA_DIR}/transcripts_100.fa' + + res = subprocess.run(query_command.split(), stdout=PIPE) + self.assertEqual(res.returncode, 0) + self.assertEqual(len(res.stdout), 2155983) + + query_command = f'{METAGRAPH} query --query-coords \ + -i {self.tempdir.name}/graph{graph_file_extension[self.graph_repr]} \ + -a {self.tempdir.name}/annotation{anno_file_extension[self.anno_repr]} \ + --discovery-fraction 0.95 {TEST_DATA_DIR}/transcripts_100.fa' + + res = subprocess.run(query_command.split(), stdout=PIPE) + self.assertEqual(res.returncode, 0) + self.assertEqual(len(res.stdout), 687712) + + +@parameterized_class(('graph_repr', 'anno_repr'), + input_values=product( + [repr for repr in GRAPH_TYPES if not (repr == 'bitmap' and PROTEIN_MODE)], + ['int_brwt', 'row_diff_int_brwt'] + ), + class_name_func=get_test_class_name +) +class TestQueryCounts(TestingBase): + @classmethod + def setUpClass(cls): + cls.tempdir = TemporaryDirectory() + + cls.kmer_counts_1 = { + 'AAA': 1, + 'AAC': 2, + 'ACC': 3, + 'CCC': 4, + 'CCG': 5, + 'CGG': 6, + 'GGG': 7, + 'GGT': 8, + 'GTT': 9, + 'TTT': 10, + 'TTA': 11, + 'TAA': 12, + } + cls.kmer_counts_2 = { + 'AAA': 11, + 'AAC': 12, + 'ACC': 13, + 'CCC': 14, + 'CCG': 15, + 'CGG': 16, + 'GGG': 17, + 'GGT': 18, + 'GTT': 19, + 'TTT': 20, + } + fasta_file = cls.tempdir.name + '/file.fa' + with open(fasta_file, 'w') as f: + for kmer, count in cls.kmer_counts_1.items(): + f.write(f'>L1\n{kmer}\n' * count) + + for kmer, count in cls.kmer_counts_2.items(): + f.write(f'>L2\n{kmer}\n' * count) + + cls.k = 3 + + cls.with_bloom = False + if cls.graph_repr == 'succinct_bloom': + cls.graph_repr = 'succinct' + cls.with_bloom = True + + cls.mask_dummy = False + if cls.graph_repr == 'succinct_mask': + cls.graph_repr = 'succinct' + cls.mask_dummy = True + + construct_command = f"{METAGRAPH} build {'--mask-dummy' if cls.mask_dummy else ''} -p {NUM_THREADS} \ + --graph {cls.graph_repr} -k {cls.k} -o {cls.tempdir.name}/graph {fasta_file}" + + res = subprocess.run([construct_command], shell=True) + assert(res.returncode == 0) + + stats_command = '{exe} stats {graph}'.format( + exe=METAGRAPH, + graph=cls.tempdir.name + '/graph' + graph_file_extension[cls.graph_repr], + ) + res = subprocess.run(stats_command.split(), stdout=PIPE) + assert(res.returncode == 0) + params_str = res.stdout.decode().split('\n')[2:] + assert('k: 3' == params_str[0]) + if cls.graph_repr != 'succinct' or cls.mask_dummy: + assert('nodes (k): 12' == params_str[1]) + assert('mode: basic' == params_str[2]) + + if cls.with_bloom: + convert_command = '{exe} transform -o {outfile} --initialize-bloom {bloom_param} {input}'.format( + exe=METAGRAPH, + outfile=cls.tempdir.name + '/graph', + bloom_param='--bloom-fpp 0.1', + input=cls.tempdir.name + '/graph' + graph_file_extension[cls.graph_repr], + ) + res = subprocess.run([convert_command], shell=True) + assert(res.returncode == 0) + + def check_suffix(anno_repr, suffix): + match = anno_repr.endswith(suffix) + if match: + anno_repr = anno_repr[:-len(suffix)] + return anno_repr, match + + cls.anno_repr, separate = check_suffix(cls.anno_repr, '_separate') + cls.anno_repr, no_fork_opt = check_suffix(cls.anno_repr, '_no_fork_opt') + cls.anno_repr, no_anchor_opt = check_suffix(cls.anno_repr, '_no_anchor_opt') + + cls._annotate_graph( + fasta_file, + cls.tempdir.name + '/graph' + graph_file_extension[cls.graph_repr], + cls.tempdir.name + '/annotation', + cls.anno_repr, + separate, + no_fork_opt, + no_anchor_opt + ) + + # check annotation + anno_stats_command = '{exe} stats -a {annotation}'.format( + exe=METAGRAPH, + annotation=cls.tempdir.name + '/annotation' + anno_file_extension[cls.anno_repr], + ) + res = subprocess.run(anno_stats_command.split(), stdout=PIPE) + assert(res.returncode == 0) + params_str = res.stdout.decode().split('\n')[2:] + assert('labels: 2' == params_str[0]) + if cls.graph_repr != 'hashfast' and (cls.graph_repr != 'succinct' or cls.mask_dummy): + assert('objects: 12' == params_str[1]) + assert('representation: ' + cls.anno_repr == params_str[3]) + + def test_count_query(self): + query_file = self.tempdir.name + '/query.fa' + queries = [ + 'AAA', + 'AAAA', + 'AAAAAAAAAAAAA', + 'CCC', + 'CCCC', + 'CCCCCCCCCCCCC', + 'TTT', + 'AAACCCGGGTTT', + 'AAACCCGGGTTTTTT', + 'AAACCCGGGTTTAAA', + 'TTTAAACCCGGG', + 'ACACACACACACATTTAAACCCGGG', + ] + for discovery_rate in np.linspace(0, 1, 5): + expected_output = '' + with open(query_file, 'w') as f: + for i, s in enumerate(queries): + f.write(f'>s{i}\n{s}\n') + expected_output += f'{i}\ts{i}' + def get_count(d, kmer): + try: + return d[kmer] + except: + return 0 + + num_kmers = len(s) - self.k + 1 + + num_matches_1 = sum([get_count(self.kmer_counts_1, s[i:i + self.k]) > 0 for i in range(num_kmers)]) + count_1 = sum([get_count(self.kmer_counts_1, s[i:i + self.k]) for i in range(len(s) - self.k + 1)]) + + num_matches_2 = sum([get_count(self.kmer_counts_2, s[i:i + self.k]) > 0 for i in range(num_kmers)]) + count_2 = sum([get_count(self.kmer_counts_2, s[i:i + self.k]) for i in range(len(s) - self.k + 1)]) + + for (c, i, n) in [(count_1, 1, num_matches_1), (count_2, 0, num_matches_2)]: + if n >= discovery_rate * num_kmers: + expected_output += f'\t:{c}' + + expected_output += '\n' + + query_command = f'{METAGRAPH} query --fast --count-kmers \ + -i {self.tempdir.name}/graph{graph_file_extension[self.graph_repr]} \ + -a {self.tempdir.name}/annotation{anno_file_extension[self.anno_repr]} \ + --discovery-fraction {discovery_rate} {query_file}' + + res = subprocess.run(query_command.split(), stdout=PIPE) + self.assertEqual(res.returncode, 0) + self.assertEqual(res.stdout.decode(), expected_output) + + def test_count_quantiles(self): + query_file = self.tempdir.name + '/query.fa' + queries = [ + 'AAA', + 'AAAA', + 'AAAAAAAAAAAAA', + 'CCC', + 'CCCC', + 'CCCCCCCCCCCCC', + 'TTT', + 'AAACCCGGGTTT', + 'AAACCCGGGTTTTTT', + 'AAACCCGGGTTTAAA', + 'TTTAAACCCGGG', + 'ACACACACACACATTTAAACCCGGG', + ] + quantiles = np.linspace(0, 1, 100) + expected_output = '' + with open(query_file, 'w') as f: + for i, s in enumerate(queries): + f.write(f'>s{i}\n{s}\n') + expected_output += f'{i}\ts{i}\t' + def get_count(d, kmer): + try: + return d[kmer] + except: + return 0 + for p in quantiles: + counts = [get_count(self.kmer_counts_1, s[i:i + self.k]) for i in range(len(s) - self.k + 1)] + expected_output += f':{np.quantile(counts, p, interpolation="lower")}' + expected_output += f'\t' + for p in quantiles: + counts = [get_count(self.kmer_counts_2, s[i:i + self.k]) for i in range(len(s) - self.k + 1)] + expected_output += f':{np.quantile(counts, p, interpolation="lower")}' + expected_output += '\n' + + query_command = f'{METAGRAPH} query --fast --count-quantiles \ + -i {self.tempdir.name}/graph{graph_file_extension[self.graph_repr]} \ + -a {self.tempdir.name}/annotation{anno_file_extension[self.anno_repr]} \ + --discovery-fraction 0.0 {query_file}' + + query_command = query_command.split() + query_command[4] = ' '.join([str(p) for p in quantiles]) + res = subprocess.run(query_command, stdout=PIPE) + self.assertEqual(res.returncode, 0) + self.assertEqual(res.stdout.decode(), expected_output) + + query_command = f'{METAGRAPH} query --fast --count-quantiles \ + -i {self.tempdir.name}/graph{graph_file_extension[self.graph_repr]} \ + -a {self.tempdir.name}/annotation{anno_file_extension[self.anno_repr]} \ + --discovery-fraction 1.0 {query_file}' + + query_command = query_command.split() + query_command[4] = ' '.join([str(p) for p in quantiles]) + res = subprocess.run(query_command, stdout=PIPE) + self.assertEqual(res.returncode, 0) + self.assertEqual(len(res.stdout.decode()), 5230) + @parameterized_class(('graph_repr', 'anno_repr'), input_values=(product(list(set(GRAPH_TYPES) - {'hashstr'}), ANNO_TYPES) + diff --git a/metagraph/src/annotation/binary_matrix/column_sparse/column_major.cpp b/metagraph/src/annotation/binary_matrix/column_sparse/column_major.cpp index 3942571a39..c9174e188d 100644 --- a/metagraph/src/annotation/binary_matrix/column_sparse/column_major.cpp +++ b/metagraph/src/annotation/binary_matrix/column_sparse/column_major.cpp @@ -56,6 +56,39 @@ ColumnMajor::get_rows(const std::vector &row_ids) const { return rows; } +Vector> +ColumnMajor::get_column_ranks(Row row) const { + assert(row < num_rows() || !columns_.size()); + + Vector> result; + for (size_t i = 0; i < columns_.size(); ++i) { + assert(columns_[i]); + + if (uint64_t r = columns_[i]->conditional_rank1(row)) + result.emplace_back(i, r); + } + return result; +} + +std::vector>> +ColumnMajor::get_column_ranks(const std::vector &row_ids) const { + std::vector>> result(row_ids.size()); + + for (size_t j = 0; j < columns_.size(); ++j) { + assert(columns_[j]); + const bit_vector &col = *columns_[j]; + + for (size_t i = 0; i < row_ids.size(); ++i) { + assert(row_ids[i] < num_rows()); + + if (uint64_t r = col.conditional_rank1(row_ids[i])) + result[i].emplace_back(j, r); + } + } + + return result; +} + std::vector ColumnMajor::slice_rows(const std::vector &row_ids) const { std::vector slice; @@ -91,6 +124,7 @@ bool ColumnMajor::load(std::istream &in) { for (auto &c : columns_) { assert(!c); + // TODO: switch to bit_vector_smart? c = std::make_unique(); if (!c->load(in)) return false; diff --git a/metagraph/src/annotation/binary_matrix/column_sparse/column_major.hpp b/metagraph/src/annotation/binary_matrix/column_sparse/column_major.hpp index 91d7451ec0..7aa95bf639 100644 --- a/metagraph/src/annotation/binary_matrix/column_sparse/column_major.hpp +++ b/metagraph/src/annotation/binary_matrix/column_sparse/column_major.hpp @@ -23,6 +23,10 @@ class ColumnMajor : public BinaryMatrix { bool get(Row row, Column column) const override; SetBitPositions get_row(Row row) const override; std::vector get_rows(const std::vector &rows) const override; + // query row and get ranks of each set bit in its column + Vector> get_column_ranks(Row row) const; + std::vector>> + get_column_ranks(const std::vector &rows) const; std::vector get_column(Column column) const override; // get all selected rows appended with -1 and concatenated std::vector slice_rows(const std::vector &rows) const override; diff --git a/metagraph/src/annotation/int_matrix/base/int_matrix.cpp b/metagraph/src/annotation/int_matrix/base/int_matrix.cpp index 07562f983a..11b8418a2f 100644 --- a/metagraph/src/annotation/int_matrix/base/int_matrix.cpp +++ b/metagraph/src/annotation/int_matrix/base/int_matrix.cpp @@ -7,29 +7,31 @@ namespace matrix { IntMatrix::RowValues IntMatrix::sum_row_values(const std::vector> &index_counts, - size_t min, - size_t cap) const { - assert(cap >= min); - - if (!cap) - return {}; - - min = std::max(min, size_t(1)); + size_t min_count) const { + min_count = std::max(min_count, size_t(1)); + std::vector rows; + rows.reserve(index_counts.size()); size_t total_sum = 0; - for (const auto &pair : index_counts) { - total_sum += pair.second; + for (const auto &[i, count] : index_counts) { + total_sum += count; + rows.push_back(i); } - if (total_sum < min) + if (total_sum < min_count) return {}; std::vector sum_row(num_columns(), 0); + std::vector counts(num_columns(), 0); - for (auto [i, count] : index_counts) { - for (auto [j, value] : get_row_values(i)) { + auto row_values = get_row_values(rows); + + for (size_t t = 0; t < index_counts.size(); ++t) { + auto [i, count] = index_counts[t]; + for (const auto &[j, value] : row_values[t]) { assert(j < sum_row.size()); sum_row[j] += count * value; + counts[j] += count; } } @@ -37,14 +39,47 @@ IntMatrix::sum_row_values(const std::vector> &index_count result.reserve(sum_row.size()); for (size_t j = 0; j < num_columns(); ++j) { - if (sum_row[j] >= min) { - result.emplace_back(j, std::min(sum_row[j], cap)); + if (counts[j] >= min_count) { + result.emplace_back(j, sum_row[j]); } } return result; } + +// return sizes of all non-empty tuples in the row +MultiIntMatrix::RowValues MultiIntMatrix::get_row_values(Row row) const { + RowTuples row_tuples = get_row_tuples(row); + + RowValues row_values(row_tuples.size()); + + for (size_t i = 0; i < row_tuples.size(); ++i) { + row_values[i].first = row_tuples[i].first; + row_values[i].second = row_tuples[i].second.size(); + } + + return row_values; +} + +// for each row return the sizes of all non-empty tuples +std::vector +MultiIntMatrix::get_row_values(const std::vector &rows) const { + std::vector row_tuples = get_row_tuples(rows); + + std::vector row_values(row_tuples.size()); + + for (size_t i = 0; i < row_tuples.size(); ++i) { + row_values[i].resize(row_tuples[i].size()); + for (size_t j = 0; j < row_tuples[i].size(); ++j) { + row_values[i][j].first = row_tuples[i][j].first; + row_values[i][j].second = row_tuples[i][j].second.size(); + } + } + + return row_values; +} + } // namespace matrix } // namespace annot } // namespace mtg diff --git a/metagraph/src/annotation/int_matrix/base/int_matrix.hpp b/metagraph/src/annotation/int_matrix/base/int_matrix.hpp index 1f595b5082..00beb51f6e 100644 --- a/metagraph/src/annotation/int_matrix/base/int_matrix.hpp +++ b/metagraph/src/annotation/int_matrix/base/int_matrix.hpp @@ -23,13 +23,38 @@ class IntMatrix : public binmat::BinaryMatrix { virtual std::vector get_row_values(const std::vector &rows) const = 0; - // Get all columns for which the sum of the values in queried rows - // is greater than or equal to |min|. Stop counting if the sum is - // greater than |cap|. + // sum up values for each column with at least |min_count| non-zero values virtual RowValues sum_row_values(const std::vector> &index_counts, - size_t min = 1, - size_t cap = std::numeric_limits::max()) const; + size_t min_count = 1) const; +}; + + +// Entries are tuples and their aggregated `values` are tuple sizes +class MultiIntMatrix : public IntMatrix { + public: + typedef SmallVector Tuple; + typedef Vector> RowTuples; + + virtual ~MultiIntMatrix() {} + + // return tuple sizes (if not zero) at each entry + virtual RowValues get_row_values(Row row) const; + + virtual std::vector + get_row_values(const std::vector &rows) const; + + // return total number of attributes in all tuples + virtual uint64_t num_attributes() const = 0; + + // return entries of the matrix -- where each entry is a set of integers + virtual RowTuples get_row_tuples(Row row) const = 0; + + virtual std::vector + get_row_tuples(const std::vector &rows) const = 0; + + virtual bool load_tuples(std::istream &in) = 0; + virtual void serialize_tuples(std::ostream &out) const = 0; }; } // namespace matrix diff --git a/metagraph/src/annotation/int_matrix/rank_extended/tuple_csc_matrix.hpp b/metagraph/src/annotation/int_matrix/rank_extended/tuple_csc_matrix.hpp new file mode 100644 index 0000000000..08c5a78b11 --- /dev/null +++ b/metagraph/src/annotation/int_matrix/rank_extended/tuple_csc_matrix.hpp @@ -0,0 +1,203 @@ +#ifndef __TUPLE_CSC_MATRIX_HPP__ +#define __TUPLE_CSC_MATRIX_HPP__ + +#include + +#include "annotation/int_matrix/base/int_matrix.hpp" +#include "common/logger.hpp" + + +namespace mtg { +namespace annot { +namespace matrix { + +/** + * Multi-Value Compressed Sparse Column Matrix (column-rank extended) + * + * Matrix which stores the non-empty tuples externally and indexes their + * positions in a binary matrix. These values are indexed by rank1 called + * on binary columns of the indexing matrix. + */ +template , + class Delims = bit_vector_smart> +class TupleCSCMatrix : public MultiIntMatrix { + public: + TupleCSCMatrix() {} + + TupleCSCMatrix(BaseMatrix&& index_matrix) + : binary_matrix_(std::move(index_matrix)) {} + + // return tuple sizes (if not zero) at each entry + RowValues get_row_values(Row row) const; + + std::vector + get_row_values(const std::vector &rows) const; + + uint64_t num_attributes() const; + + // return entries of the matrix -- where each entry is a set of integers + RowTuples get_row_tuples(Row row) const; + + std::vector + get_row_tuples(const std::vector &rows) const; + + uint64_t num_columns() const { return binary_matrix_.num_columns(); } + uint64_t num_rows() const { return binary_matrix_.num_rows(); } + uint64_t num_relations() const { return binary_matrix_.num_relations(); } + + // row is in [0, num_rows), column is in [0, num_columns) + bool get(Row row, Column column) const { return binary_matrix_.get(row, column); } + SetBitPositions get_row(Row row) const { return binary_matrix_.get_row(row); } + std::vector get_rows(const std::vector &rows) const { + return binary_matrix_.get_rows(rows); + } + std::vector get_column(Column column) const { + return binary_matrix_.get_column(column); + } + // get all selected rows appended with -1 and concatenated + std::vector slice_rows(const std::vector &rows) const { + return binary_matrix_.slice_rows(rows); + } + + bool load(std::istream &in); + void serialize(std::ostream &out) const; + + bool load_tuples(std::istream &in); + void serialize_tuples(std::ostream &out) const; + + const BaseMatrix& get_binary_matrix() const { return binary_matrix_; } + + private: + BaseMatrix binary_matrix_; + std::vector delimiters_; + std::vector column_values_; +}; + + +template +inline typename TupleCSCMatrix::RowValues +TupleCSCMatrix::get_row_values(Row row) const { + const auto &column_ranks = binary_matrix_.get_column_ranks(row); + RowValues row_values; + row_values.reserve(column_ranks.size()); + for (auto [j, r] : column_ranks) { + assert(r >= 1 && "matches can't have zero-rank"); + size_t tuple_size = delimiters_[j].select1(r + 1) - delimiters_[j].select1(r) - 1; + row_values.emplace_back(j, tuple_size); + } + return row_values; +} + +template +inline std::vector::RowValues> +TupleCSCMatrix::get_row_values(const std::vector &rows) const { + const auto &column_ranks = binary_matrix_.get_column_ranks(rows); + std::vector row_values(rows.size()); + // TODO: reshape? + for (size_t i = 0; i < rows.size(); ++i) { + row_values[i].reserve(column_ranks[i].size()); + for (auto [j, r] : column_ranks[i]) { + assert(r >= 1 && "matches can't have zero-rank"); + size_t tuple_size = delimiters_[j].select1(r + 1) - delimiters_[j].select1(r) - 1; + row_values[i].emplace_back(j, tuple_size); + } + } + return row_values; +} + +template +uint64_t TupleCSCMatrix::num_attributes() const { + uint64_t num_attributes = 0; + for (size_t j = 0; j < column_values_.size(); ++j) { + num_attributes += column_values_[j].size(); + } + return num_attributes; +} + +template +inline typename TupleCSCMatrix::RowTuples +TupleCSCMatrix::get_row_tuples(Row row) const { + const auto &column_ranks = binary_matrix_.get_column_ranks(row); + RowTuples row_tuples; + row_tuples.reserve(column_ranks.size()); + for (auto [j, r] : column_ranks) { + assert(r >= 1 && "matches can't have zero-rank"); + size_t begin = delimiters_[j].select1(r) + 1 - r; + size_t end = delimiters_[j].select1(r + 1) - r; + Tuple tuple; + tuple.reserve(end - begin); + for (size_t t = begin; t < end; ++t) { + tuple.push_back(column_values_[j][t]); + } + row_tuples.emplace_back(j, std::move(tuple)); + } + return row_tuples; +} + +template +inline std::vector::RowTuples> +TupleCSCMatrix::get_row_tuples(const std::vector &rows) const { + const auto &column_ranks = binary_matrix_.get_column_ranks(rows); + std::vector row_tuples(rows.size()); + // TODO: reshape? + for (size_t i = 0; i < rows.size(); ++i) { + row_tuples[i].reserve(column_ranks[i].size()); + for (auto [j, r] : column_ranks[i]) { + assert(r >= 1 && "matches can't have zero-rank"); + size_t begin = delimiters_[j].select1(r) + 1 - r; + size_t end = delimiters_[j].select1(r + 1) - r; + Tuple tuple; + tuple.reserve(end - begin); + for (size_t t = begin; t < end; ++t) { + tuple.push_back(column_values_[j][t]); + } + row_tuples[i].emplace_back(j, std::move(tuple)); + } + } + return row_tuples; +} + +template +inline bool TupleCSCMatrix::load(std::istream &in) { + return binary_matrix_.load(in) && load_tuples(in); +} + +template +inline bool TupleCSCMatrix::load_tuples(std::istream &in) { + delimiters_.clear(); + column_values_.clear(); + + delimiters_.resize(num_columns()); + column_values_.resize(num_columns()); + for (size_t j = 0; j < column_values_.size(); ++j) { + try { + delimiters_[j].load(in); + column_values_[j].load(in); + } catch (...) { + common::logger->error("Couldn't load tuple attributes for column {}", j); + return false; + } + } + return true; +} + +template +inline void TupleCSCMatrix::serialize(std::ostream &out) const { + binary_matrix_.serialize(out); + serialize_tuples(out); +} + +template +inline void TupleCSCMatrix::serialize_tuples(std::ostream &out) const { + for (size_t j = 0; j < column_values_.size(); ++j) { + delimiters_[j].serialize(out); + column_values_[j].serialize(out); + } +} + +} // namespace matrix +} // namespace annot +} // namespace mtg + +#endif // __TUPLE_CSC_MATRIX_HPP__ diff --git a/metagraph/src/annotation/representation/annotation_matrix/annotation_matrix.cpp b/metagraph/src/annotation/representation/annotation_matrix/annotation_matrix.cpp index 211e5722a9..32f06e4909 100644 --- a/metagraph/src/annotation/representation/annotation_matrix/annotation_matrix.cpp +++ b/metagraph/src/annotation/representation/annotation_matrix/annotation_matrix.cpp @@ -214,5 +214,7 @@ template class StaticBinRelAnnotator; +template class StaticBinRelAnnotator, std::string>; + } // namespace annot } // namespace mtg diff --git a/metagraph/src/annotation/representation/annotation_matrix/static_annotators_def.hpp b/metagraph/src/annotation/representation/annotation_matrix/static_annotators_def.hpp index 889afa4ca2..d1f53f31ac 100644 --- a/metagraph/src/annotation/representation/annotation_matrix/static_annotators_def.hpp +++ b/metagraph/src/annotation/representation/annotation_matrix/static_annotators_def.hpp @@ -16,6 +16,7 @@ #include "annotation/int_matrix/rank_extended/csc_matrix.hpp" #include "annotation/int_matrix/row_diff/int_row_diff.hpp" #include "annotation/int_matrix/csr_matrix/csr_matrix.hpp" +#include "annotation/int_matrix/rank_extended/tuple_csc_matrix.hpp" namespace mtg { @@ -51,6 +52,8 @@ typedef StaticBinRelAnnotator IntRowAnnotator; +typedef StaticBinRelAnnotator, std::string> ColumnCoordAnnotator; + template <> inline const std::string RowFlatAnnotator::kExtension = ".flat.annodbg"; @@ -80,6 +83,8 @@ template <> inline const std::string IntRowDiffBRWTAnnotator::kExtension = ".row_diff_int_brwt.annodbg"; template <> inline const std::string IntRowAnnotator::kExtension = ".int_csr.annodbg"; +template <> +inline const std::string ColumnCoordAnnotator::kExtension = ".column_coord.annodbg"; } // namespace annot } // namespace mtg diff --git a/metagraph/src/annotation/representation/column_compressed/annotate_column_compressed.cpp b/metagraph/src/annotation/representation/column_compressed/annotate_column_compressed.cpp index 9a5f419262..8cdf19f6ce 100644 --- a/metagraph/src/annotation/representation/column_compressed/annotate_column_compressed.cpp +++ b/metagraph/src/annotation/representation/column_compressed/annotate_column_compressed.cpp @@ -766,10 +766,11 @@ bitmap_builder& ColumnCompressed