#ifndef __CLUSTER_FOREST_HPP__ #define __CLUSTER_FOREST_HPP__ #include #include "SalmonSpinLock.hpp" #include "Transcript.hpp" #include "TranscriptCluster.hpp" #include #include #include /** A forest of transcript clusters */ class ClusterForest { public: ClusterForest(size_t numTranscripts, std::vector& refs) : rank_(std::vector(numTranscripts, 0)), parent_(std::vector(numTranscripts, 0)), disjointSets_(&rank_[0], &parent_[0]), clusters_(std::vector(numTranscripts)) { // Initially make a unique set for each transcript for (size_t tnum = 0; tnum < numTranscripts; ++tnum) { disjointSets_.make_set(tnum); clusters_[tnum].members_.push_front(tnum); clusters_[tnum].addMass(refs[tnum].mass()); } } template void mergeClusters(typename std::vector::iterator start, typename std::vector::iterator finish) { // Use a lock_guard to ensure this is a locked (and exception-safe) operation #if defined __APPLE__ spin_lock::scoped_lock sl(clusterMutex_); #else std::lock_guard lock(clusterMutex_); #endif size_t firstCluster, otherCluster; auto firstTranscriptID = start->transcriptID(); ++start; for (auto it = start; it != finish; ++it) { firstCluster = disjointSets_.find_set(firstTranscriptID); otherCluster = disjointSets_.find_set(it->transcriptID()); if (otherCluster != firstCluster) { disjointSets_.link(firstCluster, otherCluster); auto parentClust = disjointSets_.find_set(it->transcriptID()); auto childClust = (parentClust == firstCluster) ? otherCluster : firstCluster; if (parentClust == firstCluster or parentClust == otherCluster) { clusters_[parentClust].merge(clusters_[childClust]); clusters_[childClust].deactivate(); } else { std::cerr << "DANGER\n"; } } } } template void mergeClusters(typename std::vector::iterator start, typename std::vector::iterator finish) { // Use a lock_guard to ensure this is a locked (and exception-safe) operation #if defined __APPLE__ spin_lock::scoped_lock sl(clusterMutex_); #else std::lock_guard lock(clusterMutex_); #endif auto firstTranscriptID = (*start)->transcriptID(); decltype(firstTranscriptID) firstCluster, otherCluster; ++start; for (auto it = start; it != finish; ++it) { firstCluster = disjointSets_.find_set(firstTranscriptID); otherCluster = disjointSets_.find_set((*it)->transcriptID()); if (otherCluster != firstCluster) { disjointSets_.link(firstCluster, otherCluster); auto parentClust = disjointSets_.find_set((*it)->transcriptID()); auto childClust = (parentClust == firstCluster) ? otherCluster : firstCluster; if (parentClust == firstCluster or parentClust == otherCluster) { clusters_[parentClust].merge(clusters_[childClust]); clusters_[childClust].deactivate(); } else { std::cerr << "DANGER\n"; } } } } /* void mergeClusters(AlignmentBatch::iterator start, AlignmentBatch::iterator finish) { // Use a lock_guard to ensure this is a locked (and exception-safe) operation std::lock_guard lock(clusterMutex_); size_t firstCluster, otherCluster; auto firstTranscriptID = start->read1->core.tid; ++start; for (auto it = start; it != finish; ++it) { firstCluster = disjointSets_.find_set(firstTranscriptID); otherCluster = disjointSets_.find_set(it->read1->core.tid); if (otherCluster != firstCluster) { disjointSets_.link(firstCluster, otherCluster); auto parentClust = disjointSets_.find_set(it->read1->core.tid); auto childClust = (parentClust == firstCluster) ? otherCluster : firstCluster; if (parentClust == firstCluster or parentClust == otherCluster) { clusters_[parentClust].merge(clusters_[childClust]); clusters_[childClust].deactivate(); } else { std::cerr << "DANGER\n"; } } } } */ void updateCluster(size_t memberTranscript, size_t newCount, double logNewMass, bool updateCount) { // Use a lock_guard to ensure this is a locked (and exception-safe) operation #if defined __APPLE__ spin_lock::scoped_lock sl(clusterMutex_); #else std::lock_guard lock(clusterMutex_); #endif auto clusterID = disjointSets_.find_set(memberTranscript); auto& cluster = clusters_[clusterID]; if (updateCount) { cluster.incrementCount(newCount); } cluster.addMass(logNewMass); } std::vector getClusters() { std::vector clusters; std::unordered_set observedReps; for (size_t i = 0; i < clusters_.size(); ++i) { auto rep = disjointSets_.find_set(i); if (observedReps.find(rep) == observedReps.end()) { if (!clusters_[rep].isActive()) { std::cerr << "returning a non-active cluster!\n"; std::exit(1); } clusters.push_back(&clusters_[rep]); observedReps.insert(rep); } } return clusters; } private: std::vector rank_; std::vector parent_; boost::disjoint_sets disjointSets_; std::vector clusters_; #if defined __APPLE__ spin_lock clusterMutex_; #else std::mutex clusterMutex_; #endif }; #endif // __CLUSTER_FOREST_HPP__