Strassen矩陣乘法
矩陣乘法是線性代數(shù)中最常見的運算之一,它在數(shù)值計算中有廣泛的應(yīng)用。若A和B是2個n×n的矩陣,則它們的乘積C=AB同樣是一個n×n的矩陣。A和B的乘積矩陣C中的元素C[i,j]定義為:
若依此定義來計算A和B的乘積矩陣C,則每計算C的一個元素C[i,j],需要做n個乘法和n-1次加法。因此,求出矩陣C的n2個元素所需的計算時間為0(n3)。
60年代末,Strassen采用了類似于在大整數(shù)乘法中用過的分治技術(shù),將計算2個n階矩陣乘積所需的計算時間改進到O(nlog7)=O(n2.18)。
首先,我們還是需要假設(shè)n是2的冪。將矩陣A,B和C中每一矩陣都分塊成為4個大小相等的子矩陣,每個子矩陣都是n/2×n/2的方陣。由此可將方程C=AB重寫為:
由此可得:
C11=A11B11+A12B21 (2)
C12=A11B12+A12B22 (3)
C21=A21B11+A22B21 (4)
C22=A21B12+A22B22 (5)
如果n=2,則2個2階方陣的乘積可以直接用(2)-(3)式計算出來,共需8次乘法和4次加法。當(dāng)子矩陣的階大于2時,為求2個子矩陣的積,可以繼續(xù)將子矩陣分塊,直到子矩陣的階降為2。這樣,就產(chǎn)生了一個分治降階的遞歸算法。依此算法,計算2個n階方陣的乘積轉(zhuǎn)化為計算8個n/2階方陣的乘積和4個n/2階方陣的加法。2個n/2×n/2矩陣的加法顯然可以在c*n2/4時間內(nèi)完成,這里c是一個常數(shù)。因此,上述分治法的計算時間耗費T(n)應(yīng)該滿足:
這個遞歸方程的解仍然是T(n)=O(n3)。因此,該方法并不比用原始定義直接計算更有效。究其原因,乃是由于式(2)-(5)并沒有減少矩陣的乘法次數(shù)。而矩陣乘法耗費的時間要比矩陣加減法耗費的時間多得多。要想改進矩陣乘法的計算時間復(fù)雜性,必須減少子矩陣乘法運算的次數(shù)。按照上述分治法的思想可以看出,要想減少乘法運算次數(shù),關(guān)鍵在于計算2個2階方陣的乘積時,能否用少于8次的乘法運算。Strassen提出了一種新的算法來計算2個2階方陣的乘積。他的算法只用了7次乘法運算,但增加了加、減法的運算次數(shù)。這7次乘法是:
M1=A11(B12-B22)
M2=(A11+A12)B22
M3=(A21+A22)B11
M4=A22(B21-B11)
M5=(A11+A22)(B11+B22)
M6=(A12-A22)(B21+B22)
M7=(A11-A21)(B11+B12)
做了這7次乘法后,再做若干次加、減法就可以得到:
C11=M5+M4-M2+M6
C12=M1+M2
C21=M3+M4
C22=M5+M1-M3-M7
以上計算的正確性很容易驗證。例如:
C22=M5+M1-M3-M7
=(A11+A22)(B11+B22)+A11(B12-B22)-(A21+A22)B11-(A11-A21)(B11+B12)
=A11B11+A11B22+A22B11+A22B22+A11B12
-A11B22-A21B11-A22B11-A11B11-A11B12+A21B11+A21B12
=A21B12+A22B22
由(2)式便知其正確性。
至此,我們可以得到完整的Strassen算法如下:
procedure STRASSEN(n,A,B,C);
begin
if n=2 then MATRIX-MULTIPLY(A,B,C)
else begin
將矩陣A和B依(1)式分塊;
STRASSEN(n/2,A11,B12-B22,M1);
STRASSEN(n/2,A11+A12,B22,M2);
STRASSEN(n/2,A21+A22,B11,M3);
STRASSEN(n/2,A22,B21-B11,M4);
STRASSEN(n/2,A11+A22,B11+B22,M5);
STRASSEN(n/2,A12-A22,B21+B22,M6);
STRASSEN(n/2,A11-A21,B11+B12,M7);; end; end;
其中MATRIX-MULTIPLY(A,B,C)是按通常的矩陣乘法計算C=AB的子算法。
Strassen矩陣乘積分治算法中,用了7次對于n/2階矩陣乘積的遞歸調(diào)用和18次n/2階矩陣的加減運算。由此可知,該算法的所需的計算時間T(n)滿足如下的遞歸方程:
按照解遞歸方程的套用公式法,其解為T(n)=O(nlog7)≈O(n2.81)。由此可見,Strassen矩陣乘法的計算時間復(fù)雜性比普通矩陣乘法有階的改進。
有人曾列舉了計算2個2階矩陣乘法的36種不同方法。但所有的方法都要做7次乘法。除非能找到一種計算2階方陣乘積的算法,使乘法的計算次數(shù)少于7次,按上述思路才有可能進一步改進矩陣乘積的計算時間的上界。但是Hopcroft和Kerr(197l)已經(jīng)證明,計算2個2×2矩陣的乘積,7次乘法是必要的。因此,要想進一步改進矩陣乘法的時間復(fù)雜性,就不能再寄希望于計算2×2矩陣的乘法次數(shù)的減少?;蛟S應(yīng)當(dāng)研究3×3或5×5矩陣的更好算法。在Strassen之后又有許多算法改進了矩陣乘法的計算時間復(fù)雜性。目前最好的計算時間上界是O(n2.367)。而目前所知道的矩陣乘法的最好下界仍是它的平凡下界Ω(n2)。因此到目前為止還無法確切知道矩陣乘法的時間復(fù)雜性。關(guān)于這一研究課題還有許多工作可做。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191 |
/**
* Strassen矩陣乘法
* */ import java.util.*;
public class Strassen{
public Strassen(){
A = new int [NUMBER][NUMBER];
B = new int [NUMBER][NUMBER];
C = new int [NUMBER][NUMBER];
}
/**
* 輸入矩陣函數(shù)
* */
public void input( int a[][]){
Scanner scanner = new Scanner(System.in);
for ( int i = 0 ; i < a.length; i++) {
for ( int j = 0 ; j < a[i].length; j++) {
a[i][j] = scanner.nextInt();
}
}
}
/**
* 輸出矩陣
* */
public void output( int [][] resault){
for ( int b[] : resault) {
for ( int temp : b) {
System.out.print(temp + " " );
}
System.out.println();
}
}
/**
* 矩陣乘法,此處只是定義了2*2矩陣的乘法
* */
public void Mul( int [][] first, int [][] second, int [][] resault){
for ( int i = 0 ; i < 2 ; ++i) {
for ( int j = 0 ; j < 2 ; ++j) {
resault[i][j] = 0 ;
for ( int k = 0 ; k < 2 ; ++k) {
resault[i][j] += first[i][k] * second[k][j];
}
}
}
}
/**
* 矩陣的加法運算
* */
public void Add( int [][] first, int [][] second, int [][] resault){
for ( int i = 0 ; i < first.length; i++) {
for ( int j = 0 ; j < first[i].length; j++) {
resault[i][j] = first[i][j] + second[i][j];
}
}
}
/**
* 矩陣的減法運算
* */
public void sub( int [][] first, int [][] second, int [][] resault){
for ( int i = 0 ; i < first.length; i++) {
for ( int j = 0 ; j < first[i].length; j++) {
resault[i][j] = first[i][j] - second[i][j];
}
}
}
/**
* strassen矩陣算法
* */
public void strassen( int [][] A, int [][] B, int [][] C){
//定義一些中間變量
int [][] M1= new int [NUMBER][NUMBER];
int [][] M2= new int [NUMBER][NUMBER];
int [][] M3= new int [NUMBER][NUMBER];
int [][] M4= new int [NUMBER][NUMBER];
int [][] M5= new int [NUMBER][NUMBER];
int [][] M6= new int [NUMBER][NUMBER];
int [][] M7= new int [NUMBER][NUMBER];
int [][] C11= new int [NUMBER][NUMBER];
int [][] C12= new int [NUMBER][NUMBER];
int [][] C21= new int [NUMBER][NUMBER];
int [][] C22= new int [NUMBER][NUMBER];
int [][] A11= new int [NUMBER][NUMBER];
int [][] A12= new int [NUMBER][NUMBER];
int [][] A21= new int [NUMBER][NUMBER];
int [][] A22= new int [NUMBER][NUMBER];
int [][] B11= new int [NUMBER][NUMBER];
int [][] B12= new int [NUMBER][NUMBER];
int [][] B21= new int [NUMBER][NUMBER];
int [][] B22= new int [NUMBER][NUMBER];
int [][] temp= new int [NUMBER][NUMBER];
int [][] temp1= new int [NUMBER][NUMBER];
if (A.length== 2 ){
Mul(A, B, C);
} else {
//首先將矩陣A,B 分為4塊
for ( int i = 0 ; i < A.length/ 2 ; i++) {
for ( int j = 0 ; j < A.length/ 2 ; j++) {
A11[i][j]=A[i][j];
A12[i][j]=A[i][j+A.length/ 2 ];
A21[i][j]=A[i+A.length/ 2 ][j];
A22[i][j]=A[i+A.length/ 2 ][j+A.length/ 2 ];
B11[i][j]=B[i][j];
B12[i][j]=B[i][j+A.length/ 2 ];
B21[i][j]=B[i+A.length/ 2 ][j];
B22[i][j]=B[i+A.length/ 2 ][j+A.length/ 2 ];
}
}
//計算M1
sub(B12, B22, temp);
Mul(A11, temp, M1);
//計算M2
Add(A11, A12, temp);
Mul(temp, B22, M2);
//計算M3
Add(A21, A22, temp);
Mul(temp, B11, M3);
//M4
sub(B21, B11, temp);
Mul(A22, temp, M4);
//M5
Add(A11, A22, temp1);
Add(B11, B22, temp);
Mul(temp1, temp, M5);
//M6
sub(A12, A22, temp1);
Add(B21, B22, temp);
Mul(temp1, temp, M6);
//M7
sub(A11, A21, temp1);
Add(B11, B12, temp);
Mul(temp1, temp, M7);
//計算C11
Add(M5, M4, temp1);
sub(temp1, M2, temp);
Add(temp, M6, C11);
//計算C12
Add(M1, M2, C12);
//C21
Add(M3, M4, C21);
//C22
Add(M5, M1, temp1);
sub(temp1, M3, temp);
sub(temp, M7, C22);
//結(jié)果送回C中
for ( int i = 0 ; i < C.length/ 2 ; i++) {
for ( int j = 0 ; j < C.length/ 2 ; j++) {
C[i][j]=C11[i][j];
C[i][j+C.length/ 2 ]=C12[i][j];
C[i+C.length/ 2 ][j]=C21[i][j];
C[i+C.length/ 2 ][j+C.length/ 2 ]=C22[i][j];
}
}
}
}
public static void main(String[] args){
Strassen demo= new Strassen();
System.out.println( "輸入矩陣A" );
demo.input(A);
System.out.println( "輸入矩陣B" );
demo.input(B);
demo.strassen(A, B, C);
demo.output(C);
}
private static int A[][];
private static int B[][];
private static int C[][];
private final static int NUMBER = 4 ; } |
【測試】:
1 1 1 1
1 1 1 1
1 1 1 1
1 1 1 1
-----------
2 2 2 2
2 2 2 2
2 2 2 2
2 2 2 2
--------
8 8 8 8
8 8 8 8
8 8 8 8
8 8 8 8