如何泛化某个算法代码?
引言
假设有一段很棒的代码放在一个文件里,如何将这段代码泛化,使之成为一个符合STL风格的算法?
来考虑一个目前STL中还没有的算法。首先说明它要解决的问题 : 假设你有一组数(数学意义上的"集合",而非容器意义上的set),这些数是不连续的、无序的,理论上取自正整数,例如:
{5,2,7,9,4,0,1}
我们要找的是这组数中最小的缺失数。对于上面这个例子,答案是3。如果把这些数排好序,解决方案就显而易见了——adjacent_find(相邻查找)可以帮助我们找到第一个缺口。
初始解法(O(n.log(n)))
// 假设我们的集合在一个数组或vector中
std::sort(a.begin(), a.end());
auto p = std::adjacent_find(
a.begin(), a.end(),
[] (int x, int y) { return y-x > 1; });
int unused = a.back() + 1;
if (p != a.end())
unused = *p + 1;
这个解法的渐进时间复杂度是 O(n.log(n)),因为 adjacent_find 是线性的,而 sort 的复杂度占主导。
更好的解法:寻找最小未使用数
有没有更好的方法?乍一看并不明显,但这个问题其实有一个线性时间复杂度的解法。关键在于使用分治策略,而这项工作的主力军就是 partition(划分)。
让我们从假设最小未使用值是0开始。我们选择一个轴值 m,它是序列的假设中点——它实际上不一定要出现在序列中,但我们简单地假定它就是中间值。然后用这个轴值对序列进行划分,我们会得到这样一种情况:第一个大于等于 m 的值位于位置 p(这是 partition 的返回值)。如果序列的前半部分没有空缺,m 就等于 p,我们就可以递归处理后半部分,并将新的最小未使用值设为 m。
我们知道,由于集合中不存在重复元素,m 不可能小于 p。因此,如果 m 不等于 p,它必然更大——这意味着在位置 p 之前至少存在一个空缺,我们可以在保持最小未使用值不变的情况下,对前半部分进行递归。
算法的基本情况是当我们需要划分的序列为空时。此时我们就找到了最小未使用值。
算法代码
void min_unused()
{
// 初始化随机数生成器
std::array seed_data;
std::random_device r;
std::generate_n(seed_data.data(), seed_data.size(), std::ref(r));
std::seed_seq seq(std::begin(seed_data), std::end(seed_data));
std::mt19937 gen(seq);
// 用整数填充数组,打乱顺序,然后丢弃一些
int a[10];
std::iota(&a[0], &a[10], 0);
std::shuffle(&a[0], &a[10], gen);
int first_idx = 0;
int last_idx = 7; // 任意的截断点
for (int i = first_idx; i < last_idx; ++i)
std::cout << a[i] << '\n';
// 算法主体
int unused = 0;
while (first_idx != last_idx) {
int m = unused + (last_idx - first_idx + 1)/2;
auto p = std::partition(&a[first_idx], &a[last_idx],
[&] (int i) { return i < m; });
if (p - &a[first_idx] == m - unused) {
unused = m;
first_idx = p - &a[0];
} else {
last_idx = p - &a[0];
}
}
std::cout << "Min unused: " << unused << '\n';
}
注意:这个算法的渐进时间复杂度是线性的;在每个阶段我们都在运行
partition,这是 O(n) 的操作,然后我们只在序列的一半上递归。因此复杂度为:
O(n + n/2 + n/4 + n/8 + …) = O(∑_{i=0}∞ n/2^i) = O(2n) = O(n)
从具体到泛化:第一步
现在我们有了一个可以工作的算法……但只适用于C风格数组,而且只在一个地方使用。显然我们想要尽可能地将其泛化;那么,将其转化为一个可以像STL中那些算法一样方便使用的优秀算法,需要哪些步骤呢?
版本1:接收Vector的函数
// 版本1:接收vector的函数
int min_unused(std::vector& v)
{
int* first = &v[0];
int* last = first + v.size();
int unused = 0;
while (first != last) {
int m = unused + (last - first + 1)/2;
auto p = std::partition(first, last,
[&] (int i) { return i < m; });
if (p - first == m - unused) {
unused = m;
first = p;
} else {
last = p;
}
}
return unused;
}
版本2:使用迭代器
// 版本2:使用迭代器,像真正的算法那样
template <typename It>
inline int min_unused(It first, It last)
{
int unused = 0;
while (first != last) {
int m = unused + (last - first + 1)/2;
auto p = std::partition(first, last,
[&] (int i) { return i < m; });
if (p - first == m - unused) {
unused = m;
first = p;
} else {
last = p;
}
}
return unused;
}
还要注意,由于
min_unused现在是一个函数模板,我们添加了inline关键字,这意味着我们可以把它放在头文件中,而不会在多个翻译单元中实例化时产生链接错误。
版本3:推断value_type
// 版本3:推断value_type
template <typename It>
inline typename std::iterator_traits<It>::value_type min_unused(It first, It last)
{
using T = typename std::iterator_traits<It>::value_type;
T unused{};
while (first != last) {
T m = unused + (last - first + 1)/2;
auto p = std::partition(first, last,
[&] (const T& i) { return i < m; });
if (p - first == m - unused) {
unused = m;
first = p;
} else {
last = p;
}
}
return unused;
}
这里我们使用
iterator_traits来发现传入迭代器的value_type。这适用于任何满足Iterator概念的类型。因此我们也将散落在代码中的int替换成了T。注意在第6行,我们对T进行了值初始化,这会对整型进行正确的零初始化。
版本4:允许用户指定初始最小值
// 版本4:允许用户指定初始最小值
template <typename It, typename T = typename std::iterator_traits<It>::value_type>
inline T min_unused(It first, It last, T init = T{})
{
while (first != last) {
T m = init + (last - first + 1)/2;
auto p = std::partition(first, last,
[&] (const T& i) { return i < m; });
if (p - first == m - init) {
init = m;
first = p;
} else {
last = p;
}
}
return init;
}
调用者提供的最小值可能会非常有用——在很多场景下0是保留值,或者需要一个非零的起始值。
迭代器再探
此时,min_unused 看起来至少在表面上更像一个真正的STL算法了,但它还不完全达标。第7行和第10行做的是普通的算术运算,这意味着传入的迭代器需要支持这些运算。STL的一个基本概念是迭代器类别的思想:这是一个定义不同类型迭代器可用操作的概念。
版本5:泛化处理迭代器算术
// 版本5:泛化处理迭代器算术
template <typename It, typename T = typename std::iterator_traits<It>::value_type>
inline T min_unused(It first, It last, T init = T{})
{
while (first != last) {
T m = init + (std::distance(first, last)+1)/2;
auto p = std::partition(first, last,
[&] (const T& i) { return i < m; });
if (std::distance(first, p) == m - init) {
init = m;
first = p;
} else {
last = p;
}
}
return init;
}
在上面的代码中,第7行和第10行的算术运算已经被替换成了
distance调用。这是一个关键改动!现在我们不再局限于vector或普通指针那样的随机访问迭代器了。我们的算法现在可以作用于所有类型的迭代器。
查看
distance和partition的文档,我们发现只需要前向迭代器就足够了。
版本6:记录放宽的迭代器类别
// 版本6:记录放宽的迭代器类别
template <typename ForwardIt, typename T = typename std::iterator_traits<ForwardIt>::value_type>
inline T min_unused(ForwardIt first, ForwardIt last, T init = T{})
{
while (first != last) {
T m = init + (std::distance(first, last)+1)/2;
auto p = std::partition(first, last,
[&] (const T& i) { return i < m; });
if (std::distance(first, p) == m - init) {
init = m;
first = p;
} else {
last = p;
}
}
return init;
}
更加泛化:chrono的方式
T 上做减法产生的结果类型可能是 T 本身以外的东西。chrono 库中就是这样,time_point 和 duration 就是这种情况。两个 time_point 相减得到一个 duration。两个 time_point 相加是无意义的,因此也是非法的,但 duration 可以和自身相加,也可以和 time_point 相加。
用数学语言来说,chrono 建模了一个一维仿射空间。
min_unused 的情况也是如此:T 是我们的 time_point 对应物,而 distance 的结果是我们的 duration 对应物,这表明我们可以将差分类型提取为一个模板参数:
template <typename ForwardIt,
typename T = typename std::iterator_traits<ForwardIt>::value_type,
typename DiffT = typename std::iterator_traits<ForwardIt>::difference_type>
inline T min_unused(ForwardIt first, ForwardIt last, T init = T{})
{
while (first != last) {
T m = init + DiffT{(std::distance(first, last)+1)/2};
auto p = std::partition(first, last,
[&] (const T& i) { return i < m; });
if (DiffT{std::distance(first, p)} == m - init) {
init = m;
first = p;
} else {
last = p;
}
}
return init;
}
这样做的好处是,我们现在可以用 min_unused 来处理任何值类型和差分类型不同的代码,比如 chrono:
// 一个time_point的vector,间隔1秒
constexpr auto VECTOR_SIZE = 10;
std::vector<std::chrono::system_clock::time_point> v;
auto start_time = std::chrono::system_clock::time_point{};
std::generate_n(std::back_inserter(v), VECTOR_SIZE,
[&] () {
auto t = start_time;
start_time += 1s;
return t;
});
std::shuffle(v.begin(), v.end(), gen);
v.resize(3*VECTOR_SIZE/4);
for (auto i : v)
std::cout
<< std::chrono::duration_cast<std::chrono::seconds>(i.time_since_epoch()).count()
<< '\n';
// 现在可以找到不存在的最小time_point
auto u = min_unused<
decltype(v.begin()), decltype(start_time), std::chrono::seconds>(
v.begin(), v.end());
cout << "Min unused: "
<< std::chrono::duration_cast<std::chrono::seconds>(u.time_since_epoch()).count()
<< '\n';