From 04426f5e0d3b4f61d4eaa640c63064da9d2c8ca5 Mon Sep 17 00:00:00 2001 From: Mikhail Karasikov Date: Fri, 9 Jul 2021 14:12:41 +0100 Subject: [PATCH 1/4] query k-mer count quantiles (#331) --- metagraph/integration_tests/test_query.py | 171 ++++++++++++++++++++++ metagraph/src/cli/config/config.cpp | 8 +- metagraph/src/cli/config/config.hpp | 1 + metagraph/src/cli/query.cpp | 25 +++- metagraph/src/cli/query.hpp | 4 +- metagraph/src/graph/annotated_dbg.cpp | 82 +++++++++++ metagraph/src/graph/annotated_dbg.hpp | 6 + 7 files changed, 293 insertions(+), 4 deletions(-) diff --git a/metagraph/integration_tests/test_query.py b/metagraph/integration_tests/test_query.py index 1780680900..472cee85fa 100644 --- a/metagraph/integration_tests/test_query.py +++ b/metagraph/integration_tests/test_query.py @@ -6,6 +6,7 @@ from tempfile import TemporaryDirectory import glob import os +import numpy as np from helpers import get_test_class_name from base import TestingBase, METAGRAPH, TEST_DATA_DIR, graph_file_extension @@ -504,6 +505,176 @@ def test_batch_query_with_tiny_batch(self): self.assertEqual(len(res.stdout), 136959) +@parameterized_class(('graph_repr', 'anno_repr'), + input_values=product( + [repr for repr in GRAPH_TYPES if not (repr == 'bitmap' and PROTEIN_MODE)], + ['int_brwt', 'row_diff_int_brwt'] + ), + class_name_func=get_test_class_name +) +class TestQueryCounts(TestingBase): + @classmethod + def setUpClass(cls): + cls.tempdir = TemporaryDirectory() + + cls.kmer_counts_1 = { + 'AAA': 1, + 'AAC': 2, + 'ACC': 3, + 'CCC': 4, + 'CCG': 5, + 'CGG': 6, + 'GGG': 7, + 'GGT': 8, + 'GTT': 9, + 'TTT': 10, + 'TTA': 11, + 'TAA': 12, + } + cls.kmer_counts_2 = { + 'AAA': 11, + 'AAC': 12, + 'ACC': 13, + 'CCC': 14, + 'CCG': 15, + 'CGG': 16, + 'GGG': 17, + 'GGT': 18, + 'GTT': 19, + 'TTT': 20, + 'TTA': 21, + 'TAA': 22, + } + fasta_file = cls.tempdir.name + '/file.fa' + with open(fasta_file, 'w') as f: + for kmer, count in cls.kmer_counts_1.items(): + f.write(f'>L1\n{kmer}\n' * count) + + for kmer, count in cls.kmer_counts_2.items(): + f.write(f'>L2\n{kmer}\n' * count) + + cls.k = 3 + + cls.with_bloom = False + if cls.graph_repr == 'succinct_bloom': + cls.graph_repr = 'succinct' + cls.with_bloom = True + + cls.mask_dummy = False + if cls.graph_repr == 'succinct_mask': + cls.graph_repr = 'succinct' + cls.mask_dummy = True + + construct_command = f"{METAGRAPH} build {'--mask-dummy' if cls.mask_dummy else ''} -p {NUM_THREADS} \ + --graph {cls.graph_repr} -k {cls.k} -o {cls.tempdir.name}/graph {fasta_file}" + + res = subprocess.run([construct_command], shell=True) + assert(res.returncode == 0) + + stats_command = '{exe} stats {graph}'.format( + exe=METAGRAPH, + graph=cls.tempdir.name + '/graph' + graph_file_extension[cls.graph_repr], + ) + res = subprocess.run(stats_command.split(), stdout=PIPE) + assert(res.returncode == 0) + params_str = res.stdout.decode().split('\n')[2:] + assert('k: 3' == params_str[0]) + if cls.graph_repr != 'succinct' or cls.mask_dummy: + assert('nodes (k): 12' == params_str[1]) + assert('mode: basic' == params_str[2]) + + if cls.with_bloom: + convert_command = '{exe} transform -o {outfile} --initialize-bloom {bloom_param} {input}'.format( + exe=METAGRAPH, + outfile=cls.tempdir.name + '/graph', + bloom_param='--bloom-fpp 0.1', + input=cls.tempdir.name + '/graph' + graph_file_extension[cls.graph_repr], + ) + res = subprocess.run([convert_command], shell=True) + assert(res.returncode == 0) + + def check_suffix(anno_repr, suffix): + match = anno_repr.endswith(suffix) + if match: + anno_repr = anno_repr[:-len(suffix)] + return anno_repr, match + + cls.anno_repr, separate = check_suffix(cls.anno_repr, '_separate') + cls.anno_repr, no_fork_opt = check_suffix(cls.anno_repr, '_no_fork_opt') + cls.anno_repr, no_anchor_opt = check_suffix(cls.anno_repr, '_no_anchor_opt') + + cls._annotate_graph( + fasta_file, + cls.tempdir.name + '/graph' + graph_file_extension[cls.graph_repr], + cls.tempdir.name + '/annotation', + cls.anno_repr, + separate, + no_fork_opt, + no_anchor_opt + ) + + # check annotation + anno_stats_command = '{exe} stats -a {annotation}'.format( + exe=METAGRAPH, + annotation=cls.tempdir.name + '/annotation' + anno_file_extension[cls.anno_repr], + ) + res = subprocess.run(anno_stats_command.split(), stdout=PIPE) + assert(res.returncode == 0) + params_str = res.stdout.decode().split('\n')[2:] + assert('labels: 2' == params_str[0]) + if cls.graph_repr != 'hashfast' and (cls.graph_repr != 'succinct' or cls.mask_dummy): + assert('objects: 12' == params_str[1]) + assert('representation: ' + cls.anno_repr == params_str[3]) + + def test_count_quantiles(self): + query_file = self.tempdir.name + '/query.fa' + queries = [ + 'AAA', + 'AAAA', + 'AAAAAAAAAAAAA', + 'CCC', + 'CCCC', + 'CCCCCCCCCCCCC', + 'TTT', + 'AAACCCGGGTTT', + 'AAACCCGGGTTTTTT', + 'AAACCCGGGTTTAAA', + 'TTTAAACCCGGG', + 'ACACACACACACATTTAAACCCGGG', + ] + quantiles = np.linspace(0, 1, 100) + expected_output = '' + with open(query_file, 'w') as f: + for i, s in enumerate(queries): + f.write(f'>s{i}\n{s}\n') + expected_output += f'{i}\ts{i}\t' + def get_count(d, kmer): + try: + return d[kmer] + except: + return 0 + for p in quantiles: + counts = [get_count(self.kmer_counts_1, s[i:i + self.k]) for i in range(len(s) - self.k + 1)] + expected_output += f':{np.quantile(counts, p, interpolation="lower")}' + expected_output += f'\t' + for p in quantiles: + counts = [get_count(self.kmer_counts_2, s[i:i + self.k]) for i in range(len(s) - self.k + 1)] + expected_output += f':{np.quantile(counts, p, interpolation="lower")}' + expected_output += '\n' + + query_command = '{exe} query --fast --count-quantiles -i {graph} -a {annotation} --discovery-fraction 0.0 {input}'.format( + exe=METAGRAPH, + graph=self.tempdir.name + '/graph' + graph_file_extension[self.graph_repr], + annotation=self.tempdir.name + '/annotation' + anno_file_extension[self.anno_repr], + input=query_file + ) + query_command = query_command.split() + query_command[4] = ' '.join([str(p) for p in quantiles]) + res = subprocess.run(query_command, stdout=PIPE) + self.assertEqual(res.returncode, 0) + self.assertEqual(res.stdout.decode(), expected_output) + + @parameterized_class(('graph_repr', 'anno_repr'), input_values=(product(list(set(GRAPH_TYPES) - {'hashstr'}), ANNO_TYPES) + product(['succinct_bloom', 'succinct_mask'], ['flat'])), diff --git a/metagraph/src/cli/config/config.cpp b/metagraph/src/cli/config/config.cpp index a3f45a8717..8ea9a19378 100644 --- a/metagraph/src/cli/config/config.cpp +++ b/metagraph/src/cli/config/config.cpp @@ -180,6 +180,10 @@ Config::Config(int argc, char *argv[]) { for (const auto &border : utils::split_string(get_value(i++), " ")) { count_slice_quantiles.push_back(std::stod(border)); } + } else if (!strcmp(argv[i], "--count-quantiles")) { + for (const auto &p : utils::split_string(get_value(i++), " ")) { + count_quantiles.push_back(std::stod(p)); + } } else if (!strcmp(argv[i], "--aggregate-columns")) { aggregate_columns = true; } else if (!strcmp(argv[i], "--intersected-anno")) { @@ -1168,6 +1172,8 @@ void Config::print_usage(const std::string &prog_name, IdentityType identity) { fprintf(stderr, "\n"); fprintf(stderr, "\t --count-labels \t\tcount labels for k-mers from querying sequences [off]\n"); fprintf(stderr, "\t --count-kmers \t\tweight k-mers with their annotated counts (requires count annotation) [off]\n"); + fprintf(stderr, "\t --count-quantiles [FLOAT ...] \tk-mer count quantiles to compute for each label [off]\n" + "\t \t\tExample: --count-quantiles '0.33 0.5 0.66 1'\n"); fprintf(stderr, "\t --print-signature \t\tprint vectors indicating present/absent k-mers [off]\n"); fprintf(stderr, "\t --num-top-labels \t\tmaximum number of frequent labels to print [off]\n"); fprintf(stderr, "\t --discovery-fraction [FLOAT] fraction of labeled k-mers required for annotation [0.7]\n"); @@ -1186,7 +1192,7 @@ void Config::print_usage(const std::string &prog_name, IdentityType identity) { fprintf(stderr, "\t --align-min-path-score [INT]\t\t\tthe minimum score that a reported path can have [0]\n"); fprintf(stderr, "\t --align-edit-distance \t\t\tuse unit costs for scoring matrix [off]\n"); fprintf(stderr, "\t --align-max-nodes-per-seq-char [FLOAT]\tmaximum number of nodes to consider per sequence character [12.0]\n"); - fprintf(stderr, "\t --align-max-ram [FLOAT]\t\tmaximum amount of RAM used per alignment in MB [200.0]\n"); + fprintf(stderr, "\t --align-max-ram [FLOAT]\t\t\tmaximum amount of RAM used per alignment in MB [200.0]\n"); fprintf(stderr, "\n"); fprintf(stderr, "\t --batch-align \t\talign against query graph [off]\n"); fprintf(stderr, "\t --max-hull-forks [INT]\tmaximum number of forks to take when expanding query graph [4]\n"); diff --git a/metagraph/src/cli/config/config.hpp b/metagraph/src/cli/config/config.hpp index 5649a2ef18..d4b76d6bb7 100644 --- a/metagraph/src/cli/config/config.hpp +++ b/metagraph/src/cli/config/config.hpp @@ -133,6 +133,7 @@ class Config { double min_fraction = 0.0; double max_fraction = 1.0; std::vector count_slice_quantiles; + std::vector count_quantiles; std::vector fnames; std::vector anno_labels; diff --git a/metagraph/src/cli/query.cpp b/metagraph/src/cli/query.cpp index 41d8631d17..a098f9e442 100644 --- a/metagraph/src/cli/query.cpp +++ b/metagraph/src/cli/query.cpp @@ -60,7 +60,8 @@ std::string QueryExecutor::execute_query(const std::string &seq_name, double discovery_fraction, std::string anno_labels_delimiter, const AnnotatedDBG &anno_graph, - bool with_kmer_counts) { + bool with_kmer_counts, + const std::vector &count_quantiles) { std::string output; output.reserve(1'000); @@ -84,6 +85,26 @@ std::string QueryExecutor::execute_query(const std::string &seq_name, output += '\n'; + } else if (count_quantiles.size()) { + auto result = anno_graph.get_label_count_quantiles(sequence, + num_top_labels, + discovery_fraction, + count_quantiles); + + if (!result.size() && suppress_unlabeled) + return ""; + + output += seq_name; + + for (const auto &[label, quantiles] : result) { + output += "\t<" + label + ">"; + for (uint64_t count : quantiles) { + output += fmt::format(":{}", count); + } + } + + output += '\n'; + } else if (count_labels) { auto top_labels = anno_graph.get_top_labels(sequence, num_top_labels, @@ -893,7 +914,7 @@ std::string query_sequence(size_t id, std::string name, std::string seq, config.count_labels, config.print_signature, config.suppress_unlabeled, config.num_top_labels, config.discovery_fraction, config.anno_labels_delimiter, - anno_graph, config.count_kmers); + anno_graph, config.count_kmers, config.count_quantiles); } void QueryExecutor::query_fasta(const string &file, diff --git a/metagraph/src/cli/query.hpp b/metagraph/src/cli/query.hpp index c007d9bada..be7ae21dff 100644 --- a/metagraph/src/cli/query.hpp +++ b/metagraph/src/cli/query.hpp @@ -5,6 +5,7 @@ #include #include #include +#include class ThreadPool; @@ -63,7 +64,8 @@ class QueryExecutor { double discovery_fraction, std::string anno_labels_delimiter, const graph::AnnotatedDBG &anno_graph, - bool with_kmer_counts = false); + bool with_kmer_counts = false, + const std::vector &count_quantiles = {}); private: const Config &config_; diff --git a/metagraph/src/graph/annotated_dbg.cpp b/metagraph/src/graph/annotated_dbg.cpp index a9480dc2dc..033ae3d3ca 100644 --- a/metagraph/src/graph/annotated_dbg.cpp +++ b/metagraph/src/graph/annotated_dbg.cpp @@ -294,6 +294,88 @@ AnnotatedDBG::get_top_labels(std::string_view sequence, return top_labels; } +std::vector>> +AnnotatedDBG::get_label_count_quantiles(std::string_view sequence, + size_t num_top_labels, + double presence_ratio, + const std::vector &count_quantiles) const { + assert(presence_ratio >= 0.); + assert(presence_ratio <= 1.); + assert(check_compatibility()); + if (!std::is_sorted(count_quantiles.begin(), count_quantiles.end())) + throw std::runtime_error("Quantiles must be sorted"); + if (count_quantiles.at(0) < 0. || count_quantiles.back() > 1.) + throw std::runtime_error("Quantiles must be in range [0, 1]"); + + if (sequence.size() < dbg_.get_k()) + return {}; + + std::vector rows; + size_t num_kmers = sequence.size() - dbg_.get_k() + 1; + rows.reserve(num_kmers); + + graph_->map_to_nodes(sequence, [&](node_index i) { + if (i > 0) + rows.push_back(graph_to_anno_index(i)); + }); + + uint64_t min_count = std::max(1.0, std::ceil(presence_ratio * num_kmers)); + if (rows.size() < min_count) + return {}; + + std::vector q_low(count_quantiles.size()); + for (size_t i = 0; i < count_quantiles.size(); ++i) { + q_low[i] = (num_kmers - 1) * count_quantiles[i]; + } + + VectorMap> code_to_counts; + for (const auto &row_values : dynamic_cast(annotator_->get_matrix()) + .get_row_values(rows)) { + for (const auto &[column, count] : row_values) { + code_to_counts[column].push_back(count); + } + } + + std::vector>> code_counts; + code_counts.reserve(code_to_counts.size()); + for (auto &[j, counts] : code_to_counts.values_container()) { + // filter by the number of matched k-mers + if (counts.size() >= min_count) + code_counts.emplace_back(j, std::move(counts)); + } + // sort by the number of matched k-mers + std::sort(code_counts.begin(), code_counts.end(), + [](const auto &x, const auto &y) { + return x.second.size() > y.second.size() + || (x.second.size() == y.second.size() && x.first < y.first); + }); + // keep only the first |num_top_labels| top labels + if (code_counts.size() > num_top_labels) + code_counts.resize(num_top_labels); + + std::vector>> label_quantiles; + label_quantiles.reserve(code_counts.size()); + // Quantiles are defined as `count[i]` where `i < q * N <= i + 1` + for (auto &[j, counts] : code_counts) { + std::sort(counts.begin(), counts.end()); + const size_t num_zeros = num_kmers - counts.size(); + + label_quantiles.emplace_back(annotator_->get_label_encoder().decode(j), + std::vector(q_low.size())); + + std::vector &quantiles = label_quantiles.back().second; + for (size_t q = 0; q < q_low.size(); ++q) { + if (q_low[q] < num_zeros) { + quantiles[q] = 0; + } else { + quantiles[q] = counts[q_low[q] - num_zeros]; + } + } + } + + return label_quantiles; +} + std::vector> AnnotatedDBG::get_top_label_signatures(std::string_view sequence, size_t num_top_labels, diff --git a/metagraph/src/graph/annotated_dbg.hpp b/metagraph/src/graph/annotated_dbg.hpp index 2142f74c7b..d9b55284c5 100644 --- a/metagraph/src/graph/annotated_dbg.hpp +++ b/metagraph/src/graph/annotated_dbg.hpp @@ -116,6 +116,12 @@ class AnnotatedDBG : public AnnotatedSequenceGraph { size_t min_count = 0, bool with_kmer_counts = false) const; + std::vector>> + get_label_count_quantiles(std::string_view sequence, + size_t num_top_labels, + double presence_ratio, + const std::vector &count_quantiles) const; + std::vector> get_top_label_signatures(std::string_view sequence, size_t num_top_labels, From 8cf46aab16df9f1c8a19d23f664da35451b8516c Mon Sep 17 00:00:00 2001 From: Mikhail Karasikov Date: Wed, 14 Jul 2021 16:45:21 +0200 Subject: [PATCH 2/4] clarified syntax for flag --count-quantiles --- metagraph/src/cli/config/config.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/metagraph/src/cli/config/config.cpp b/metagraph/src/cli/config/config.cpp index 8ea9a19378..43eb36a94d 100644 --- a/metagraph/src/cli/config/config.cpp +++ b/metagraph/src/cli/config/config.cpp @@ -1173,7 +1173,8 @@ void Config::print_usage(const std::string &prog_name, IdentityType identity) { fprintf(stderr, "\t --count-labels \t\tcount labels for k-mers from querying sequences [off]\n"); fprintf(stderr, "\t --count-kmers \t\tweight k-mers with their annotated counts (requires count annotation) [off]\n"); fprintf(stderr, "\t --count-quantiles [FLOAT ...] \tk-mer count quantiles to compute for each label [off]\n" - "\t \t\tExample: --count-quantiles '0.33 0.5 0.66 1'\n"); + "\t \t\tExample: --count-quantiles '0 0.33 0.5 0.66 1'\n" + "\t \t\t(0 corresponds to MIN, 1 corresponds to MAX)\n"); fprintf(stderr, "\t --print-signature \t\tprint vectors indicating present/absent k-mers [off]\n"); fprintf(stderr, "\t --num-top-labels \t\tmaximum number of frequent labels to print [off]\n"); fprintf(stderr, "\t --discovery-fraction [FLOAT] fraction of labeled k-mers required for annotation [0.7]\n"); From 8325931dbded1211dea09f754834be7139eeebe1 Mon Sep 17 00:00:00 2001 From: Mikhail Karasikov Date: Sat, 17 Jul 2021 15:31:28 +0100 Subject: [PATCH 3/4] query k-mer coordinates (#337) * interfaces for querying k-mer coordinates * transform to column_coord * count queries without weighting by counts when filtering by the number of matches --- metagraph/integration_tests/base.py | 7 +- metagraph/integration_tests/test_query.py | 98 ++++++++- .../column_sparse/column_major.cpp | 34 +++ .../column_sparse/column_major.hpp | 4 + .../annotation/int_matrix/base/int_matrix.cpp | 65 ++++-- .../annotation/int_matrix/base/int_matrix.hpp | 35 ++- .../rank_extended/tuple_csc_matrix.hpp | 203 ++++++++++++++++++ .../annotation_matrix/annotation_matrix.cpp | 2 + .../static_annotators_def.hpp | 5 + metagraph/src/cli/config/config.cpp | 8 + metagraph/src/cli/config/config.hpp | 2 + .../src/cli/load/load_annotated_graph.cpp | 2 +- metagraph/src/cli/load/load_annotation.cpp | 7 + metagraph/src/cli/query.cpp | 41 +++- metagraph/src/cli/query.hpp | 3 +- metagraph/src/cli/stats.cpp | 6 + metagraph/src/cli/transform_annotation.cpp | 21 ++ metagraph/src/graph/annotated_dbg.cpp | 126 ++++++++++- metagraph/src/graph/annotated_dbg.hpp | 11 + 19 files changed, 636 insertions(+), 44 deletions(-) create mode 100644 metagraph/src/annotation/int_matrix/rank_extended/tuple_csc_matrix.hpp diff --git a/metagraph/integration_tests/base.py b/metagraph/integration_tests/base.py index c655bf7055..84f3f5aafe 100644 --- a/metagraph/integration_tests/base.py +++ b/metagraph/integration_tests/base.py @@ -106,7 +106,9 @@ def _clean(graph, output, extra_params=''): def _annotate_graph(input, graph_path, output, anno_repr, separate=False, no_fork_opt=False, no_anchor_opt=False): target_anno = anno_repr - if anno_repr in {'row_sparse'} or anno_repr.endswith('brwt') or anno_repr.startswith('row_diff'): + if (anno_repr in {'row_sparse', 'column_coord'} or + anno_repr.endswith('brwt') or + anno_repr.startswith('row_diff')): target_anno = anno_repr anno_repr = 'column' elif anno_repr in {'flat', 'rbfish'}: @@ -117,6 +119,9 @@ def _annotate_graph(input, graph_path, output, anno_repr, -i {graph_path} --anno-type {anno_repr} \ -o {output} {input}' + if target_anno.endswith('_coord'): + command += ' --coordinates' + with_counts = target_anno.endswith('int_brwt') if with_counts: command += ' --count-kmers' diff --git a/metagraph/integration_tests/test_query.py b/metagraph/integration_tests/test_query.py index 472cee85fa..159cab88e3 100644 --- a/metagraph/integration_tests/test_query.py +++ b/metagraph/integration_tests/test_query.py @@ -17,6 +17,7 @@ PROTEIN_MODE = os.readlink(METAGRAPH).endswith("_Protein") anno_file_extension = {'column': '.column.annodbg', + 'column_coord': '.column_coord.annodbg', 'row': '.row.annodbg', 'row_diff': '.row_diff.annodbg', 'row_sparse': '.row_sparse.annodbg', @@ -504,6 +505,28 @@ def test_batch_query_with_tiny_batch(self): self.assertEqual(res.returncode, 0) self.assertEqual(len(res.stdout), 136959) + def test_query_coordinates(self): + if not self.anno_repr.endswith('_coord'): + self.skipTest('annotation does not support coordinates') + + query_command = f'{METAGRAPH} query --query-coords \ + -i {self.tempdir.name}/graph{graph_file_extension[self.graph_repr]} \ + -a {self.tempdir.name}/annotation{anno_file_extension[self.anno_repr]} \ + --discovery-fraction 0.05 {TEST_DATA_DIR}/transcripts_100.fa' + + res = subprocess.run(query_command.split(), stdout=PIPE) + self.assertEqual(res.returncode, 0) + self.assertEqual(len(res.stdout), 2155983) + + query_command = f'{METAGRAPH} query --query-coords \ + -i {self.tempdir.name}/graph{graph_file_extension[self.graph_repr]} \ + -a {self.tempdir.name}/annotation{anno_file_extension[self.anno_repr]} \ + --discovery-fraction 0.95 {TEST_DATA_DIR}/transcripts_100.fa' + + res = subprocess.run(query_command.split(), stdout=PIPE) + self.assertEqual(res.returncode, 0) + self.assertEqual(len(res.stdout), 687712) + @parameterized_class(('graph_repr', 'anno_repr'), input_values=product( @@ -542,8 +565,6 @@ def setUpClass(cls): 'GGT': 18, 'GTT': 19, 'TTT': 20, - 'TTA': 21, - 'TAA': 22, } fasta_file = cls.tempdir.name + '/file.fa' with open(fasta_file, 'w') as f: @@ -626,6 +647,57 @@ def check_suffix(anno_repr, suffix): assert('objects: 12' == params_str[1]) assert('representation: ' + cls.anno_repr == params_str[3]) + def test_count_query(self): + query_file = self.tempdir.name + '/query.fa' + queries = [ + 'AAA', + 'AAAA', + 'AAAAAAAAAAAAA', + 'CCC', + 'CCCC', + 'CCCCCCCCCCCCC', + 'TTT', + 'AAACCCGGGTTT', + 'AAACCCGGGTTTTTT', + 'AAACCCGGGTTTAAA', + 'TTTAAACCCGGG', + 'ACACACACACACATTTAAACCCGGG', + ] + for discovery_rate in np.linspace(0, 1, 5): + expected_output = '' + with open(query_file, 'w') as f: + for i, s in enumerate(queries): + f.write(f'>s{i}\n{s}\n') + expected_output += f'{i}\ts{i}' + def get_count(d, kmer): + try: + return d[kmer] + except: + return 0 + + num_kmers = len(s) - self.k + 1 + + num_matches_1 = sum([get_count(self.kmer_counts_1, s[i:i + self.k]) > 0 for i in range(num_kmers)]) + count_1 = sum([get_count(self.kmer_counts_1, s[i:i + self.k]) for i in range(len(s) - self.k + 1)]) + + num_matches_2 = sum([get_count(self.kmer_counts_2, s[i:i + self.k]) > 0 for i in range(num_kmers)]) + count_2 = sum([get_count(self.kmer_counts_2, s[i:i + self.k]) for i in range(len(s) - self.k + 1)]) + + for (c, i, n) in [(count_1, 1, num_matches_1), (count_2, 0, num_matches_2)]: + if n >= discovery_rate * num_kmers: + expected_output += f'\t:{c}' + + expected_output += '\n' + + query_command = f'{METAGRAPH} query --fast --count-kmers \ + -i {self.tempdir.name}/graph{graph_file_extension[self.graph_repr]} \ + -a {self.tempdir.name}/annotation{anno_file_extension[self.anno_repr]} \ + --discovery-fraction {discovery_rate} {query_file}' + + res = subprocess.run(query_command.split(), stdout=PIPE) + self.assertEqual(res.returncode, 0) + self.assertEqual(res.stdout.decode(), expected_output) + def test_count_quantiles(self): query_file = self.tempdir.name + '/query.fa' queries = [ @@ -662,18 +734,28 @@ def get_count(d, kmer): expected_output += f':{np.quantile(counts, p, interpolation="lower")}' expected_output += '\n' - query_command = '{exe} query --fast --count-quantiles -i {graph} -a {annotation} --discovery-fraction 0.0 {input}'.format( - exe=METAGRAPH, - graph=self.tempdir.name + '/graph' + graph_file_extension[self.graph_repr], - annotation=self.tempdir.name + '/annotation' + anno_file_extension[self.anno_repr], - input=query_file - ) + query_command = f'{METAGRAPH} query --fast --count-quantiles \ + -i {self.tempdir.name}/graph{graph_file_extension[self.graph_repr]} \ + -a {self.tempdir.name}/annotation{anno_file_extension[self.anno_repr]} \ + --discovery-fraction 0.0 {query_file}' + query_command = query_command.split() query_command[4] = ' '.join([str(p) for p in quantiles]) res = subprocess.run(query_command, stdout=PIPE) self.assertEqual(res.returncode, 0) self.assertEqual(res.stdout.decode(), expected_output) + query_command = f'{METAGRAPH} query --fast --count-quantiles \ + -i {self.tempdir.name}/graph{graph_file_extension[self.graph_repr]} \ + -a {self.tempdir.name}/annotation{anno_file_extension[self.anno_repr]} \ + --discovery-fraction 1.0 {query_file}' + + query_command = query_command.split() + query_command[4] = ' '.join([str(p) for p in quantiles]) + res = subprocess.run(query_command, stdout=PIPE) + self.assertEqual(res.returncode, 0) + self.assertEqual(len(res.stdout.decode()), 5230) + @parameterized_class(('graph_repr', 'anno_repr'), input_values=(product(list(set(GRAPH_TYPES) - {'hashstr'}), ANNO_TYPES) + diff --git a/metagraph/src/annotation/binary_matrix/column_sparse/column_major.cpp b/metagraph/src/annotation/binary_matrix/column_sparse/column_major.cpp index 3942571a39..c9174e188d 100644 --- a/metagraph/src/annotation/binary_matrix/column_sparse/column_major.cpp +++ b/metagraph/src/annotation/binary_matrix/column_sparse/column_major.cpp @@ -56,6 +56,39 @@ ColumnMajor::get_rows(const std::vector &row_ids) const { return rows; } +Vector> +ColumnMajor::get_column_ranks(Row row) const { + assert(row < num_rows() || !columns_.size()); + + Vector> result; + for (size_t i = 0; i < columns_.size(); ++i) { + assert(columns_[i]); + + if (uint64_t r = columns_[i]->conditional_rank1(row)) + result.emplace_back(i, r); + } + return result; +} + +std::vector>> +ColumnMajor::get_column_ranks(const std::vector &row_ids) const { + std::vector>> result(row_ids.size()); + + for (size_t j = 0; j < columns_.size(); ++j) { + assert(columns_[j]); + const bit_vector &col = *columns_[j]; + + for (size_t i = 0; i < row_ids.size(); ++i) { + assert(row_ids[i] < num_rows()); + + if (uint64_t r = col.conditional_rank1(row_ids[i])) + result[i].emplace_back(j, r); + } + } + + return result; +} + std::vector ColumnMajor::slice_rows(const std::vector &row_ids) const { std::vector slice; @@ -91,6 +124,7 @@ bool ColumnMajor::load(std::istream &in) { for (auto &c : columns_) { assert(!c); + // TODO: switch to bit_vector_smart? c = std::make_unique(); if (!c->load(in)) return false; diff --git a/metagraph/src/annotation/binary_matrix/column_sparse/column_major.hpp b/metagraph/src/annotation/binary_matrix/column_sparse/column_major.hpp index 91d7451ec0..7aa95bf639 100644 --- a/metagraph/src/annotation/binary_matrix/column_sparse/column_major.hpp +++ b/metagraph/src/annotation/binary_matrix/column_sparse/column_major.hpp @@ -23,6 +23,10 @@ class ColumnMajor : public BinaryMatrix { bool get(Row row, Column column) const override; SetBitPositions get_row(Row row) const override; std::vector get_rows(const std::vector &rows) const override; + // query row and get ranks of each set bit in its column + Vector> get_column_ranks(Row row) const; + std::vector>> + get_column_ranks(const std::vector &rows) const; std::vector get_column(Column column) const override; // get all selected rows appended with -1 and concatenated std::vector slice_rows(const std::vector &rows) const override; diff --git a/metagraph/src/annotation/int_matrix/base/int_matrix.cpp b/metagraph/src/annotation/int_matrix/base/int_matrix.cpp index 07562f983a..11b8418a2f 100644 --- a/metagraph/src/annotation/int_matrix/base/int_matrix.cpp +++ b/metagraph/src/annotation/int_matrix/base/int_matrix.cpp @@ -7,29 +7,31 @@ namespace matrix { IntMatrix::RowValues IntMatrix::sum_row_values(const std::vector> &index_counts, - size_t min, - size_t cap) const { - assert(cap >= min); - - if (!cap) - return {}; - - min = std::max(min, size_t(1)); + size_t min_count) const { + min_count = std::max(min_count, size_t(1)); + std::vector rows; + rows.reserve(index_counts.size()); size_t total_sum = 0; - for (const auto &pair : index_counts) { - total_sum += pair.second; + for (const auto &[i, count] : index_counts) { + total_sum += count; + rows.push_back(i); } - if (total_sum < min) + if (total_sum < min_count) return {}; std::vector sum_row(num_columns(), 0); + std::vector counts(num_columns(), 0); - for (auto [i, count] : index_counts) { - for (auto [j, value] : get_row_values(i)) { + auto row_values = get_row_values(rows); + + for (size_t t = 0; t < index_counts.size(); ++t) { + auto [i, count] = index_counts[t]; + for (const auto &[j, value] : row_values[t]) { assert(j < sum_row.size()); sum_row[j] += count * value; + counts[j] += count; } } @@ -37,14 +39,47 @@ IntMatrix::sum_row_values(const std::vector> &index_count result.reserve(sum_row.size()); for (size_t j = 0; j < num_columns(); ++j) { - if (sum_row[j] >= min) { - result.emplace_back(j, std::min(sum_row[j], cap)); + if (counts[j] >= min_count) { + result.emplace_back(j, sum_row[j]); } } return result; } + +// return sizes of all non-empty tuples in the row +MultiIntMatrix::RowValues MultiIntMatrix::get_row_values(Row row) const { + RowTuples row_tuples = get_row_tuples(row); + + RowValues row_values(row_tuples.size()); + + for (size_t i = 0; i < row_tuples.size(); ++i) { + row_values[i].first = row_tuples[i].first; + row_values[i].second = row_tuples[i].second.size(); + } + + return row_values; +} + +// for each row return the sizes of all non-empty tuples +std::vector +MultiIntMatrix::get_row_values(const std::vector &rows) const { + std::vector row_tuples = get_row_tuples(rows); + + std::vector row_values(row_tuples.size()); + + for (size_t i = 0; i < row_tuples.size(); ++i) { + row_values[i].resize(row_tuples[i].size()); + for (size_t j = 0; j < row_tuples[i].size(); ++j) { + row_values[i][j].first = row_tuples[i][j].first; + row_values[i][j].second = row_tuples[i][j].second.size(); + } + } + + return row_values; +} + } // namespace matrix } // namespace annot } // namespace mtg diff --git a/metagraph/src/annotation/int_matrix/base/int_matrix.hpp b/metagraph/src/annotation/int_matrix/base/int_matrix.hpp index 1f595b5082..00beb51f6e 100644 --- a/metagraph/src/annotation/int_matrix/base/int_matrix.hpp +++ b/metagraph/src/annotation/int_matrix/base/int_matrix.hpp @@ -23,13 +23,38 @@ class IntMatrix : public binmat::BinaryMatrix { virtual std::vector get_row_values(const std::vector &rows) const = 0; - // Get all columns for which the sum of the values in queried rows - // is greater than or equal to |min|. Stop counting if the sum is - // greater than |cap|. + // sum up values for each column with at least |min_count| non-zero values virtual RowValues sum_row_values(const std::vector> &index_counts, - size_t min = 1, - size_t cap = std::numeric_limits::max()) const; + size_t min_count = 1) const; +}; + + +// Entries are tuples and their aggregated `values` are tuple sizes +class MultiIntMatrix : public IntMatrix { + public: + typedef SmallVector Tuple; + typedef Vector> RowTuples; + + virtual ~MultiIntMatrix() {} + + // return tuple sizes (if not zero) at each entry + virtual RowValues get_row_values(Row row) const; + + virtual std::vector + get_row_values(const std::vector &rows) const; + + // return total number of attributes in all tuples + virtual uint64_t num_attributes() const = 0; + + // return entries of the matrix -- where each entry is a set of integers + virtual RowTuples get_row_tuples(Row row) const = 0; + + virtual std::vector + get_row_tuples(const std::vector &rows) const = 0; + + virtual bool load_tuples(std::istream &in) = 0; + virtual void serialize_tuples(std::ostream &out) const = 0; }; } // namespace matrix diff --git a/metagraph/src/annotation/int_matrix/rank_extended/tuple_csc_matrix.hpp b/metagraph/src/annotation/int_matrix/rank_extended/tuple_csc_matrix.hpp new file mode 100644 index 0000000000..08c5a78b11 --- /dev/null +++ b/metagraph/src/annotation/int_matrix/rank_extended/tuple_csc_matrix.hpp @@ -0,0 +1,203 @@ +#ifndef __TUPLE_CSC_MATRIX_HPP__ +#define __TUPLE_CSC_MATRIX_HPP__ + +#include + +#include "annotation/int_matrix/base/int_matrix.hpp" +#include "common/logger.hpp" + + +namespace mtg { +namespace annot { +namespace matrix { + +/** + * Multi-Value Compressed Sparse Column Matrix (column-rank extended) + * + * Matrix which stores the non-empty tuples externally and indexes their + * positions in a binary matrix. These values are indexed by rank1 called + * on binary columns of the indexing matrix. + */ +template , + class Delims = bit_vector_smart> +class TupleCSCMatrix : public MultiIntMatrix { + public: + TupleCSCMatrix() {} + + TupleCSCMatrix(BaseMatrix&& index_matrix) + : binary_matrix_(std::move(index_matrix)) {} + + // return tuple sizes (if not zero) at each entry + RowValues get_row_values(Row row) const; + + std::vector + get_row_values(const std::vector &rows) const; + + uint64_t num_attributes() const; + + // return entries of the matrix -- where each entry is a set of integers + RowTuples get_row_tuples(Row row) const; + + std::vector + get_row_tuples(const std::vector &rows) const; + + uint64_t num_columns() const { return binary_matrix_.num_columns(); } + uint64_t num_rows() const { return binary_matrix_.num_rows(); } + uint64_t num_relations() const { return binary_matrix_.num_relations(); } + + // row is in [0, num_rows), column is in [0, num_columns) + bool get(Row row, Column column) const { return binary_matrix_.get(row, column); } + SetBitPositions get_row(Row row) const { return binary_matrix_.get_row(row); } + std::vector get_rows(const std::vector &rows) const { + return binary_matrix_.get_rows(rows); + } + std::vector get_column(Column column) const { + return binary_matrix_.get_column(column); + } + // get all selected rows appended with -1 and concatenated + std::vector slice_rows(const std::vector &rows) const { + return binary_matrix_.slice_rows(rows); + } + + bool load(std::istream &in); + void serialize(std::ostream &out) const; + + bool load_tuples(std::istream &in); + void serialize_tuples(std::ostream &out) const; + + const BaseMatrix& get_binary_matrix() const { return binary_matrix_; } + + private: + BaseMatrix binary_matrix_; + std::vector delimiters_; + std::vector column_values_; +}; + + +template +inline typename TupleCSCMatrix::RowValues +TupleCSCMatrix::get_row_values(Row row) const { + const auto &column_ranks = binary_matrix_.get_column_ranks(row); + RowValues row_values; + row_values.reserve(column_ranks.size()); + for (auto [j, r] : column_ranks) { + assert(r >= 1 && "matches can't have zero-rank"); + size_t tuple_size = delimiters_[j].select1(r + 1) - delimiters_[j].select1(r) - 1; + row_values.emplace_back(j, tuple_size); + } + return row_values; +} + +template +inline std::vector::RowValues> +TupleCSCMatrix::get_row_values(const std::vector &rows) const { + const auto &column_ranks = binary_matrix_.get_column_ranks(rows); + std::vector row_values(rows.size()); + // TODO: reshape? + for (size_t i = 0; i < rows.size(); ++i) { + row_values[i].reserve(column_ranks[i].size()); + for (auto [j, r] : column_ranks[i]) { + assert(r >= 1 && "matches can't have zero-rank"); + size_t tuple_size = delimiters_[j].select1(r + 1) - delimiters_[j].select1(r) - 1; + row_values[i].emplace_back(j, tuple_size); + } + } + return row_values; +} + +template +uint64_t TupleCSCMatrix::num_attributes() const { + uint64_t num_attributes = 0; + for (size_t j = 0; j < column_values_.size(); ++j) { + num_attributes += column_values_[j].size(); + } + return num_attributes; +} + +template +inline typename TupleCSCMatrix::RowTuples +TupleCSCMatrix::get_row_tuples(Row row) const { + const auto &column_ranks = binary_matrix_.get_column_ranks(row); + RowTuples row_tuples; + row_tuples.reserve(column_ranks.size()); + for (auto [j, r] : column_ranks) { + assert(r >= 1 && "matches can't have zero-rank"); + size_t begin = delimiters_[j].select1(r) + 1 - r; + size_t end = delimiters_[j].select1(r + 1) - r; + Tuple tuple; + tuple.reserve(end - begin); + for (size_t t = begin; t < end; ++t) { + tuple.push_back(column_values_[j][t]); + } + row_tuples.emplace_back(j, std::move(tuple)); + } + return row_tuples; +} + +template +inline std::vector::RowTuples> +TupleCSCMatrix::get_row_tuples(const std::vector &rows) const { + const auto &column_ranks = binary_matrix_.get_column_ranks(rows); + std::vector row_tuples(rows.size()); + // TODO: reshape? + for (size_t i = 0; i < rows.size(); ++i) { + row_tuples[i].reserve(column_ranks[i].size()); + for (auto [j, r] : column_ranks[i]) { + assert(r >= 1 && "matches can't have zero-rank"); + size_t begin = delimiters_[j].select1(r) + 1 - r; + size_t end = delimiters_[j].select1(r + 1) - r; + Tuple tuple; + tuple.reserve(end - begin); + for (size_t t = begin; t < end; ++t) { + tuple.push_back(column_values_[j][t]); + } + row_tuples[i].emplace_back(j, std::move(tuple)); + } + } + return row_tuples; +} + +template +inline bool TupleCSCMatrix::load(std::istream &in) { + return binary_matrix_.load(in) && load_tuples(in); +} + +template +inline bool TupleCSCMatrix::load_tuples(std::istream &in) { + delimiters_.clear(); + column_values_.clear(); + + delimiters_.resize(num_columns()); + column_values_.resize(num_columns()); + for (size_t j = 0; j < column_values_.size(); ++j) { + try { + delimiters_[j].load(in); + column_values_[j].load(in); + } catch (...) { + common::logger->error("Couldn't load tuple attributes for column {}", j); + return false; + } + } + return true; +} + +template +inline void TupleCSCMatrix::serialize(std::ostream &out) const { + binary_matrix_.serialize(out); + serialize_tuples(out); +} + +template +inline void TupleCSCMatrix::serialize_tuples(std::ostream &out) const { + for (size_t j = 0; j < column_values_.size(); ++j) { + delimiters_[j].serialize(out); + column_values_[j].serialize(out); + } +} + +} // namespace matrix +} // namespace annot +} // namespace mtg + +#endif // __TUPLE_CSC_MATRIX_HPP__ diff --git a/metagraph/src/annotation/representation/annotation_matrix/annotation_matrix.cpp b/metagraph/src/annotation/representation/annotation_matrix/annotation_matrix.cpp index 211e5722a9..32f06e4909 100644 --- a/metagraph/src/annotation/representation/annotation_matrix/annotation_matrix.cpp +++ b/metagraph/src/annotation/representation/annotation_matrix/annotation_matrix.cpp @@ -214,5 +214,7 @@ template class StaticBinRelAnnotator; +template class StaticBinRelAnnotator, std::string>; + } // namespace annot } // namespace mtg diff --git a/metagraph/src/annotation/representation/annotation_matrix/static_annotators_def.hpp b/metagraph/src/annotation/representation/annotation_matrix/static_annotators_def.hpp index 889afa4ca2..d1f53f31ac 100644 --- a/metagraph/src/annotation/representation/annotation_matrix/static_annotators_def.hpp +++ b/metagraph/src/annotation/representation/annotation_matrix/static_annotators_def.hpp @@ -16,6 +16,7 @@ #include "annotation/int_matrix/rank_extended/csc_matrix.hpp" #include "annotation/int_matrix/row_diff/int_row_diff.hpp" #include "annotation/int_matrix/csr_matrix/csr_matrix.hpp" +#include "annotation/int_matrix/rank_extended/tuple_csc_matrix.hpp" namespace mtg { @@ -51,6 +52,8 @@ typedef StaticBinRelAnnotator IntRowAnnotator; +typedef StaticBinRelAnnotator, std::string> ColumnCoordAnnotator; + template <> inline const std::string RowFlatAnnotator::kExtension = ".flat.annodbg"; @@ -80,6 +83,8 @@ template <> inline const std::string IntRowDiffBRWTAnnotator::kExtension = ".row_diff_int_brwt.annodbg"; template <> inline const std::string IntRowAnnotator::kExtension = ".int_csr.annodbg"; +template <> +inline const std::string ColumnCoordAnnotator::kExtension = ".column_coord.annodbg"; } // namespace annot } // namespace mtg diff --git a/metagraph/src/cli/config/config.cpp b/metagraph/src/cli/config/config.cpp index 43eb36a94d..d5c2536db4 100644 --- a/metagraph/src/cli/config/config.cpp +++ b/metagraph/src/cli/config/config.cpp @@ -200,6 +200,8 @@ Config::Config(int argc, char *argv[]) { discovery_fraction = std::stof(get_value(i++)); } else if (!strcmp(argv[i], "--query-presence")) { query_presence = true; + } else if (!strcmp(argv[i], "--query-coords")) { + query_coords = true; } else if (!strcmp(argv[i], "--filter-present")) { filter_present = true; } else if (!strcmp(argv[i], "--count-labels")) { @@ -700,6 +702,8 @@ std::string Config::annotype_to_string(AnnotationType state) { return "int_brwt"; case IntRowDiffBRWT: return "row_diff_int_brwt"; + case ColumnCoord: + return "column_coord"; } throw std::runtime_error("Never happens"); } @@ -733,6 +737,8 @@ Config::AnnotationType Config::string_to_annotype(const std::string &string) { return AnnotationType::IntBRWT; } else if (string == "row_diff_int_brwt") { return AnnotationType::IntRowDiffBRWT; + } else if (string == "column_coord") { + return AnnotationType::ColumnCoord; } else { std::cerr << "Error: unknown annotation representation" << std::endl; exit(1); @@ -795,6 +801,7 @@ DeBruijnGraph::Mode Config::string_to_graphmode(const std::string &string) { void Config::print_usage(const std::string &prog_name, IdentityType identity) { const char annotation_list[] = "\t\t( column, brwt, rb_brwt, int_brwt,\n" + "\t\t column_coord,\n" "\t\t row_diff, row_diff_brwt, row_diff_sparse, row_diff_int_brwt,\n" "\t\t row, flat, row_sparse, rbfish, bin_rel_wt, bin_rel_wt_sdsl )"; @@ -1175,6 +1182,7 @@ void Config::print_usage(const std::string &prog_name, IdentityType identity) { fprintf(stderr, "\t --count-quantiles [FLOAT ...] \tk-mer count quantiles to compute for each label [off]\n" "\t \t\tExample: --count-quantiles '0 0.33 0.5 0.66 1'\n" "\t \t\t(0 corresponds to MIN, 1 corresponds to MAX)\n"); + fprintf(stderr, "\t --query-coords \t\tquery k-mer coordinates (requires coord annotation) [off]\n"); fprintf(stderr, "\t --print-signature \t\tprint vectors indicating present/absent k-mers [off]\n"); fprintf(stderr, "\t --num-top-labels \t\tmaximum number of frequent labels to print [off]\n"); fprintf(stderr, "\t --discovery-fraction [FLOAT] fraction of labeled k-mers required for annotation [0.7]\n"); diff --git a/metagraph/src/cli/config/config.hpp b/metagraph/src/cli/config/config.hpp index d4b76d6bb7..4a19a3ab32 100644 --- a/metagraph/src/cli/config/config.hpp +++ b/metagraph/src/cli/config/config.hpp @@ -39,6 +39,7 @@ class Config { bool count_kmers = false; bool print_signature = false; bool query_presence = false; + bool query_coords = false; bool filter_present = false; bool dump_text_anno = false; bool sparse = false; @@ -199,6 +200,7 @@ class Config { RbBRWT, IntBRWT, IntRowDiffBRWT, + ColumnCoord, }; enum GraphType { diff --git a/metagraph/src/cli/load/load_annotated_graph.cpp b/metagraph/src/cli/load/load_annotated_graph.cpp index 68b1977930..3c9e0ec90c 100644 --- a/metagraph/src/cli/load/load_annotated_graph.cpp +++ b/metagraph/src/cli/load/load_annotated_graph.cpp @@ -17,7 +17,6 @@ namespace mtg { namespace cli { using namespace mtg::graph; - using mtg::common::logger; @@ -52,6 +51,7 @@ std::unique_ptr initialize_annotated_dbg(std::shared_ptr(annotation_temp->get_matrix()); diff --git a/metagraph/src/cli/load/load_annotation.cpp b/metagraph/src/cli/load/load_annotation.cpp index 9482b47fd8..d9865ad60a 100644 --- a/metagraph/src/cli/load/load_annotation.cpp +++ b/metagraph/src/cli/load/load_annotation.cpp @@ -21,6 +21,9 @@ Config::AnnotationType parse_annotation_type(const std::string &filename) { if (utils::ends_with(filename, annot::ColumnCompressed<>::kExtension)) { return Config::AnnotationType::ColumnCompressed; + } else if (utils::ends_with(filename, annot::ColumnCoordAnnotator::kExtension)) { + return Config::AnnotationType::ColumnCoord; + } else if (utils::ends_with(filename, annot::RowDiffColumnAnnotator::kExtension)) { return Config::AnnotationType::RowDiff; @@ -137,6 +140,10 @@ initialize_annotation(Config::AnnotationType anno_type, annotation.reset(new annot::IntRowDiffBRWTAnnotator()); break; } + case Config::ColumnCoord: { + annotation.reset(new annot::ColumnCoordAnnotator()); + break; + } } return annotation; diff --git a/metagraph/src/cli/query.cpp b/metagraph/src/cli/query.cpp index a098f9e442..13322d0f7d 100644 --- a/metagraph/src/cli/query.cpp +++ b/metagraph/src/cli/query.cpp @@ -61,7 +61,8 @@ std::string QueryExecutor::execute_query(const std::string &seq_name, std::string anno_labels_delimiter, const AnnotatedDBG &anno_graph, bool with_kmer_counts, - const std::vector &count_quantiles) { + const std::vector &count_quantiles, + bool query_coords) { std::string output; output.reserve(1'000); @@ -85,6 +86,25 @@ std::string QueryExecutor::execute_query(const std::string &seq_name, output += '\n'; + } else if (query_coords) { + auto result = anno_graph.get_kmer_coordinates(sequence, + num_top_labels, + discovery_fraction); + + if (!result.size() && suppress_unlabeled) + return ""; + + output += seq_name; + + for (const auto &[label, tuples] : result) { + output += "\t<" + label + ">"; + for (const auto &coords : tuples) { + output += fmt::format(":{}", fmt::join(coords, ",")); + } + } + + output += '\n'; + } else if (count_quantiles.size()) { auto result = anno_graph.get_label_count_quantiles(sequence, num_top_labels, @@ -97,15 +117,12 @@ std::string QueryExecutor::execute_query(const std::string &seq_name, output += seq_name; for (const auto &[label, quantiles] : result) { - output += "\t<" + label + ">"; - for (uint64_t count : quantiles) { - output += fmt::format(":{}", count); - } + output += fmt::format("\t<{}>:{}", label, fmt::join(quantiles, ":")); } output += '\n'; - } else if (count_labels) { + } else if (count_labels || with_kmer_counts) { auto top_labels = anno_graph.get_top_labels(sequence, num_top_labels, discovery_fraction, @@ -117,10 +134,7 @@ std::string QueryExecutor::execute_query(const std::string &seq_name, output += seq_name; for (const auto &[label, count] : top_labels) { - output += "\t<"; - output += label; - output += ">:"; - output += fmt::format_int(count).c_str(); + output += fmt::format("\t<{}>:{}", label, count); } output += '\n'; @@ -914,7 +928,8 @@ std::string query_sequence(size_t id, std::string name, std::string seq, config.count_labels, config.print_signature, config.suppress_unlabeled, config.num_top_labels, config.discovery_fraction, config.anno_labels_delimiter, - anno_graph, config.count_kmers, config.count_quantiles); + anno_graph, config.count_kmers, config.count_quantiles, + config.query_coords); } void QueryExecutor::query_fasta(const string &file, @@ -924,6 +939,10 @@ void QueryExecutor::query_fasta(const string &file, seq_io::FastaParser fasta_parser(file, config_.forward_and_reverse); if (config_.fast) { + if (config_.query_coords) { + logger->error("Querying coordinates in batch mode is not supported"); + exit(1); + } // Construct a query graph and query against it batched_query_fasta(fasta_parser, callback); return; diff --git a/metagraph/src/cli/query.hpp b/metagraph/src/cli/query.hpp index be7ae21dff..f39b16feaa 100644 --- a/metagraph/src/cli/query.hpp +++ b/metagraph/src/cli/query.hpp @@ -65,7 +65,8 @@ class QueryExecutor { std::string anno_labels_delimiter, const graph::AnnotatedDBG &anno_graph, bool with_kmer_counts = false, - const std::vector &count_quantiles = {}); + const std::vector &count_quantiles = {}, + bool query_coords = false); private: const Config &config_; diff --git a/metagraph/src/cli/stats.cpp b/metagraph/src/cli/stats.cpp index 36429282b3..c9961a49b1 100644 --- a/metagraph/src/cli/stats.cpp +++ b/metagraph/src/cli/stats.cpp @@ -167,6 +167,12 @@ void print_stats(const Annotator &annotation) { << utils::split_string(annotation.file_extension(), ".").at(0) << std::endl; using namespace annot::binmat; + using mtg::annot::matrix::MultiIntMatrix; + + if (const auto *mat_coord = dynamic_cast(&annotation.get_matrix())) { + std::cout << "================== COORDINATES STATS ===================" << std::endl; + std::cout << "coordinates: " << mat_coord->num_attributes() << std::endl; + } if (const auto *rbmat = dynamic_cast(&annotation.get_matrix())) { std::cout << "================= RAINBOW MATRIX STATS =================" << std::endl; diff --git a/metagraph/src/cli/transform_annotation.cpp b/metagraph/src/cli/transform_annotation.cpp index 8eda812d5e..402cd744e3 100644 --- a/metagraph/src/cli/transform_annotation.cpp +++ b/metagraph/src/cli/transform_annotation.cpp @@ -652,6 +652,27 @@ int transform_annotation(Config *config) { assert(false); break; } + case Config::ColumnCoord: { + auto label_encoder = annotator->get_label_encoder(); + auto tuple_matrix = std::make_unique>( + annotator->release_matrix()); + if (files.size() > 1) { + logger->error("Merging coordinates from multiple columns is not supported"); + exit(1); + } + auto coords_fname = utils::remove_suffix(files.at(0), + ColumnCompressed<>::kExtension) + + ColumnCompressed<>::kCoordExtension; + std::ifstream in(coords_fname); + tuple_matrix->load_tuples(in); + + ColumnCoordAnnotator column_coord(std::move(tuple_matrix), label_encoder); + + logger->trace("Annotation converted in {} sec", timer.elapsed()); + column_coord.serialize(config->outfbase); + logger->trace("Serialized to {}", config->outfbase); + break; + } case Config::RowDiffBRWT: { logger->error("Convert to row_diff first, and then to row_diff_brwt"); return 0; diff --git a/metagraph/src/graph/annotated_dbg.cpp b/metagraph/src/graph/annotated_dbg.cpp index 033ae3d3ca..1461a9baf2 100644 --- a/metagraph/src/graph/annotated_dbg.cpp +++ b/metagraph/src/graph/annotated_dbg.cpp @@ -15,12 +15,15 @@ #include "common/aligned_vector.hpp" #include "common/vectors/vector_algorithm.hpp" #include "common/vector_map.hpp" +#include "common/logger.hpp" namespace mtg { namespace graph { +using mtg::common::logger; using mtg::annot::matrix::IntMatrix; +using mtg::annot::matrix::MultiIntMatrix; typedef AnnotatedDBG::Label Label; typedef std::pair StringCountPair; @@ -328,9 +331,14 @@ AnnotatedDBG::get_label_count_quantiles(std::string_view sequence, q_low[i] = (num_kmers - 1) * count_quantiles[i]; } + const auto *int_matrix = dynamic_cast(&annotator_->get_matrix()); + if (!int_matrix) { + logger->error("k-mer counts are not indexed in this annotator"); + exit(1); + } + VectorMap> code_to_counts; - for (const auto &row_values : dynamic_cast(annotator_->get_matrix()) - .get_row_values(rows)) { + for (const auto &row_values : int_matrix->get_row_values(rows)) { for (const auto &[column, count] : row_values) { code_to_counts[column].push_back(count); } @@ -376,6 +384,120 @@ AnnotatedDBG::get_label_count_quantiles(std::string_view sequence, return label_quantiles; } +std::vector>>> +AnnotatedDBG::get_kmer_coordinates(std::string_view sequence, + size_t num_top_labels, + double presence_ratio) const { + assert(presence_ratio >= 0.); + assert(presence_ratio <= 1.); + assert(check_compatibility()); + + if (sequence.size() < dbg_.get_k()) + return {}; + + std::vector path; + size_t num_kmers = sequence.size() - dbg_.get_k() + 1; + path.reserve(num_kmers); + + graph_->map_to_nodes(sequence, [&](node_index i) { + path.push_back(i); + }); + + return get_kmer_coordinates(path, num_top_labels, presence_ratio); +} + +std::vector>>> +AnnotatedDBG::get_kmer_coordinates(const std::vector &path, + size_t num_top_labels, + double presence_ratio) const { + assert(presence_ratio >= 0.); + assert(presence_ratio <= 1.); + assert(check_compatibility()); + + if (!path.size()) + return {}; + + std::vector rows; + rows.reserve(path.size()); + + std::vector ids; + ids.reserve(path.size()); + + for (node_index i : path) { + if (i > 0) { + ids.push_back(rows.size()); + rows.push_back(graph_to_anno_index(i)); + } else { + ids.push_back(-1); + } + } + + uint64_t min_count = std::max(1.0, std::ceil(presence_ratio * path.size())); + if (rows.size() < min_count) + return {}; + + const auto *tuple_matrix = dynamic_cast(&annotator_->get_matrix()); + if (!tuple_matrix) { + logger->error("k-mer coordinates are not indexed in this annotator"); + exit(1); + } + + auto rows_tuples = tuple_matrix->get_row_tuples(rows); + + VectorMap code_to_count; + for (const auto &row_tuples : rows_tuples) { + for (const auto &[column, tuple] : row_tuples) { + code_to_count[column] += 1; + } + } + + auto code_counts = code_to_count.values_container(); + // sort by the number of matched k-mers + std::sort(code_counts.begin(), code_counts.end(), + [](const auto &x, const auto &y) { + return x.second > y.second || (x.second == y.second && x.first < y.first); + }); + + // keep only the first |num_top_labels| top labels + if (code_counts.size() > num_top_labels) + code_counts.resize(num_top_labels); + + // filter by the number of matched k-mers + code_counts.erase( + std::upper_bound(code_counts.begin(), code_counts.end(), min_count, + [](uint64_t min, const auto &x) { return x.second < min; }), + code_counts.end() + ); + + code_to_count = VectorMap(code_counts.begin(), code_counts.end()); + + std::vector>>> result(code_to_count.size()); + + for (size_t j = 0; j < result.size(); ++j) { + result[j].first = annotator_->get_label_encoder().decode(code_counts[j].first); + } + + for (size_t i : ids) { + for (size_t j = 0; j < result.size(); ++j) { + // append empty tuple + result[j].second.emplace_back(); + } + + // leave all tuples empty if the k-mer is missing + if (i == (size_t)-1) + continue; + + // set the non-empty tuples + for (auto &[j, tuple] : rows_tuples[i]) { + auto it = code_to_count.find(j); + if (it != code_to_count.end()) + result[it - code_to_count.begin()].second.back() = std::move(tuple); + } + } + + return result; +} + std::vector> AnnotatedDBG::get_top_label_signatures(std::string_view sequence, size_t num_top_labels, diff --git a/metagraph/src/graph/annotated_dbg.hpp b/metagraph/src/graph/annotated_dbg.hpp index d9b55284c5..1ffd58d4d8 100644 --- a/metagraph/src/graph/annotated_dbg.hpp +++ b/metagraph/src/graph/annotated_dbg.hpp @@ -9,6 +9,7 @@ #include "representation/base/sequence_graph.hpp" #include "annotation/representation/base/annotation.hpp" +#include "common/vector.hpp" namespace mtg { @@ -122,6 +123,16 @@ class AnnotatedDBG : public AnnotatedSequenceGraph { double presence_ratio, const std::vector &count_quantiles) const; + std::vector>>> + get_kmer_coordinates(std::string_view sequence, + size_t num_top_labels, + double presence_ratio) const; + + std::vector>>> + get_kmer_coordinates(const std::vector &path, + size_t num_top_labels, + double presence_ratio) const; + std::vector> get_top_label_signatures(std::string_view sequence, size_t num_top_labels, From 0d156d312385920206f6dfd923a9c0f9105251d3 Mon Sep 17 00:00:00 2001 From: Harun Mustafa Date: Sun, 18 Jul 2021 01:22:01 +0200 Subject: [PATCH 4/4] Various minor fixes and tweaks (#332) * CanonicalDBG fixes * reduce the number of alignment unit tests * fix corner case in exact k-mer mapping * Create a separate CanonicalDBG for each batch * Updated caches lib * use TryGet instead of Get * added copy constructor for CanonicalDBG * added exact map integration tests --- metagraph/external-libraries/caches | 2 +- metagraph/integration_tests/test_align.py | 60 ++++++++ .../annotate_column_compressed.cpp | 9 +- metagraph/src/cli/align.cpp | 29 ++-- .../graph/representation/canonical_dbg.cpp | 141 ++++++++++++------ .../graph/representation/canonical_dbg.hpp | 14 ++ .../tests/graph/all/test_dbg_helpers.hpp | 9 +- metagraph/tests/graph/test_aligner.cpp | 3 +- .../tests/graph/test_aligner_helpers.hpp | 13 +- metagraph/tests/graph/test_canonical_dbg.cpp | 3 +- 10 files changed, 201 insertions(+), 82 deletions(-) diff --git a/metagraph/external-libraries/caches b/metagraph/external-libraries/caches index b969baa3e7..9843c53447 160000 --- a/metagraph/external-libraries/caches +++ b/metagraph/external-libraries/caches @@ -1 +1 @@ -Subproject commit b969baa3e774fac964487e14bb0122708b09114b +Subproject commit 9843c534478a9188343c1cf41fb3e9269df872dd diff --git a/metagraph/integration_tests/test_align.py b/metagraph/integration_tests/test_align.py index fa789f5e70..b7fcb25522 100644 --- a/metagraph/integration_tests/test_align.py +++ b/metagraph/integration_tests/test_align.py @@ -60,6 +60,66 @@ def test_simple_align_all_graphs(self, representation): self.assertEqual(last_split[1], "AACAGAGAATTGTTTAAATTACAATCTTAGCTATGGGTGCTAAAGGTGGAGTTATAGACTTTTTCACTGATTTGTCGTTGGAAAAAGCTTTTCATCTCGGGTTTACAAGTCTGGTGTATTTGTTTATACTAGAAGGACAGGCGCATTTGA") self.assertEqual(last_split[4], "22") + @parameterized.expand(GRAPH_TYPES) + def test_simple_align_map_all_graphs(self, representation): + + self._build_graph(input=TEST_DATA_DIR + '/genome.MT.fa', + output=self.tempdir.name + '/genome.MT', + k=11, repr=representation, + extra_params="--mask-dummy") + + res = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) + params_str = res.stdout.decode().split('\n')[2:] + self.assertEqual('k: 11', params_str[0]) + self.assertEqual('nodes (k): 16438', params_str[1]) + self.assertEqual('mode: basic', params_str[2]) + + stats_command = '{exe} align -i {graph} --map --count-kmers {reads}'.format( + exe=METAGRAPH, + graph=self.tempdir.name + '/genome.MT' + graph_file_extension[representation], + reads=TEST_DATA_DIR + '/genome_MT1.fq', + ) + res = subprocess.run(stats_command.split(), stdout=PIPE) + self.assertEqual(res.returncode, 0) + params_str = res.stdout.decode().rstrip().split('\n') + self.assertEqual(len(params_str), 6) + self.assertEqual(params_str[0], 'MT-10/1\t1/140/1') + self.assertEqual(params_str[1], 'MT-8/1\t140/140/140') + self.assertEqual(params_str[2], 'MT-6/1\t140/140/140') + self.assertEqual(params_str[3], 'MT-4/1\t0/140/0') + self.assertEqual(params_str[4], 'MT-2/1\t140/140/140') + self.assertEqual(params_str[5], 'MT-11/1\t1/140/1') + + @parameterized.expand(GRAPH_TYPES) + def test_simple_align_map_canonical_all_graphs(self, representation): + + self._build_graph(input=TEST_DATA_DIR + '/genome.MT.fa', + output=self.tempdir.name + '/genome.MT', + k=11, repr=representation, mode='canonical', + extra_params="--mask-dummy") + + res = self._get_stats(self.tempdir.name + '/genome.MT' + graph_file_extension[representation]) + params_str = res.stdout.decode().split('\n')[2:] + self.assertEqual('k: 11', params_str[0]) + self.assertEqual('nodes (k): 32782', params_str[1]) + self.assertEqual('mode: canonical', params_str[2]) + + stats_command = '{exe} align -i {graph} --map --count-kmers {reads}'.format( + exe=METAGRAPH, + graph=self.tempdir.name + '/genome.MT' + graph_file_extension[representation], + reads=TEST_DATA_DIR + '/genome_MT1.fq', + ) + res = subprocess.run(stats_command.split(), stdout=PIPE) + self.assertEqual(res.returncode, 0) + params_str = res.stdout.decode().rstrip().split('\n') + self.assertEqual(len(params_str), 6) + self.assertEqual(params_str[0], 'MT-10/1\t140/140/140') + self.assertEqual(params_str[1], 'MT-8/1\t140/140/140') + self.assertEqual(params_str[2], 'MT-6/1\t140/140/140') + self.assertEqual(params_str[3], 'MT-4/1\t129/140/129') + self.assertEqual(params_str[4], 'MT-2/1\t140/140/139') + self.assertEqual(params_str[5], 'MT-11/1\t2/140/2') + @parameterized.expand(['succinct']) def test_simple_align_json_all_graphs(self, representation): diff --git a/metagraph/src/annotation/representation/column_compressed/annotate_column_compressed.cpp b/metagraph/src/annotation/representation/column_compressed/annotate_column_compressed.cpp index 9a5f419262..8cdf19f6ce 100644 --- a/metagraph/src/annotation/representation/column_compressed/annotate_column_compressed.cpp +++ b/metagraph/src/annotation/representation/column_compressed/annotate_column_compressed.cpp @@ -766,10 +766,11 @@ bitmap_builder& ColumnCompressed