YoloKokura

最近在LeetCode Daily Challenge中碰到一道朴实无华的排序算法题,你只需要想办法在O(nlog(n))内排序一个整型数组就可以了,比如用归并排序、堆排序等。我觉得值得一写的原因是,我缺少对各种排序算法的整理,也没有仔细看过Golang中的排序算法源码实现,这篇blog补全了这个缺憾。

排序算法总结

选择排序

  • 找到当前数组未处理部分中最小的元素,和数组中第i个元素交换位置,直到整个数组排序完成。

这其实算是一种暴力解法,即每次拿出最小元素放到结果数组里(只是不用另外开辟内存空间而已),其时间复杂度为O(n2),因为每次找最小元素的时间(worst case)都与数组长度有关。

造成算法低效的原因是:每次扫描并不能给下次扫描提供任何信息,因此每次扫描都需要从头开始。另外,算法运行时间与数组中元素的初始顺序无关,即使给出一个有序数组,算法仍然需要进行n次扫描(当然,最后一次只有一个元素,相当于可以省下一次扫描)。

go
func sort(nums []int) {
    for i := 0; i < len(nums); i++ {
        min := i
        for j := i + 1; j < len(nums); j++ {
            if nums[j] < nums[min] {
                min = j
            }
        }
        nums[min], nums[i] = nums[i], nums[min]
    }
}

插入排序

  • 遍历数组元素,将当前元素插入到已排序部分的合适位置,直到整个数组排序完成。

时间复杂度为O(n2),一方面遍历数组需要O(n),另一方面每次插入元素需要遍历已排序部分,同样需要O(n)。但相比于选择排序,插入排序利用了已排序部分的信息,在数组基本有序的情况下,不需要像选择排序那样进行重复扫描。插入排序对于小规模数组的排序效率很高,因为它的时间复杂度接近于O(n)

go
// 用交换的方式实现插入排序
func sort(nums []int) {
    for i := 1; i < len(nums); i++ {
        for j := i; j > 0 && nums[j] < nums[j-1]; j-- {
            nums[j], nums[j-1] = nums[j-1], nums[j]
        }
    }
}

// 用右移的方式实现插入排序
func sort(nums []int) {
    for i := 1; i < len(nums); i++ {
        e := nums[i]
        j := i
        for ; j > 0 && nums[j-1] > e; j-- {
            // 将大于e的元素右移
            nums[j] = nums[j-1]
        }
        nums[j] = e
    }
}

希尔排序

  • 将数组分为h个子数组,对每个子数组进行插入排序(这些子数组的元素交织在一起,同一个子数组的元素位置相差h)。
  • 重复上述过程,h不断减小,直到h=1,此时数组已经基本有序。

当数组相对有序的时候,插入排序的效率会很高,而希尔排序权衡了子数组的规模和有序性。对规模较小的子数组进行插入排序,时间开销较小,而排序子数组之后,数组的局部有序性会提高,这样就可以进一步减少插入排序的时间开销。

希尔排序的时间复杂度和h的选择有关,如果h的选择不当,希尔排序的时间复杂度可能会退化到O(n2)。但是,如果h的选择合理,希尔排序的时间复杂度可以达到O(nlog(n))

go
func sort(nums []int) {
    // h的选择可以变化
    h := 1
    for h < len(nums)/3 {
        h = 3*h + 1
    }

    for h >= 1 {
        for i := h; i < len(nums); i++ {
            e := nums[i]
            j := i
            // 对h个子数组分别进行插入排序,这里使用右移法
            for ; j >= h && nums[j-h] > e; j -= h {
                nums[j] = nums[j-h]
            }
            nums[j] = e
        }
        h /= 3
    }
}

归并排序

  • 递归地将数组分为两个子数组,分别对两个子数组进行归并排序
  • 子数组长度为1时,认为该子数组已经有序,直接返回。
  • 合并子数组,利用两个指针分别指向两个子数组的头部,比较两个指针指向的元素,将较小的元素放入结果数组中,直到其中一个子数组遍历完毕,将另一个子数组剩余的元素放入结果数组中。

归并排序的背后是Divide and Conquer的思想,即将一个大问题分解为若干个小问题,分别解决,最后将结果合并起来。归并排序的时间复杂度为O(nlog(n)),但是自底而上的归并排序的代码量更小。关于时间复杂度的讨论可以参考这本书^1。然而,归并排序的空间复杂度为O(n),因为需要额外的空间存储归并结果,见下面代码块中对aux数组的讨论。

go
// 下面的代码片段将aux作为merge函数的局部变量,但是这并不是最好的实现,因为算法在每次使用merge函数时都需要反复申请和释放内存,更好的方法是将aux作为merge函数的参数传入。

// 自顶而下的归并排序: 递归分解数组,然后合并子数组
func sort(nums []int) {
    divide(nums, 0, len(nums))
}

func divide(nums []int, left, right int) {
    if right-left <= 1 {
        // 子数组长度为1,认为该子数组已经有序
        return
    }
    mid := left + (right - left) / 2
    divide(nums, left, mid)
    divide(nums, mid, right)
    merge(nums, left, mid, right)
}

func merge(nums []int, left, mid, right int) {
    // 将排序结果放入aux数组中暂存,最后复制到原数组
    aux := make([]int, right-left)
    i, j, k := left, mid, 0
    for i < mid && j < right {
        if nums[i] < nums[j] {
            aux[k] = nums[i]
            i++
        } else {
            aux[k] = nums[j]
            j++
        }
        k++
    }
    for i < mid {
        aux[k] = nums[i]
        i++
        k++
    }
    for j < right {
        aux[k] = nums[j]
        j++
        k++
    }
    copy(nums[left:right], aux)
}

// 自底而上的归并排序:先对小数组进行排序,然后两两合并,再对大数组进行排序
func sort(nums []int) {
    for sz := 1; sz < len(nums); sz *= 2 {
        for left := 0; left < len(nums)-sz; left += 2*sz {
            // 合并两个子数组, merge函数同上
            merge(nums, left, left+sz, min(left+2*sz, len(nums)))
        }
    }
}

func min(a, b int) int {
    if a < b {
        return a
    }
    return b
}

快速排序

  • 随机从数组中选择一个元素作为切分。
  • 将数组分为两个子数组,左子数组的元素都小于切分元素,右子数组的元素都大于切分元素。
  • 在子数组中递归地进行快速排序

快速排序同样采用分治思想,在每一次切分的过程之中,必然会确定切分元素在有序数组中的位置。此外,快排的切分过程是原地的,不需要额外的空间。反观归并排序,每次合并都需要额外的空间来暂存合并结果。快排的时间复杂度为O(nlog(n)),且不需要额外的空间开销。

go
func sort(nums []int) {
    // 打乱数组
    rand.Seed(time.Now().UnixNano())
    rand.Shuffle(len(nums), func(i, j int) {
        nums[i], nums[j] = nums[j], nums[i]
    })
    divide(nums, 0, len(nums))
}

func divide(nums []int, left, right int) {
    if right-left <= 1 {
        return
    }
    // 切分,由于数组已经打乱,因此可以保证切分元素是随机选择的
    j := partition(nums, left, right)
    divide(nums, left, j)
    divide(nums, j+1, right)
}

func partition(nums []int, left, right int) int {
    i, j := left+1, right-1
    mid := nums[left]
    for {
        // 在左右数组中分别找到不符合条件的元素,并进行交换
        for i < right && nums[i] < mid {
            i++
        }
        for j > left && nums[j] > mid {
            j--
        }
        if i >= j {
            break
        }
        nums[i], nums[j] = nums[j], nums[i]
        i++
        j--
    }
    // 将切分元素放到正确的位置
    nums[left], nums[j] = nums[j], nums[left]
    return j
}

TIP

当切分元素是数组中的最大或最小元素时,快排的时间复杂度退化为O(n2)。在切分之前先将数组打乱可以尽量避免这种情况。

也有一些针对快排的改进算法,比如在小数组排序上切换的效率更高的插入排序,或者使用三向切分(用子数组一部分元素的中位数作为切分元素)的方法,这些讨论同样参见之前提到的资料^1

堆排序

  • 将数组构造成一个最大堆或者最小堆
  • 依次将堆顶元素pop出来,放到数组的末尾(或开头),这样数组就变成了有序的。

由于堆或者优先级队列的本质是一种数据结构而非排序算法,本文不赘述它的具体实现。我们在这里只需要知道堆顶元素一定是最值元素即可。堆排序的时间复杂度为O(nlog(n)),但是由于堆的实现需要额外的空间开销,因此空间复杂度为O(n)

go
func sort(nums []int) {
    pq := make(PriorityQueue, len(nums))
    heap.Init(&pq)
    for _, num := range nums {
        heap.Push(&pq, num)
    }
    i := 0
    for heap.Len(&pq) > 0 {
        nums[i] = heap.Pop(&pq).(int)
        i++
    }
}

// Golang实现堆的boilerplate代码,确实很繁琐😡,尤其是当我们只需要处理简单的数据结构的时候

type PriorityQueue []int

func (pq PriorityQueue) Len() int {
    return len(pq)
}

func (pq PriorityQueue) Less(i, j int) bool {
    return pq[i] > pq[j]
}

func (pq PriorityQueue) Swap(i, j int) {
    pq[i], pq[j] = pq[j], pq[i]
}

func (pq *PriorityQueue) Push(x interface{}) {
    *pq = append(*pq, x.(int))
}

func (pq *PriorityQueue) Pop() interface{} {
    n := len(*pq)
    x := (*pq)[n-1]
    *pq = (*pq)[:n-1]
    return x
}

Golang 的 sort 包

Package sort provides primitives for sorting slices and user-defined collections.

基本使用

对于基本数据类型切片,sort内置3类常见数据类型切片的排序函数:

  • sort.Ints:对int切片进行排序。
  • sort.Float64s:对float64切片进行排序。
  • sort.Strings:对string切片进行排序。

事实上,sort预先定义了用于排序的Interface接口,只要实现了这个接口,就可以使用sort.Sort函数进行排序。而针对上述3种基本数据类型,sort包已经实现了对应的Interface接口:IntSlice、Float64Slice、StringSlice,在对应的排序函数中实际上进行了类型转换。

go
// 类似于前面堆排序中的PriorityQueue,这里的Interface要求实现Len、Less、Swap三个方法以便排序
type Interface interface {
        Len() int
        // 如果需要降序排序,可以将Less的返回值取反,即写Greater的逻辑
        Less(i, j int) bool
        Swap(i, j int)
}

// 下面代码以IntSlice为例,其他类型的实现类似
func Ints(x []int) { Sort(IntSlice(x)) }

type IntSlice []int
// Interface接口实现
func (x IntSlice) Len() int           { return len(x) }
func (x IntSlice) Less(i, j int) bool { return x[i] < x[j] }
func (x IntSlice) Swap(i, j int)      { x[i], x[j] = x[j], x[i] }

func (x IntSlice) Sort() { Sort(x) }

// 也可以自定义数据类型,实现Interface接口,使用sort.Sort进行排序
func Sort(data Interface) {
	n := data.Len()
	if n <= 1 {
		return
	}
	limit := bits.Len(uint(n))
	pdqsort(data, 0, n, limit)
}

What's inside this pdqsort function?

Golang 实际上在该函数中采用了一定的具体排序算法的选择策略,根据具体的元素平衡情况和切片长度,分别在不同的层级采用快排、插入排序和堆排序,具体的注释可参见下面sort包源码的代码块:

go
// pdqsort sorts data[a:b].
// The algorithm based on pattern-defeating quicksort(pdqsort), but without the optimizations from BlockQuicksort.
// pdqsort paper: https://arxiv.org/pdf/2106.05123.pdf
// C++ implementation: https://github.com/orlp/pdqsort
// Rust implementation: https://docs.rs/pdqsort/latest/pdqsort/
// limit is the number of allowed bad (very unbalanced) pivots before falling back to heapsort.
func pdqsort(data Interface, a, b, limit int) {
	const maxInsertion = 12

	var (
		wasBalanced    = true // whether the last partitioning was reasonably balanced
		wasPartitioned = true // whether the slice was already partitioned
	)

	for {
		length := b - a

		if length <= maxInsertion {
            // 当切片长度小于等于12时,使用插入排序,如前面讨论,切片长度较小时,插入排序的性能较好
			insertionSort(data, a, b)
			return
		}

		// Fall back to heapsort if too many bad choices were made.
		if limit == 0 {
            // 如果切片长度较大,且不平衡的切分过多,则直接使用堆排序
			heapSort(data, a, b)
			return
		}

		// If the last partitioning was imbalanced, we need to breaking patterns.
		if !wasBalanced {
			breakPatterns(data, a, b)
			limit--
		}

        // 以下可以理解为快排
		pivot, hint := choosePivot(data, a, b)
		if hint == decreasingHint {
			reverseRange(data, a, b)
			// The chosen pivot was pivot-a elements after the start of the array.
			// After reversing it is pivot-a elements before the end of the array.
			// The idea came from Rust's implementation.
			pivot = (b - 1) - (pivot - a)
			hint = increasingHint
		}

		// The slice is likely already sorted.
		if wasBalanced && wasPartitioned && hint == increasingHint {
			if partialInsertionSort(data, a, b) {
				return
			}
		}

		// Probably the slice contains many duplicate elements, partition the slice into
		// elements equal to and elements greater than the pivot.
		if a > 0 && !data.Less(a-1, pivot) {
			mid := partitionEqual(data, a, b, pivot)
			a = mid
			continue
		}

		mid, alreadyPartitioned := partition(data, a, b, pivot)
		wasPartitioned = alreadyPartitioned

		leftLen, rightLen := mid-a, b-mid
		balanceThreshold := length / 8
		if leftLen < rightLen {
			wasBalanced = leftLen >= balanceThreshold
			pdqsort(data, a, mid, limit)
			a = mid + 1
		} else {
			wasBalanced = rightLen >= balanceThreshold
			pdqsort(data, mid+1, b, limit)
			b = mid
		}
	}
}

CPP STL algorithm sort()

std::sort is a sorting algorithm that sorts the elements in the range [first, last) in non-descending order. The order of equal elements is not guaranteed to be preserved.

同Golang,在默认情况下,CPP使用< operator,即lesser进行排序。

以下是SGI STL中的sort函数实现。

cpp
template <class _RandomAccessIter>
inline void sort(_RandomAccessIter __first, _RandomAccessIter __last) {
    __STL_REQUIRES(_RandomAccessIter, _Mutable_RandomAccessIterator);
    __STL_REQUIRES(typename iterator_traits<_RandomAccessIter>::value_type,
                    _LessThanComparable);
    if (__first != __last) {
        // 当容器长度大于16时,使用introsort,introsort本身类似于三分位快排
        __introsort_loop(__first, __last,
                        __VALUE_TYPE(__first),
                        __lg(__last - __first) * 2);
        // 最后使用插入排序完成小数组的排序
        __final_insertion_sort(__first, __last);
    }
}

template <class _Size>
inline _Size __lg(_Size __n) {
    // 计算log2(__n)的值,即2的多少次方等于__n, __k将用作introsort的递归深度
    _Size __k;
    for (__k = 0; __n != 1; __n >>= 1) ++__k;
    return __k;
}

template <class _RandomAccessIter, class _Tp, class _Size>
void __introsort_loop(_RandomAccessIter __first,
                        _RandomAccessIter __last, _Tp*,
                        _Size __depth_limit)
{
    // __stl_threshold = 16 类似于Golang的maxInsertion
    // __depth_limit即递归深度,类似于Golang的limit
    while (__last - __first > __stl_threshold) {
        if (__depth_limit == 0) {
            // partial_sort用堆排序实现
            partial_sort(__first, __last, __last);
            return;
        }
        --__depth_limit;
        // __median求取首、尾、中位数三者的中位数
        _RandomAccessIter __cut =
          __unguarded_partition(__first, __last,
                                _Tp(__median(*__first,
                                             *(__first + (__last - __first)/2),
                                             *(__last - 1))));
        __introsort_loop(__cut, __last, (_Tp*) 0, __depth_limit);
        __last = __cut;
    }
}

template <class _RandomAccessIter, class _Tp>
_RandomAccessIter __unguarded_partition(_RandomAccessIter __first, 
                                        _RandomAccessIter __last, 
                                        _Tp __pivot) 
{
    // 类似于快排的partition
    while (true) {
        while (*__first < __pivot)
            ++__first;
        --__last;
        while (__pivot < *__last)
            --__last;
        if (!(__first < __last))
            return __first;
        iter_swap(__first, __last);
        ++__first;
    }
}

Tags: