我尝试用 std::priority_queue 替换 std::multiset。但我对速度结果感到失望。算法运行时间增加 50%...
以下是相应的命令:
top() = begin();
pop() = erase(knn.begin());
push() = insert();
我对priority_queue 实现的速度感到惊讶,我期待不同的结果(对PQ 更好)......从概念上讲,多重集被用作优先级队列。为什么优先级队列和多重集具有如此不同的性能,即使是-O2
?
十个结果的平均值,MSVS 2010,Win XP,32 位,方法 findAllKNN2()(请参见下文)
MS
N time [s]
100 000 0.5
1 000 000 8
PQ
N time [s]
100 000 0.8
1 000 000 12
什么可能导致这些结果?没有对源代码进行其他更改...感谢您的帮助...
微软实施:
template <typename Point>
struct TKDNodePriority
{
KDNode <Point> *node;
typename Point::Type priority;
TKDNodePriority() : node ( NULL ), priority ( 0 ) {}
TKDNodePriority ( KDNode <Point> *node_, typename Point::Type priority_ ) : node ( node_ ), priority ( priority_ ) {}
bool operator < ( const TKDNodePriority <Point> &n1 ) const
{
return priority > n1.priority;
}
};
template <typename Point>
struct TNNeighboursList
{
typedef std::multiset < TKDNodePriority <Point> > Type;
};
方法:
template <typename Point>
template <typename Point2>
void KDTree2D <Point>::findAllKNN2 ( const Point2 * point, typename TNNeighboursList <Point>::Type & knn, unsigned int k, KDNode <Point> *node, const unsigned int depth ) const
{
if ( node == NULL )
{
return;
}
if ( point->getCoordinate ( depth % 2 ) <= node->getData()->getCoordinate ( depth % 2 ) )
{
findAllKNN2 ( point, knn, k, node->getLeft(), depth + 1 );
}
else
{
findAllKNN2 ( point, knn, k, node->getRight(), depth + 1 );
}
typename Point::Type dist_q_node = ( node->getData()->getX() - point->getX() ) * ( node->getData()->getX() - point->getX() ) +
( node->getData()->getY() - point->getY() ) * ( node->getData()->getY() - point->getY() );
if (knn.size() == k)
{
if (dist_q_node < knn.begin()->priority )
{
knn.erase(knn.begin());
knn.insert ( TKDNodePriority <Point> ( node, dist_q_node ) );
}
}
else
{
knn.insert ( TKDNodePriority <Point> ( node, dist_q_node ) );
}
typename Point::Type dist_q_node_straight = ( point->getCoordinate ( node->getDepth() % 2 ) - node->getData()->getCoordinate ( node->getDepth() % 2 ) ) *
( point->getCoordinate ( node->getDepth() % 2 ) - node->getData()->getCoordinate ( node->getDepth() % 2 ) ) ;
typename Point::Type top_priority = knn.begin()->priority;
if ( knn.size() < k || dist_q_node_straight < top_priority )
{
if ( point->getCoordinate ( node->getDepth() % 2 ) < node->getData()->getCoordinate ( node->getDepth() % 2 ) )
{
findAllKNN2 ( point, knn, k, node->getRight(), depth + 1 );
}
else
{
findAllKNN2 ( point, knn, k, node->getLeft(), depth + 1 );
}
}
}
PQ 实施(较慢,为什么?)
template <typename Point>
struct TKDNodePriority
{
KDNode <Point> *node;
typename Point::Type priority;
TKDNodePriority() : node ( NULL ), priority ( 0 ) {}
TKDNodePriority ( KDNode <Point> *node_, typename Point::Type priority_ ) : node ( node_ ), priority ( priority_ ) {}
bool operator < ( const TKDNodePriority <Point> &n1 ) const
{
return priority > n1.priority;
}
};
template <typename Point>
struct TNNeighboursList
{
typedef std::priority_queue< TKDNodePriority <Point> > Type;
};
方法:
template <typename Point>
template <typename Point2>
void KDTree2D <Point>::findAllKNN2 ( const Point2 * point, typename TNNeighboursList <Point>::Type & knn, unsigned int k, KDNode <Point> *node, const unsigned int depth ) const
{
if ( node == NULL )
{
return;
}
if ( point->getCoordinate ( depth % 2 ) <= node->getData()->getCoordinate ( depth % 2 ) )
{
findAllKNN2 ( point, knn, k, node->getLeft(), depth + 1 );
}
else
{
findAllKNN2 ( point, knn, k, node->getRight(), depth + 1 );
}
typename Point::Type dist_q_node = ( node->getData()->getX() - point->getX() ) * ( node->getData()->getX() - point->getX() ) +
( node->getData()->getY() - point->getY() ) * ( node->getData()->getY() - point->getY() );
if (knn.size() == k)
{
if (dist_q_node < knn.top().priority )
{
knn.pop();
knn.push ( TKDNodePriority <Point> ( node, dist_q_node ) );
}
}
else
{
knn.push ( TKDNodePriority <Point> ( node, dist_q_node ) );
}
typename Point::Type dist_q_node_straight = ( point->getCoordinate ( node->getDepth() % 2 ) - node->getData()->getCoordinate ( node->getDepth() % 2 ) ) *
( point->getCoordinate ( node->getDepth() % 2 ) - node->getData()->getCoordinate ( node->getDepth() % 2 ) ) ;
typename Point::Type top_priority = knn.top().priority;
if ( knn.size() < k || dist_q_node_straight < top_priority )
{
if ( point->getCoordinate ( node->getDepth() % 2 ) < node->getData()->getCoordinate ( node->getDepth() % 2 ) )
{
findAllKNN2 ( point, knn, k, node->getRight(), depth + 1 );
}
else
{
findAllKNN2 ( point, knn, k, node->getLeft(), depth + 1 );
}
}
}