Skip to content

Commit

Permalink
Hack to use segment ids as semantic labels
Browse files Browse the repository at this point in the history
  • Loading branch information
ranlu committed Sep 30, 2024
1 parent 980e3e9 commit 85eb19f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 23 deletions.
33 changes: 17 additions & 16 deletions src/agg/mean_aggl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ struct agglomeration_size_heuristic_t

struct agglomeration_semantic_heuristic_t
{
aff_t aff_threshold = 0.5;
aff_t aff_threshold = 1.0;
size_t total_signal_threshold = 100'000;
double dominant_signal_ratio = 0.6;
};
Expand Down Expand Up @@ -269,7 +269,12 @@ std::vector<sem_array_t> load_sem(const char * sem_filename, const std::vector<s
std::sort(std::execution::par, std::begin(sem_array), std::end(sem_array), [](auto & a, auto & b) { return a.first < b.first; });

for (auto & [k, v] : sem_array) {
std::transform(sem_counts[k].begin(), sem_counts[k].end(), v.begin(), sem_counts[k].begin(), std::plus<>());
if (sem_counts[k][0] > 0 and v[0] > 0 and v[0] != sem_counts[k][0]) {
sem_counts[k][1] = 1;
}
if (v[0] > 0) {
sem_counts[k] = v;
}
}
return sem_counts;
}
Expand Down Expand Up @@ -624,20 +629,14 @@ std::pair<size_t, size_t> sem_label(const sem_array_t & labels)

bool sem_can_merge(const sem_array_t & labels1, const sem_array_t & labels2, const agglomeration_semantic_heuristic_t & sem_params)
{
auto max_label1 = std::distance(labels1.begin(), std::max_element(labels1.begin(), labels1.end()));
auto max_label2 = std::distance(labels2.begin(), std::max_element(labels2.begin(), labels2.end()));
auto total_label1 = std::accumulate(labels1.begin(), labels1.end(), static_cast<size_t>(0));
auto total_label2 = std::accumulate(labels2.begin(), labels2.end(), static_cast<size_t>(0));
if (labels1[max_label1] < sem_params.dominant_signal_ratio * total_label1 || total_label1 < sem_params.total_signal_threshold) { //unsure about the semantic label
return true;
if (labels1[1] > 0 or labels2[1] > 0) {
return false;
}
if (labels2[max_label2] < sem_params.dominant_signal_ratio * total_label2 || total_label2 < sem_params.total_signal_threshold) { //unsure about the semantic label
if (labels1[0] == 0 or labels2[0] == 0 or labels1[0] == labels2[0]) {
return true;
} else {
return false;
}
if (max_label1 == max_label2) {
return true;
}
return false;
}

template <class T, class Compare = std::greater<T>, class Plus = std::plus<T>,
Expand Down Expand Up @@ -752,9 +751,11 @@ inline agglomeration_output_t<T> agglomerate_cc(agglomeration_data_t<T, Compare>
std::swap(seg_size[v0], seg_size[s]);

if (!sem_counts.empty()) {
std::transform(sem_counts[v0].begin(), sem_counts[v0].end(), sem_counts[v1].begin(), sem_counts[v0].begin(), std::plus<size_t>());
sem_counts[v1] = sem_array_t();
std::swap(sem_counts[v0], sem_counts[s]);
if (sem_counts[v0][0] > 0) {
std::swap(sem_counts[v0], sem_counts[s]);
} else {
std::swap(sem_counts[v1], sem_counts[s]);
}
}

output.merged_rg_vector.push_back(*(e.edge));
Expand Down
18 changes: 12 additions & 6 deletions src/seg/SemExtractor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <iostream>
#include <fstream>

using sem_array_t = std::array<size_t, 3>;
using sem_array_t = std::array<seg_t, 3>;

template<typename Tseg, typename Tsem, typename Chunk>
class SemExtractor
Expand All @@ -17,8 +17,14 @@ class SemExtractor
void collectVoxel(Coord c, Tseg segid)
{
auto sem_label = m_sem[c[0]][c[1]][c[2]];
if (sem_label >= 0 and sem_map[sem_label] >= 0) {
m_labels[segid][sem_map[sem_label]] += 1;
if (m_labels[segid][1] > 0 or sem_label == 0) {
return;
}

if (m_labels[segid][0] == 0) {
m_labels[segid][0] = sem_label;
} else if (m_labels[segid][0] != sem_label){
m_labels[segid][1] = 1;
}
}

Expand All @@ -38,10 +44,10 @@ class SemExtractor
if (chunkMap.count(k) > 0) {
svid = chunkMap.at(k);
}
if (remapped_labels.count(svid) == 0) {
if (remapped_labels.count(svid) == 0 or remapped_labels[svid][0] == 0) {
remapped_labels[svid] = v;
} else {
std::transform(remapped_labels[svid].begin(), remapped_labels[svid].end(), v.begin(), remapped_labels[svid].begin(), std::plus<size_t>());
} else if (v[0] > 0 and remapped_labels[svid][0] != v[0]) {
remapped_labels[svid][1] = 1;
}
}
for (const auto & [k,v] : remapped_labels) {
Expand Down
2 changes: 1 addition & 1 deletion src/seg/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using ContactRegionExt = MapContainer<Coord, int, HashFunction<Coord> >;
template <class Ta>
using Edge = std::array<MapContainer<Coord, Ta, HashFunction<Coord> >, 3>;

using semantic_t = uint8_t;
using semantic_t = uint64_t;

template <class T>
struct __attribute__((packed)) matching_entry_t
Expand Down

0 comments on commit 85eb19f

Please sign in to comment.