如何泛化某个算法代码?

引言

假设有一段很棒的代码放在一个文件里,如何将这段代码泛化,使之成为一个符合STL风格的算法?

来考虑一个目前STL中还没有的算法。首先说明它要解决的问题 : 假设你有一组数(数学意义上的"集合",而非容器意义上的set),这些数是不连续的、无序的,理论上取自正整数,例如:

{5,2,7,9,4,0,1}

我们要找的是这组数中最小的缺失数。对于上面这个例子,答案是3。如果把这些数排好序,解决方案就显而易见了——adjacent_find(相邻查找)可以帮助我们找到第一个缺口。

初始解法(O(n.log(n)))

// 假设我们的集合在一个数组或vector中
std::sort(a.begin(), a.end());
auto p = std::adjacent_find(
    a.begin(), a.end(), 
    [] (int x, int y) { return y-x > 1; });
int unused = a.back() + 1;
if (p != a.end())
  unused = *p + 1;

这个解法的渐进时间复杂度是 O(n.log(n)),因为 adjacent_find 是线性的,而 sort 的复杂度占主导。

更好的解法:寻找最小未使用数

有没有更好的方法?乍一看并不明显,但这个问题其实有一个线性时间复杂度的解法。关键在于使用分治策略,而这项工作的主力军就是 partition(划分)。

让我们从假设最小未使用值是0开始。我们选择一个轴值 m,它是序列的假设中点——它实际上不一定要出现在序列中,但我们简单地假定它就是中间值。然后用这个轴值对序列进行划分,我们会得到这样一种情况:第一个大于等于 m 的值位于位置 p(这是 partition 的返回值)。如果序列的前半部分没有空缺,m 就等于 p,我们就可以递归处理后半部分,并将新的最小未使用值设为 m

我们知道,由于集合中不存在重复元素,m 不可能小于 p。因此,如果 m 不等于 p,它必然更大——这意味着在位置 p 之前至少存在一个空缺,我们可以在保持最小未使用值不变的情况下,对前半部分进行递归。

算法的基本情况是当我们需要划分的序列为空时。此时我们就找到了最小未使用值。

算法代码

void min_unused()
{
  // 初始化随机数生成器
  std::array seed_data;
  std::random_device r;
  std::generate_n(seed_data.data(), seed_data.size(), std::ref(r));
  std::seed_seq seq(std::begin(seed_data), std::end(seed_data));
  std::mt19937 gen(seq);

  // 用整数填充数组,打乱顺序,然后丢弃一些
  int a[10];
  std::iota(&a[0], &a[10], 0);
  std::shuffle(&a[0], &a[10], gen);
  int first_idx = 0;
  int last_idx = 7; // 任意的截断点

  for (int i = first_idx; i < last_idx; ++i)
    std::cout << a[i] << '\n';

  // 算法主体
  int unused = 0;
  while (first_idx != last_idx) {
    int m = unused + (last_idx - first_idx + 1)/2;
    auto p = std::partition(&a[first_idx], &a[last_idx],
                            [&] (int i) { return i < m; });
    if (p - &a[first_idx] == m - unused) {
      unused = m;
      first_idx = p - &a[0];
    } else {
      last_idx = p - &a[0];
    }
  }

  std::cout << "Min unused: " << unused << '\n';
}

注意:这个算法的渐进时间复杂度是线性的;在每个阶段我们都在运行 partition,这是 O(n) 的操作,然后我们只在序列的一半上递归。因此复杂度为:

O(n + n/2 + n/4 + n/8 + …) = O(∑_{i=0}∞ n/2^i) = O(2n) = O(n)

从具体到泛化:第一步

现在我们有了一个可以工作的算法……但只适用于C风格数组,而且只在一个地方使用。显然我们想要尽可能地将其泛化;那么,将其转化为一个可以像STL中那些算法一样方便使用的优秀算法,需要哪些步骤呢?

版本1:接收Vector的函数

// 版本1:接收vector的函数
int min_unused(std::vector& v)
{
  int* first = &v[0];
  int* last = first + v.size();
  int unused = 0;
  while (first != last) {
    int m = unused + (last - first + 1)/2;
    auto p = std::partition(first, last,
                            [&] (int i) { return i < m; });
    if (p - first == m - unused) {
      unused = m;
      first = p;
    } else {
      last = p;
    }
  }
  return unused;
}

版本2:使用迭代器

// 版本2:使用迭代器,像真正的算法那样
template <typename It>
inline int min_unused(It first, It last)
{
  int unused = 0;
  while (first != last) {
    int m = unused + (last - first + 1)/2;
    auto p = std::partition(first, last,
                            [&] (int i) { return i < m; });
    if (p - first == m - unused) {
      unused = m;
      first = p;
    } else {
      last = p;
    }
  }
  return unused;
}

还要注意,由于 min_unused 现在是一个函数模板,我们添加了 inline 关键字,这意味着我们可以把它放在头文件中,而不会在多个翻译单元中实例化时产生链接错误。

版本3:推断value_type

// 版本3:推断value_type
template <typename It>
inline typename std::iterator_traits<It>::value_type min_unused(It first, It last)
{
  using T = typename std::iterator_traits<It>::value_type;
  T unused{};
  while (first != last) {
    T m = unused + (last - first + 1)/2;
    auto p = std::partition(first, last,
                            [&] (const T& i) { return i < m; });
    if (p - first == m - unused) {
      unused = m;
      first = p;
    } else {
      last = p;
    }
  }
  return unused;
}

这里我们使用 iterator_traits 来发现传入迭代器的 value_type。这适用于任何满足Iterator概念的类型。因此我们也将散落在代码中的 int 替换成了 T。注意在第6行,我们对 T 进行了值初始化,这会对整型进行正确的零初始化。

版本4:允许用户指定初始最小值

// 版本4:允许用户指定初始最小值
template <typename It, typename T = typename std::iterator_traits<It>::value_type>
inline T min_unused(It first, It last, T init = T{})
{
  while (first != last) {
    T m = init + (last - first + 1)/2;
    auto p = std::partition(first, last,
                            [&] (const T& i) { return i < m; });
    if (p - first == m - init) {
      init = m;
      first = p;
    } else {
      last = p;
    }
  }
  return init;
}

调用者提供的最小值可能会非常有用——在很多场景下0是保留值,或者需要一个非零的起始值。

迭代器再探

此时,min_unused 看起来至少在表面上更像一个真正的STL算法了,但它还不完全达标。第7行和第10行做的是普通的算术运算,这意味着传入的迭代器需要支持这些运算。STL的一个基本概念是迭代器类别的思想:这是一个定义不同类型迭代器可用操作的概念。

版本5:泛化处理迭代器算术

// 版本5:泛化处理迭代器算术
template <typename It, typename T = typename std::iterator_traits<It>::value_type>
inline T min_unused(It first, It last, T init = T{})
{
  while (first != last) {
    T m = init + (std::distance(first, last)+1)/2;
    auto p = std::partition(first, last,
                            [&] (const T& i) { return i < m; });
    if (std::distance(first, p) == m - init) {
      init = m;
      first = p;
    } else {
      last = p;
    }
  }
  return init;
}

在上面的代码中,第7行和第10行的算术运算已经被替换成了 distance 调用。这是一个关键改动!现在我们不再局限于 vector 或普通指针那样的随机访问迭代器了。我们的算法现在可以作用于所有类型的迭代器。

查看 distancepartition 的文档,我们发现只需要前向迭代器就足够了。

版本6:记录放宽的迭代器类别

// 版本6:记录放宽的迭代器类别
template <typename ForwardIt, typename T = typename std::iterator_traits<ForwardIt>::value_type>
inline T min_unused(ForwardIt first, ForwardIt last, T init = T{})
{
  while (first != last) {
    T m = init + (std::distance(first, last)+1)/2;
    auto p = std::partition(first, last,
                            [&] (const T& i) { return i < m; });
    if (std::distance(first, p) == m - init) {
      init = m;
      first = p;
    } else {
      last = p;
    }
  }
  return init;
}

更加泛化:chrono的方式

T 上做减法产生的结果类型可能是 T 本身以外的东西。chrono 库中就是这样,time_pointduration 就是这种情况。两个 time_point 相减得到一个 duration。两个 time_point 相加是无意义的,因此也是非法的,但 duration 可以和自身相加,也可以和 time_point 相加。

用数学语言来说,chrono 建模了一个一维仿射空间。

min_unused 的情况也是如此:T 是我们的 time_point 对应物,而 distance 的结果是我们的 duration 对应物,这表明我们可以将差分类型提取为一个模板参数:

template <typename ForwardIt, 
          typename T = typename std::iterator_traits<ForwardIt>::value_type,
          typename DiffT = typename std::iterator_traits<ForwardIt>::difference_type>
inline T min_unused(ForwardIt first, ForwardIt last, T init = T{})
{
  while (first != last) {
    T m = init + DiffT{(std::distance(first, last)+1)/2};
    auto p = std::partition(first, last,
                            [&] (const T& i) { return i < m; });
    if (DiffT{std::distance(first, p)} == m - init) {
      init = m;
      first = p;
    } else {
      last = p;
    }
  }
  return init;
}

这样做的好处是,我们现在可以用 min_unused 来处理任何值类型和差分类型不同的代码,比如 chrono

// 一个time_point的vector,间隔1秒
constexpr auto VECTOR_SIZE = 10;
std::vector<std::chrono::system_clock::time_point> v;
auto start_time = std::chrono::system_clock::time_point{};
std::generate_n(std::back_inserter(v), VECTOR_SIZE,
                [&] () {
                  auto t = start_time;
                  start_time += 1s;
                  return t;
                });
std::shuffle(v.begin(), v.end(), gen);
v.resize(3*VECTOR_SIZE/4);

for (auto i : v)
  std::cout
    << std::chrono::duration_cast<std::chrono::seconds>(i.time_since_epoch()).count()
    << '\n';

// 现在可以找到不存在的最小time_point
auto u = min_unused<
  decltype(v.begin()), decltype(start_time), std::chrono::seconds>(
      v.begin(), v.end());

cout << "Min unused: "
     << std::chrono::duration_cast<std::chrono::seconds>(u.time_since_epoch()).count()
     << '\n';