In your question, you mentioned that you are having trouble selecting a pivot that is not at the start of the list, because this would require traversing the list. If you do it correctly, you only have to traverse the entire list twice:
- once for finding the middle and the end of the list in order to select a good pivot (e.g. using the "median-of-three" rule)
- once for the actual sorting
The first step is not necessary if you don't care much about selecting a good pivot and you are happy with simply selecting the first element of the list as the pivot (which causes worst case O(n^2) time complexity if the data is already sorted).
If you remember the end of the list the first time you traverse it by maintaining a pointer to the end, then you should never have to traverse it again to find the end. Also, if you are using the standard Lomuto partition scheme (which I am not using for the reasons stated below), then you must also maintain two pointers into the list which represent the i
and j
index of the standard Lomuto partition scheme. By using these pointers, should never have to traverse the list for accessing a single element.
Also, if you maintain a pointer to the middle and the end of every partition, then, when you later must sort one of these partitions, you will not have to traverse that partition again to find the middle and end.
I have now created my own implementation of the QuickSelect algorithm for linked lists, which I have posted below.
Since you stated that the linked list is singly-linked and cannot be upgraded to a doubly-linked list, I can't use the Hoare partition scheme, as iterating a singly-linked list backwards is very expensive. Therefore, I am using the generally less efficient Lomuto partition scheme instead.
When using the Lomuto partition scheme, either the first element or the last element is typically selected as a pivot. However, selecting either of those has the disadvantage that sorted data will cause the algorithm to have the worst-case time complexity of O(n^2). This can be prevented by selecting a pivot according to the "median-of-three" rule, which is to select a pivot from the median value of the first element, middle element and last element. Therefore, in my implementation, I am using this "median-of-three" rule.
Also, the Lomuto partition scheme typically creates two partitions, one for values smaller than the pivot and one for values larger than or equal to the pivot. However, this will cause the worst-case time complexity of O(n^2) if all values are identical. Therefore, in my implementation, I am creating three partitions, one for values smaller than the pivot, one for values larger than the pivot, and one for values equal to the pivot.
Although these measures don't completely eliminate the possibility of worst-case time complexity of O(n^2), they at least make it highly unlikely (unless the input is provided by a malicious attacker). In order to guarantee a time complexity of O(n), a more complex pivot selection algorithm would have to be used, such as median of medians.
One significant problem I encountered is that for an even number of elements, the median is defined as the arithmetic mean of the two "middle" or "median" elements. For this reason, I can't simply write a function similar to std::nth_element
, because if, for example, the total number of elements is 14, then I will be looking for the 7th and 8th largest element. This means I would have to call such a function twice, which would be inefficient. Therefore, I have instead written a function which can search for the two "median" elements at once. Although this makes the code more complex, the performance penalty due to the additional code complexity should be minimal compared to the advantage of not having to call the same function twice.
Please note that although my implementation compiles perfectly on a C++ compiler, I wouldn't call it textbook C++ code, because the question states that I am not allowed to use anything from the C++ standard template library. Therefore, my code is rather a hybrid of C code and C++ code.
In the following code, I only use the standard template library (in particular the function std::nth_element
) for testing my algorithm and for verifying the results. I do not use any of these functions in my actual algorithm.
#include <iostream>
#include <iomanip>
#include <cassert>
// The following two headers are only required for testing the algorithm and verifying
// the correctness of its results. They are not used in the algorithm itself.
#include <random>
#include <algorithm>
// The following setting can be changed to print extra debugging information
// possible settings:
// 0: no extra debugging information
// 1: print the state and length of all partitions in every loop iteraton
// 2: additionally print the contents of all partitions (if they are not too big)
#define PRINT_DEBUG_LEVEL 0
template <typename T>
struct Node
{
T data;
Node<T> *next;
};
// NOTE:
// The return type is not necessarily the same as the data type. The reason for this is
// that, for example, the data type "int" requires a "double" as a return type, so that
// the arithmetic mean of "3" and "6" returns "4.5".
// This function may require template specializations to handle overflow or wrapping.
template<typename T, typename U>
U arithmetic_mean( const T &first, const T &second )
{
return ( static_cast<U>(first) + static_cast<U>(second) ) / 2;
}
//the main loop of the function find_median can be in one of the following three states
enum LoopState
{
//we are looking for one median value
LOOPSTATE_LOOKINGFORONE,
//we are looking for two median values, and the returned median
//will be the arithmetic mean of the two
LOOPSTATE_LOOKINGFORTWO,
//one of the median values has been found, but we are still searching for
//the second one
LOOPSTATE_FOUNDONE
};
template <
typename T, //type of the data
typename U //type of the return value
>
U find_median( Node<T> *list )
{
//This variable points to the pointer to the first element of the current partition.
//During the partition phase, the linked list will be broken and reassembled afterwards, so
//the pointer this pointer points to will be nullptr until it is reassembled.
Node<T> **pp_start = &list;
//This pointer represents nothing more than the cached value of *pp_start and it is
//not always valid
Node<T> *p_start = *pp_start;
//These pointers are maintained for accessing the middle of the list for selecting a pivot
//using the "median-of-three" rule.
Node<T> *p_middle;
Node<T> *p_end;
//result is not defined if list is empty
assert( p_start != nullptr );
//in the main loop, this variable always holds the number of elements in the current partition
int num_total = 1;
// First, we must traverse the entire linked list in order to determine the number of elements,
// in order to calculate k1 and k2. If it is odd, then the median is defined as the k'th smallest
// element where k = n / 2. If the number of elements is even, then the median is defined as the
// arithmetic mean of the k'th element and the (k+1)'th element.
// We also set a pointer to the nodes in the middle and at the end, which will be required later
// for selecting a pivot according to the "median-of-three" rule.
p_middle = p_start;
for ( p_end = p_start; p_end->next != nullptr; p_end = p_end->next )
{
num_total++;
if ( num_total % 2 == 0 ) p_middle = p_middle->next;
}
// find out whether we are looking for only one or two median values
enum LoopState loop_state = num_total % 2 == 0 ? LOOPSTATE_LOOKINGFORTWO : LOOPSTATE_LOOKINGFORONE;
//set k to the index of the middle element, or if there are two middle elements, to the left one
int k = ( num_total - 1 ) / 2;
// If we are looking for two median values, but we have only found one, then this variable will
// hold the value of the one we found. Whether we have found one can be determined by the state of
// the variable loop_state.
T val_found;
for (;;)
{
//make p_start cache the value of *pp_start again, because a previous iteration of the loop
//may have changed the value of pp_start
p_start = *pp_start;
assert( p_start != nullptr );
assert( p_middle != nullptr );
assert( p_end != nullptr );
assert( num_total != 0 );
if ( num_total == 1 )
{
switch ( loop_state )
{
case LOOPSTATE_LOOKINGFORONE:
return p_start->data;
case LOOPSTATE_FOUNDONE:
return arithmetic_mean<T,U>( val_found, p_start->data );
default:
assert( false ); //this should be unreachable
}
}
//select the pivot according to the "median-of-three" rule
T pivot;
if ( p_start->data < p_middle->data )
{
if ( p_middle->data < p_end->data )
pivot = p_middle->data;
else if ( p_start->data < p_end->data )
pivot = p_end->data;
else
pivot = p_start->data;
}
else
{
if ( p_start->data < p_end->data )
pivot = p_start->data;
else if ( p_middle->data < p_end->data )
pivot = p_end->data;
else
pivot = p_middle->data;
}
#if PRINT_DEBUG_LEVEL >= 1
//this line is conditionally compiled for extra debugging information
std::cout << "\nmedian of three: " << (*pp_start)->data << " " << p_middle->data << " " << p_end->data << " ->" << pivot << std::endl;
#endif
// We will be dividing the current partition into 3 new partitions (less-than,
// equal-to and greater-than) each represented as a linked list. Each list
// requires a pointer to the start of the list and a pointer to the pointer at
// the end of the list to write the address of new elements to. Also, when
// traversing the lists, we need to keep a pointer to the middle of the list,
// as this information will be required for selecting a new pivot in the next
// iteration of the loop. The latter is not required for the equal-to partition,
// as it would never be used.
Node<T> *p_less = nullptr, **pp_less_end = &p_less, **pp_less_middle = &p_less;
Node<T> *p_equal = nullptr, **pp_equal_end = &p_equal;
Node<T> *p_greater = nullptr, **pp_greater_end = &p_greater, **pp_greater_middle = &p_greater;
// These pointers are only used as a cache to the location of the end node.
// Despite their similar name, their function is quite different to pp_less_end
// and pp_greater_end.
Node<T> *p_less_end = nullptr;
Node<T> *p_greater_end = nullptr;
// counter for the number of elements in each partition
int num_less = 0;
int num_equal = 0;
int num_greater = 0;
// NOTE:
// The following loop will temporarily split the linked list. It will be merged later.
Node<T> *p_next_node = p_start;
//the following line isn't necessary; it is only used to clarify that the pointers no
//longer point to anything meaningful
*pp_start = p_start = nullptr;
for ( int i = 0; i < num_total; i++ )
{
assert( p_next_node != nullptr );
Node<T> *p_current_node = p_next_node;
p_next_node = p_next_node->next;
if ( p_current_node->data < pivot )
{
//link node to pp_less
assert( *pp_less_end == nullptr );
*pp_less_end = p_less_end = p_current_node;
pp_less_end = &p_current_node->next;
p_current_node->next = nullptr;
num_less++;
if ( num_less % 2 == 0 )
{
pp_less_middle = &(*pp_less_middle)->next;
}
}
else if ( p_current_node->data == pivot )
{
//link node to pp_equal
assert( *pp_equal_end == nullptr );
*pp_equal_end = p_current_node;
pp_equal_end = &p_current_node->next;
p_current_node->next = nullptr;
num_equal++;
}
else
{
//link node to pp_greater
assert( *pp_greater_end == nullptr );
*pp_greater_end = p_greater_end = p_current_node;
pp_greater_end = &p_current_node->next;
p_current_node->next = nullptr;
num_greater++;
if ( num_greater % 2 == 0 )
{
pp_greater_middle = &(*pp_greater_middle)->next;
}
}
}
assert( num_total == num_less + num_equal + num_greater );
assert( num_equal >= 1 );
#if PRINT_DEBUG_LEVEL >= 1
//this section is conditionally compiled for extra debugging information
{
std::cout << std::setfill( '0' );
switch ( loop_state )
{
case LOOPSTATE_LOOKINGFORONE:
std::cout << "LOOPSTATE_LOOKINGFORONE k = " << k << "\n";
break;
case LOOPSTATE_LOOKINGFORTWO:
std::cout << "LOOPSTATE_LOOKINGFORTWO k = " << k << "\n";
break;
case LOOPSTATE_FOUNDONE:
std::cout << "LOOPSTATE_FOUNDONE k = " << k << " val_found = " << val_found << "\n";
}
std::cout << "partition lengths: ";
std::cout <<
std::setw( 2 ) << num_less << " " <<
std::setw( 2 ) << num_equal << " " <<
std::setw( 2 ) << num_greater << " " <<
std::setw( 2 ) << num_total << "\n";
#if PRINT_DEBUG_LEVEL >= 2
Node<T> *p;
std::cout << "less: ";
if ( num_less > 10 )
std::cout << "too many to print";
else
for ( p = p_less; p != nullptr; p = p->next ) std::cout << p->data << " ";
std::cout << "\nequal: ";
if ( num_equal > 10 )
std::cout << "too many to print";
else
for ( p = p_equal; p != nullptr; p = p->next ) std::cout << p->data << " ";
std::cout << "\ngreater: ";
if ( num_greater > 10 )
std::cout << "too many to print";
else
for ( p = p_greater; p != nullptr; p = p->next ) std::cout << p->data << " ";
std::cout << "\n\n" << std::flush;
#endif
std::cout << std::flush;
}
#endif
//insert less-than partition into list
assert( *pp_start == nullptr );
*pp_start = p_less;
//insert equal-to partition into list
assert( *pp_less_end == nullptr );
*pp_less_end = p_equal;
//insert greater-than partition into list
assert( *pp_equal_end == nullptr );
*pp_equal_end = p_greater;
//link list to previously cut off part
assert( *pp_greater_end == nullptr );
*pp_greater_end = p_next_node;
//if less-than partition is large enough to hold both possible median values
if ( k + 2 <= num_less )
{
//set the next iteration of the loop to process the less-than partition
//pp_start is already set to the desired value
p_middle = *pp_less_middle;
p_end = p_less_end;
num_total = num_less;
}
//else if less-than partition holds one of both possible median values
else if ( k + 1 == num_less )
{
if ( loop_state == LOOPSTATE_LOOKINGFORTWO )
{
//the equal_to partition never needs sorting, because all members are already equal
val_found = p_equal->data;
loop_state = LOOPSTATE_FOUNDONE;
}
//set the next iteration of the loop to process the less-than partition
//pp_start is already set to the desired value
p_middle = *pp_less_middle;
p_end = p_less_end;
num_total = num_less;
}
//else if equal-to partition holds both possible median values
else if ( k + 2 <= num_less + num_equal )
{
//the equal_to partition never needs sorting, because all members are already equal
if ( loop_state == LOOPSTATE_FOUNDONE )
return arithmetic_mean<T,U>( val_found, p_equal->data );
return p_equal->data;
}
//else if equal-to partition holds one of both possible median values
else if ( k + 1 == num_less + num_equal )
{
switch ( loop_state )
{
case LOOPSTATE_LOOKINGFORONE:
return p_equal->data;
case LOOPSTATE_LOOKINGFORTWO:
val_found = p_equal->data;
loop_state = LOOPSTATE_FOUNDONE;
k = 0;
//set the next iteration of the loop to process the greater-than partition
pp_start = pp_equal_end;
p_middle = *pp_greater_middle;
p_end = p_greater_end;
num_total = num_greater;
break;
case LOOPSTATE_FOUNDONE:
return arithmetic_mean<T,U>( val_found, p_equal->data );
}
}
//else both possible median values must be in the greater-than partition
else
{
k = k - num_less - num_equal;
//set the next iteration of the loop to process the greater-than partition
pp_start = pp_equal_end;
p_middle = *pp_greater_middle;
p_end = p_greater_end;
num_total = num_greater;
}
}
}
// NOTE:
// The following code is not part of the algorithm, but is only intended to test the algorithm
// This simple class is designed to contain a singly-linked list
template <typename T>
class List
{
public:
List() : first( nullptr ) {}
// the following is required to abide by the rule of three/five/zero
// see: https://en.cppreference.com/w/cpp/language/rule_of_three
List( const List<T> & ) = delete;
List( const List<T> && ) = delete;
List<T>& operator=( List<T> & ) = delete;
List<T>& operator=( List<T> && ) = delete;
~List()
{
Node<T> *p = first;
while ( p != nullptr )
{
Node<T> *temp = p;
p = p->next;
delete temp;
}
}
void push_front( int data )
{
Node<T> *temp = new Node<T>;
temp->data = data;
temp->next = first;
first = temp;
}
//member variables
Node<T> *first;
};
int main()
{
//generated random numbers will be between 0 and 2 billion (fits in 32-bit signed int)
constexpr int min_val = 0;
constexpr int max_val = 2*1000*1000*1000;
//will allocate array for 1 million ints and fill with random numbers
constexpr int num_values = 1*1000*1000;
//this class contains the singly-linked list and is empty for now
List<int> l;
double result;
//These variables are used for random number generation
std::random_device rd;
std::mt19937 gen( rd() );
std::uniform_int_distribution<> dis( min_val, max_val );
try
{
//fill array with random data
std::cout << "Filling array with random data..." << std::flush;
auto unsorted_data = std::make_unique<int[]>( num_values );
for ( int i = 0; i < num_values; i++ ) unsorted_data[i] = dis( gen );
//fill the singly-linked list
std::cout << "done\nFilling linked list..." << std::flush;
for ( int i = 0; i < num_values; i++ ) l.push_front( unsorted_data[i] );
std::cout << "done\nCalculating median using STL function..." << std::flush;
//calculate the median using the functions provided by the C++ standard template library.
//Note: this is only done to compare the results with the algorithm provided in this file
if ( num_values % 2 == 0 )
{
int median1, median2;
std::nth_element( &unsorted_data[0], &unsorted_data[(num_values - 1) / 2], &unsorted_data[num_values] );
median1 = unsorted_data[(num_values - 1) / 2];
std::nth_element( &unsorted_data[0], &unsorted_data[(num_values - 0) / 2], &unsorted_data[num_values] );
median2 = unsorted_data[(num_values - 0) / 2];
result = arithmetic_mean<int,double>( median1, median2 );
}
else
{
int median;
std::nth_element( &unsorted_data[0], &unsorted_data[(num_values - 0) / 2], &unsorted_data[num_values] );
median = unsorted_data[(num_values - 0) / 2];
result = static_cast<int>(median);
}
std::cout << "done\nMedian according to STL function: " << std::setprecision( 12 ) << result << std::endl;
// NOTE: Since the STL functions only sorted the array, but not the linked list, the
// order of the linked list is still random and not pre-sorted.
//calculate the median using the algorithm provided in this file
std::cout << "Starting algorithm" << std::endl;
result = find_median<int,double>( l.first );
std::cout << "The calculated median is: " << std::setprecision( 12 ) << result << std::endl;
std::cout << "Cleaning up\n\n" << std::flush;
}
catch ( std::bad_alloc )
{
std::cerr << "Error: Unable to allocate sufficient memory!" << std::endl;
return -1;
}
return 0;
}
I have successfully tested my code with one million randomly generated elements and it found the correct median virtually instantaneously.