diff --git a/src/seg/SemExtractor.hpp b/src/seg/SemExtractor.hpp index 7a519f7..a968dba 100644 --- a/src/seg/SemExtractor.hpp +++ b/src/seg/SemExtractor.hpp @@ -5,32 +5,34 @@ #include #include #include +#include using sem_array_t = std::array; +using sem_dict_t = std::unordered_map; template class SemExtractor { public: - SemExtractor(const Chunk & sem) + explicit SemExtractor(const Chunk & sem) :m_sem(sem){} void collectVoxel(Coord c, Tseg segid) { auto sem_label = m_sem[c[0]][c[1]][c[2]]; - if (m_labels[segid][1] > 0 or sem_label == 0) { - return; + if (not m_labels.contains(segid)) { + m_labels[segid] = sem_dict_t(); } - - if (m_labels[segid][0] == 0) { - m_labels[segid][0] = sem_label; - } else if (m_labels[segid][0] != sem_label){ - m_labels[segid][1] = 1; + auto & sem_dict = m_labels[segid]; + if (sem_dict.contains(sem_label)) { + sem_dict[sem_label] += 1; + } else { + sem_dict[sem_label] = 1; } } void collectBoundary(int face, Coord c, Tseg segid) {} void collectContactingSurface(int nv, Coord c, Tseg segid1, Tseg segid2) {} - const MapContainer & sem_labels() { + const MapContainer & sem_labels() { return m_labels; } @@ -41,13 +43,12 @@ class SemExtractor assert(of.is_open()); for (const auto & [k, v]: m_labels) { auto svid = k; + auto sem_labels = mapToArray(v); if (chunkMap.count(k) > 0) { svid = chunkMap.at(k); } - if (remapped_labels.count(svid) == 0 or remapped_labels[svid][0] == 0) { - remapped_labels[svid] = v; - } else if (v[0] > 0 and remapped_labels[svid][0] != v[0]) { - remapped_labels[svid][1] = 1; + if ((remapped_labels.count(svid) == 0) or (remapped_labels[svid][1] < sem_labels[1])) { + remapped_labels[svid] = sem_labels; } } for (const auto & [k,v] : remapped_labels) { @@ -57,6 +58,19 @@ class SemExtractor } private: + sem_array_t mapToArray(const sem_dict_t & entries) { + std::vector > vec(entries.begin(), entries.end()); + std::sort(vec.begin(), vec.end(), [](const auto & a, const auto & b) { + return a.second > b.second; + }); + auto & top = vec[0]; + auto sum = std::accumulate(vec.begin(), vec.end(), static_cast(0), + [](auto partialSum, const auto & element) { + return partialSum + element.second; + }); + return sem_array_t{top.first, top.second, sum}; + } + void write_sem(auto & io, const auto & k, const auto & v) { io.write(reinterpret_cast(&k), sizeof(k)); io.write(reinterpret_cast(&v), sizeof(v)); @@ -68,7 +82,7 @@ class SemExtractor // dummy, axon, bv, dendrite, glia, soma //static constexpr std::array sem_map = {-1,1,2,0,2,-1}; const Chunk & m_sem; - MapContainer m_labels; + MapContainer m_labels; };