区间DP 矩阵链的乘法

矩阵相乘

实际上矩阵相乘题是典型的区间动态规划

区间DP的操作如图,区间范围由小到大逐步增加
在这里插入图片描述
首先我们来看看动态规划的四个步骤

  1. 找出最优解的性质,并且刻画其结构特性;
  2. 递归的定义最优解;
  3. 以自底向上的方式刻画最优值;
  4. 根据计算最优值时候得到的信息,构造最优解

其中改进的动态规划算法:备忘录法,是以自顶向下的方式刻画最优值,对于动态规划方法和备忘录方法,两者的使用情况如下:

一般来讲,当一个问题的所有子问题都至少要解一次时,使用动态规划算法比使用备忘录方法好。此时,动态规划算法没有任何多余的计算。同时,对于许多问题,常常可以利用其规则的表格存取方式,减少动态规划算法的计算时间和空间需求。当子问题空间中的部分子问题可以不必求解时,使用备忘录方法则较为有利,因为从其控制结构可以看出,该方法只解那些确实需要求解的问题。

对于动态规划算法,我们必须明确两个基本要素,这两个要素对于在设计求解具体问题的算法时,是否选择动态规划算法具有指导意义:

  1. 算法有效性依赖于问题本身所具有的最优子结构性质设计算法的第一步通常是要刻画最优解的结构。当问题的最优解包含了子问题的最优解时,称该问题具有最优子结构性质。问题的最优子结构性质提供了该问题可以使用动态规划算法求解的重要线索

在矩阵连乘积最优次序问题中注意到,若A1A2…An的最优完全加括号方式在Ak和Ak+1之间断开,则由此可以确定的子链A1A2A3…Ak和Ak+1Ak+2…An的完全加括号方式也最优,即该问题具有最优子结构性质。在分析该问题的最优子结构性质时候,所使用的方法具有普遍性。首先假设由原问题导出的子问题的借不是最优解,然后在设法说明在这个假设下可以构造出比原问题最优解更好的解,从而导致矛盾。

在动态规划算法中,利用问题的最优子结构性质,以自底向上的方式递归的从子问题的最优解逐渐构造出整个问题的最优解。算法考察的子问题的空间规模较小。例如在举证连乘积的最优计算次序问题中,子问题空间由矩阵链的所有不用的子链组成。所有不用的子链的个数为o(n* n),因而子问题的空间规模为o(n* n)
  1. 可以用动态规划算法求解问题应该具备另一个基本要素是子问题的重叠性。在用递归算法自顶向下求解此问题时候,每次产生的子问题并不总是新问题,有些子问题被反复计算多次。动态规划算法正是利用了这种子问题的重叠性质,对每一个子问题都只是求解一次,而后将其保存到一个表格中,当再次需要解此问题时,只是简单使用常数时间查看一下结果。 通常,不同子问题个数随着问题大小呈多项式增长。因此使用动态规划算法通常只是需要多项式时间,从而获得较高的解题效率。

逐步分析

P=<30,35,15,5,10,20> 它对应5个矩阵

1
A1:30*35   A2:35*15   A3:15*5   A4:5*10   A5:10*20

先计算:
r=1表示两个矩阵相乘的运算量
m[1,1]=0 m[2,2]=0 m[3,3]=0 m[4,4]=0 m[5,5]=0

r=2表示两个矩阵相乘的运算量
m[1,2]=303515=15750
m[2,3]=35155=2625
m[3,4]=15510=750
m[4,5]=51020=1000

r=3表示3个矩阵相乘的运算量
m[1,3]=min{m[1,2]+30155,m[2,3]+30355}=min{15750+2625,2625+5250}=7875 A1(A2A3) s[1,3]=1
m[2,4]=min{m[2,3]+35510,m[3,4]+351510}={2625+1750,750+5250}=4375 (A2A3)A4 s[2,4]=3
m[3,5]=min{m[3,4]+151020,m[4,5]+15520}=2500 A3(A4A5) s[3,5]=3

r=4表示4个矩阵相乘的运算量
m[1,4]=min{m[2,4]+303510,m[1,2]+m[3,4]+301510,m[1,3]+30510}
=min{4375+10500,15750+750+4500,7875+1500}
=9375 [A1A2A3]A4—->(A1(A2A3))A4 s[1,4]=3
m[2,5]=min{m[3,5]+351520,m[2,3]+m[4,5]+35520,m[2,4]+351020}
=min{2500+10500,2625+1000+3500,4375+7000}
=7125 (A2A3)(A4A5) s[2,5]=3

r=5表示5个矩阵相乘的运算量
m[1,5]=min{m[2,5]+303520,m[1,2]+m[3,5]+301520,m[1,3]+m[4,5]+30520,m[1,4]+301020}
=min{7125+21000,15750+2500+9000,7875+1000+3000,9375+9000}
=11875 [A1A2A3](A4A5)---->(A1(A2A3))
(A4A5) s[1,5]=3

优化函数备忘录:

r=1 m[1,1]=0 m[2,2]=0 m[3,3]=0 m[4,4]=0 m[5,5]=0
r=2 m[1,2]=15750 m[2,3]=2625 m[3,4]=750 m[4,5]=1000
r=3 m[1,3]=7875 m[2,4]=4375 m[3,5]=2500
r=4 m[1,4]=9375 m[2,5]=7125
r=5 m[1,5]=11875

标记函数

r=2 s[1,2]=1 s[2,3]=2 s[3,4]=3 s[4,5]=4
r=3 s[1,3]=1 s[2,4]=3 s[3,5]=3
r=4 s[1,4]=3 s[2,5]=3
r=5 s[1,5]=3

根据 s[1,5]=3推出最后一次划分的位置为3 [A1A2A3][A4A5]
然后再由[A1A2A3]找s[1,3]=1
推出最后一次划分位置为1 A1(A2A3)
所以最终答案为:(A1(A2A3))(A4A5) 运算次数为:11875


练习:
再将P=<20,70,25,30,5,35,10>计算出来

Java实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
public class Strassen {
/*
* array[i][j]表示Ai...Aj相乘最少计算次数
* s[i][j]=k,表示Ai...Aj这(j-i+1)个矩阵中最优子结构为Ai...Ak和A(k+1)...Aj
* p[i]表示Ai的行数,p[i+1]表示Ai的列数
*/
private int array[][];
private int p[];
private int s[][];

public Strassen(){
p=new int[]{2,4,5,5,3};
array=new int[4][4];
s=new int[4][4];
}

public Strassen(int n,int []p){
this.p=new int[n+1];
this.array=new int[n][n];
this.s=new int[4][4];
for(int i=0;i<p.length;i++)
this.p[i]=p[i];
}
/*********************方法一,动态规划**********************************/
public void martixChain(){
int n=array.length;
for(int i=0;i<n;i++)
array[i][i]=0;
for(int r=2;r<=n;r++){
for(int i=0;i<=n-r;i++){
int j=i+r-1;
array[i][j]=array[i+1][j]+p[i]*p[i+1]*p[j+1];
s[i][j]=i;
for(int k=i+1;k<j;k++){
int t=array[i][k]+array[k+1][j]+p[i]*p[k+1]*p[j];
if(t<array[i][j]){
array[i][j]=t;
s[i][j]=k;
}
}
}
}
}
/*
* 如果待求矩阵为:Ap...Aq,then a=0,b=q-p
*/
public void traceBack(int a,int b){
if(a<b){
traceBack(a, s[a][b]);
traceBack(s[a][b]+1, b);
System.out.println("先把A"+a+"到A"+s[a][b]+"括起来,在把A"+(s[a][b]+1)+"到A"+b+"括起来,然后把A"+a+"到A"+b+"括起来");
}
}

/*********************方法二:备忘录方法*****************************/
public int memorizedMatrixChain(){
int n=array.length;
for(int i=0;i<n;i++){
for(int j=i;j<n;j++)
array[i][j]=0;
}
return lookUpChain(0,n-1);
}

public int lookUpChain(int a,int b){
if(array[a][b]!=0)
return array[a][b];
if(a==b)
return 0;
array[a][b]=lookUpChain(a, a)+lookUpChain(a+1, b)+p[a]*p[a+1]*p[b+1];
s[a][b]=a;
for(int k=a+1;k<b;k++){
int t=lookUpChain(a, k)+lookUpChain(k+1, b)+p[a]*p[k+1]*p[b+1];
if(t<array[a][b]){
array[a][b]=t;
s[a][b]=k;
}
}
return array[a][b];
}
public static void main(String[] args) {
Strassen strassen=new Strassen();
//strassen.martixChain();
strassen.memorizedMatrixChain();
strassen.traceBack(0, 3);
}
}

C++实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
LL a[100],d[100][100],b[100][100],n,cost;
//a为矩阵链各向量长度,d为优化函数,b为标记函数
void link(int l,int r) {
if (l==r) {
printf("A%d",l);
return;
}
//输出左边矩阵/矩阵链的括号表示法
if (b[l][r]-l>0) putchar('(');
link(l,b[l][r]);
if (b[l][r]-l>0) putchar(')');
//输出右边矩阵/矩阵链的括号表示法
if (r-b[l][r]>1) putchar('(');
link(b[l][r]+1,r);
if (r-b[l][r]>1) putchar(')');
}
int main()
{
//读入矩阵相乘的长度
scanf("%lld",&n);
for (int i=0;i<=n;i++) scanf("%lld",a+i);
//逐层计算r长度矩阵链的最优运算量,r=k+1
for (int k=1;k<n;k++) {
//p枚举区间右端,所计算区间为[p-k,p]
for (int p=k+1;p<=n;p++) {
//枚举标记点,按照题意,是左区间末端
for (int i=p-k;i<p;i++) {
cost=d[p-k][i]+d[i+1][p]+a[p-k-1]*a[p]*a[i];
if (!b[p-k][p] || cost<d[p-k][p]) {
b[p-k][p]=i; //标记
d[p-k][p]=cost; //更新优化函数
}
}
}
}
//优化函数备忘录
puts("优化函数备忘录");
for (int k=0;k<n;k++) {
printf("r=%d",k+1);
//p枚举区间右端,所计算区间为[p-k,p]
for (int p=k+1;p<=n;p++)
printf(" m[%d,%d]=%lld",p-k,p,d[p-k][p]);
printf("\n");
}
//标记函数
puts("标记函数");
for (int k=1;k<n;k++) {
printf("r=%d",k+1);
//p枚举区间右端,所计算区间为[p-k,p]
for (int p=k+1;p<=n;p++)
printf(" s[%d,%d]=%lld",p-k,p,b[p-k][p]);
printf("\n");
}
//括号表示法用axx表示矩阵
puts("括号表示法");
link(1,n);
printf("\n");
return 0;
}

/*
6
20 70 25 30 5 35 10
===================================================================
优化函数备忘录
r=1 m[1,1]=0 m[2,2]=0 m[3,3]=0 m[4,4]=0 m[5,5]=0 m[6,6]=0
r=2 m[1,2]=35000 m[2,3]=52500 m[3,4]=3750 m[4,5]=5250 m[5,6]=1750
r=3 m[1,3]=50000 m[2,4]=12500 m[3,5]=8125 m[4,6]=3250
r=4 m[1,4]=19500 m[2,5]=24750 m[3,6]=6750
r=5 m[1,5]=23000 m[2,6]=17750
r=6 m[1,6]=22250
===================================================================
标记函数
r=2 s[1,2]=1 s[2,3]=2 s[3,4]=3 s[4,5]=4 s[5,6]=5
r=3 s[1,3]=2 s[2,4]=2 s[3,5]=4 s[4,6]=4
r=4 s[1,4]=1 s[2,5]=4 s[3,6]=4
r=5 s[1,5]=4 s[2,6]=4
r=6 s[1,6]=4
===================================================================
括号表示法
(A1(A2(A3A4)))(A5A6)
*/