序:
这篇文章的内容跟上一篇是一模一样的,只不过调换了两部分的顺序,以更好地吸引读者。
其实在写上一篇的时候,我就反复思考过两部分的顺序问题,最终做出了错误的决定 >_<
这两篇文章的对比,可以作为如何写科普文的一个教材 :D
Matlab 作为一个强大的科学计算工具,内置了许多算法的高效实现。但智者千虑,必有一失,我就曾经在 Matlab 中发现一个函数的实现非常低效,以至于在应用的时候会产生「卡死」的假象。而这个函数的功能其实非常简单,不过是一道普通的面试算法题的水平。
在这篇文章中,我就带大家看看 Matlab 是怎么失手的,并用高效的实现解决这道算法题。这道算法题看似简单,实际上却有深厚的应用背景,我也会给大家展示一下它有什么用。
一、如何「掰平」一个不单调的序列?
这道算法题是这样的:已知一个不单调递增的序列
,现在要用最小的代价把它「掰平」。用数学语言说,就是求另一个(不需要严格)单调递增的序列
,使得它跟
尽可能接近。「接近」的标准是让误差
最小,其中
是序列第
项的权重。为讨论简便,认为所有权重均为正。
举一个最简单的样例:假设输入的序列为
,每项权重均为 1。输出序列应该是
,它把原序列的下降段
「掰平」了,而且这是误差最小的掰法。
Matlab 自带的做法是这样的:以输入序列
为例,所有元素权重均为 1。
首先看
这个下降段。现在要把它掰平,那么把两个数都掰成多少能使得误差最小呢?不难发现,答案应该是 4 和 3 的平均数,即 3.5。如果这两个数有不同的权重,那么使得误差最小的,就应该是它们的加权平均数(证明留给读者)。
按照这种思路,可以把输入序列中三个下降段 

分别掰成 3.5、3、6,得到
。这就结束了吗?并没有 —— 因为 3.5 和 3 这两个段落又违反单调性了。此时,就要把 3.5 和 3 这两个段落整体掰平。3.5 段落的总权重为 2,3 段落的总权重为 3,所以掰平的结果应该是加权平均数 3.2。此时得到的序列为
,满足单调性,所以这就是要求的序列
。
上面逐渐把下降段「掰平」的过程用图象表示如下,蓝点为输入序列,红点及红线为所求的单调序列。

概括一下上述算法的流程,就是不断地在序列中寻找下降段,并把下降段掰成整体的加权平均数,直到序列单调递增为止。Matlab 中实现这个算法的核心代码如下:
yhat = y; % 用输入序列初始化输出序列
block = 1:length(y); % block(i) 表示第 i 个元素属于第几个段落
% 初始时每个元素独立成段
while true
diffs = diff(yhat); % 求所有相邻元素之差
if all(diffs >= 0), break; end % 若已满足单调性,退出
idx = cumsum([1; (diffs > 0)]); % 找出序列中所有的下降段,并依次编号
% 例如,若输入为 1,4,3,5,3,1,7,5
% 则编号结果为 1,2,2,3,3,3,4,4
sumyhat = accumarray(idx, w.*yhat); % 计算每段元素的加权和
w = accumarray(idx, w); % 计算每段元素的总权重
yhat = sumyhat ./ w; % 求出每段元素的加权平均数
block = idx(block); % 更新每个元素所属的段落编号
end
yhat = yhat(block); % 构建输出序列
这段代码使用了一些 Matlab 特有的操作(比如 cumsum、accumarray),可能比较难理解。理解的关键在于,在迭代过程中,yhat 并不是记录了完整的序列,而是对序列中每一个水平段落,只记录一个值。上文所举例子的执行过程如下表所示,它会帮助你理解。

上面的实现方式是正确的,但在面试中只能得到一半的分数。它有什么问题呢?当然是复杂度啦!不难看出,每次迭代的时间复杂度为
,而迭代次数的上限也是
,所以总复杂度为
。下面这个例子可以达到复杂度的上限:输入序列
,权重
。这个例子的精髓在于,序列有且仅有前两个元素组成下降段,并且因为第一个元素 10000 的权重很大,把前两个元素取加权平均合并后,序列第一段的值依然会很大。这个巨大的值会在每次迭代中吃且仅吃掉后面的一个元素,导致迭代次数达到
。
事实上,「反复合并下降段」这个过程,完全可以用
的时间复杂度来实现。我们从左到右扫描序列的每一个元素,并用一个栈来维护已经扫描的部分「掰平」后的各个水平段落。当扫描到一个新的元素的时候,先把它作为一个单独的段落压入栈顶,然后反复查看栈顶的两个段落,如果它们违反了单调性,就把它们合并。这种实现的代码如下:
yhat = y; N = length(y); % 用输入序列初始化输出序列
bstart = zeros(1,N); bend = zeros(1,N); % 栈:bstart(i), bend(i) 记录第 i 段的起止位置
% 此外 yhat 和 w 也兼用作栈,
% yhat(i) 与 w(i) 表示第 i 段的值和总权重
b = 0; % 栈顶指针
for i = 1:N % 依次扫描每个元素
b = b + 1; % 由此往下三行:新元素作为单独的段落入栈
yhat(b) = yhat(i); w(b) = w(i);
bstart(b) = i; bend(b) = i;
while b > 1 && yhat(b) < yhat(b-1) % 栈顶两个段落违反单调性
yhat(b-1) = (yhat(b-1) * w(b-1) + yhat(b) * w(b)) / (w(b-1) + w(b));
w(b-1) = w(b-1) + w(b);
bend(b-1) = bend(b);
b = b - 1; % 由此往上四行:栈顶两个段落取加权平均合并
end
end
block = zeros(1,N);
for i = 1:b
block(bstart(i) : bend(i)) = i; % 由栈中信息反推出输出序列的每个元素位于第几段
end
yhat = yhat(block); % 构建输出序列这段代码的主体循环没有用到 Matlab 的黑科技,比较好懂,所以样例数据的执行过程我就不写了。
我的实现的时间复杂度为
。虽然有时一个元素入栈会引发连锁式的段落合并,但考虑算法的整个执行过程,一共会有
个元素入栈,最多有
次段落合并,所以复杂度为
。
现在反过来想想,Matlab 自带的实现慢在哪儿了呢?仍然考虑极端输入
,
。在迭代过程中,序列的尾部始终是单调递增的,但 Matlab 的实现在每次迭代中都徒劳无功地在序列的尾部检查是否有下降段。这就是它慢的原因。
二、「掰平」算法的应用:Multi-dimensional Scaling
在第一部分中,我们成功地把「掰平」算法的复杂度从
降到了
,其中
为序列长度。把不单调的序列「掰平」这件事儿,在数学上有个学名,叫做单调回归变换(isotonic regression)。不过它有什么用呢?
Matlab 中用来求解最优单调回归变换的函数叫 lsqisotonic,其中 lsq 是 least squares(最小二乘)的意思,指的是误差函数的形式;isotonic 就是单调的意思啦。这个函数并不能直接调用,因为它是统计工具包中的一个私有函数,专供 mdscale 函数使用。而 mdscale 函数做的事情叫做 multi-dimensional scaling,这就是上面小算法题的大应用啦!单调回归变换与 multi-dimensional scaling 的关系有点儿长,且听我娓娓道来。
. Multi-dimensional scaling 是一种数据可视化的方法。这个名字不太容易翻译成中文,主要是因为 scaling 这个词的用法比较奇怪。维基百科给出的中文翻译是「多维标度」,其实挺不知所云的,甚至都不能体现出这是一个动名词。而日文翻译叫「多次元尺度構成法」,我觉得可以把二者融合一下,译作「多维尺度构成法」。在下文中,我就把这种方法简称为 MDS 了。
MDS 并不是目前最流行的数据可视化方法,最流行的应该是机器学习大佬 Hinton 开创的 t-SNE。在这个回答中,我用 MDS 和 t-SNE 两种方法对我 2012 年的人人好友关系进行了可视化,并比较了它们的效果:王赟 Maigo:有没有那种方式可以将高维数据进行可视化?比如保持数据结构不变将高维数据映射到低维空间?下图就是用 MDS 可视化的结果:

MDS 的输入,是
个对象两两之间的差异度(dissimilarities),共
个数值。如果已知的是相似度(similarities),则可以通过一个单调递减的函数转换成差异度。记第
个对象之间的差异度为
。MDS 要做的事情,是在一个给定维数(通常为二维或三维)的空间中找一组点
来代表这些对象,使得第
两个点之间的距离
尽可能接近给定的差异度
。具体来说,是要最小化如下的目标函数,这个函数称为压力(stress):
其中
是点对
的权重。一般来说,所有权重都取为 1;如果输入数据不全,某一组差异度
没有测量到,那么可以通过设置
来把这个点对排除掉。当然,如果认为某些点对的差异度比另一些点对更重要,也可以给每一个点对赋予不同的权重。
求使得压力最小化的点集
的方法有很多。比如梯度下降法就可以使用;Matlab 中 mdscale 函数使用的是一种共轭梯度法。除此之外,Modern Multidimensional Scaling 一书的第 8 章还介绍了一种称为 SMACOF 的迭代算法,它与机器学习中常见的 EM 算法有相似之处,二者都是 MM 算法的特例。不过求点集
的算法不是本文讨论的重点,所以我就不继续展开了。
在实际问题中,差异度不一定是定比数据,而有可能只是定序的(参见:王赟 Maigo:华裔数学家陶哲轩IQ230,是智商100聪明程度的几倍?),即差异度的数值并无意义,而只有它们之间的大小关系有意义。这种情形的 MDS,称为 non-metric MDS。上文中的压力函数依赖于差异度的具体数值,在 non-metric MDS 中再使用这种压力函数,就显得不合理了。于是就有了下面这种新的压力函数:
其中
是差异度经过变换
的结果。变换
仅需要满足单调性,它的作用就是说明差异度的数值并不重要,重要的只有大小关系 —— 注意了,这就是第一部分研究的单调回归变换!
要最小化这种新的压力函数,一方面需要求出一组点的坐标
,另一方面还要求出一个变换
,形成了一个「鸡生蛋,蛋生鸡」的问题。这种问题一般也是通过迭代算法来解决的,即不停重复下面的步骤:
- 固定
,求使得压力最小化的
。事实上这一步并不需要使得压力「最小化」,只要能让它减小就行了。这一步可以使用梯度下降法、共轭梯度法、SMACOF 等任意一种方法,且只需迭代一次。 - 固定
,求使得压力最小化的单调回归变换
。这一步同样只要让压力减小就行了,不过让压力最小化也不困难,因为我们在本文的第一部分已经解决了这个问题啦!
也许你还没有看出第一部分做的算法题是怎么在 MDS 里面用来最小化压力函数的。我们把所有的差异度
从小到大排序,得到
,其中
是点对的数目。与
对应的那个点对在空间中的距离
记作
,其权重记作
。现在我们要做的,就是求一组变换后的差异度
,我们把它们记作
。它们要满足跟
一样的大小关系,即满足单调性:
。经过了这些变量替换,压力函数就变成了
—— 看,这不就是算法题中的误差函数嘛!
我用来可视化的人人好友数据,总共有 1000 多人。在运行 Matlab 自带的 mdscale 时,遭遇了「卡死」的现象 —— 程序迟迟运行不出结果。现在就可以知道原因了:Matlab 求解最优单调回归变换的复杂度是
,注意这里面的
是点对的数目,它与好友人数
的关系是
。也就是说,Matlab 自带的 lsqisotonic 函数的时间复杂度,达到了吓人的
!面对上千人的大数据,难怪会卡死了。
三、附记
我实现的 lsqisotonic 函数,可以从 Mathworks File Exchange 上下载。这个函数位于 Matlab 安装目录下的 toolbox\stats\stats\private 子目录,可以用我的版本替代原有版本。
对 MDS 感兴趣的读者,推荐阅读 Modern Multidimensional Scaling 一书。其中第 8、9 章介绍的就是本文讨论的 non-metric MDS。第 12 章介绍了 metric MDS 的另一种情形 classical MDS,它最小化的目标函数并不是 stress,而是另一种称为 strain 的目标函数;其优点是求解过程不是迭代的,而是可以一步到位。Classical MDS 在 Matlab 中由 cmdscale 函数实现。

