[已答复] 如何加速这个含有kron的for循环

[复制链接]
cxz9503 发表于 2021-4-8 11:39:51
本帖最后由 cxz9503 于 2021-4-8 19:57 编辑

  1. % x是N个data组成的data matrix 维度 n by N。
  2. % S维度 k by N. S是非负矩阵。
  3. % 用rand生成x, S, 以下需要跑19s左右。为测试方便,N取得比较小。

  4. n= 800; N =100; k =10;
  5. x = rand(n,N); S = rand(k,N); H = 0;
  6. for i = 1: size(x,2)
  7.   X = x(:,i)*x(:,i)' ;
  8.   DW = diag( S(:,i) ) - S(:,i)*S(:,i)' ;  
  9.   H = H + kron(X,DW);
  10. end

复制代码


想不用for循环,全部向量化,但是不知道怎么处理,绕不过kron。

以下是我的尝试:
将 kron(X,DW)拆分,得到 kron(X,DW) = kron(x(:,i)*x(:,i)' , diag( S(:,i) ) ) - kron( x(:,i)*x(:,i)' , S(:,i)*S(:,i)' ) ;  
后半部分kron( x(:,i)*x(:,i)' , S(:,i)*S(:,i)' ) 可做如下改写:
使得 kron(X, S(:,i)*S(:,i)') 变成 kron(x(:,i), S(:,i))*kron(x(:,i), S(:,i))'。
这样只需要计算kron(x(:,i), S(:,i)). 而 kron(x(:,i), S(:,i)) = reshape( S(:,i)*x(:,i)' , [], 1);

所以求和后半部分(注释中)可改写为:


  1. % for i = 1:N
  2. %   H_1 = H_1 + kron( x(:,i)*x(:,i)' , S(:,i)*S(:,i)' ) ;
  3. % end

  4. for i = 1: N
  5.    temp(:,i) = reshape( S(:,i)*x(:,i)' , [], 1);
  6. end
  7. H_1 = temp*temp' ;  
复制代码


但是求和前半部分 kron(x(:,i)*x(:,i)' , diag( S(:,i) ) )  不知道怎么处理。



4 条回复


cxz9503 发表于 2021-4-8 19:57:08
本帖最后由 cxz9503 于 2021-4-9 09:37 编辑

我自己优化了一下,运行时间从19s降到1s左右。
还是使用 , [latex]则 kron(x(:,i)*x(:,i)', diag(S(:,i)) ) = kron( x(:,i), sqrt (diag(S(:,i)) ) ) * kron( x(:,i), sqrt (diag(S(:,i)) ) ).' ; 所以只需要计算 kron( x(:,i), sqrt (diag(S(:,i)) )
(S是非负矩阵)。

记 kron( x(:,i), sqrt (diag(S(:,i)) ) = M_i. 则 求和 sum_i [ kron(x(:,i)*x(:,i)', diag(S(:,i)) ) ] = [M_1 M_2 ... M_n]* [M_1' ; M_2' ;... ;M_n],即求和变成“向量积”。
再结合之前对求和第二部分的改写,具体代码如下


  1. n = 800; N = 100; k = 10;
  2. x = rand(n,N); S = rand(k,N);
  3. H1= 0; K_D= zeros(n*k, k*1, N); K_S = zeros(n*k,N); %K_D记录kron( x(:,i), sqrt (diag(S(:,i)) ) )  , K_S 记录kron(x(:,i), S(:,i));
  4. for i = 1:N
  5.       D_half =  diag( sqrt(S(:,i)) ) ;
  6.       K_D(:,:,i) = kron( x(:,i),D_half);
  7.       K_S(:,i) =  reshape (S(:,i)*x(:,i)',[],1);
  8. end

  9. K_D = reshape(K_D,n*k,[]);
  10. H = K_D*K_D' - K_S*K_S';
复制代码


以上代码如何再次优化加速,或者对原问题有新的思路吗?
谢谢!

caicaibi 发表于 2021-4-10 22:25:52
本帖最后由 caicaibi 于 2021-4-10 22:28 编辑

我把你的思路,搬到四维空间来了,估计平均下来,提高了5个点吧
  1. clear;clc
  2. n= 800; N =100; k =10;
  3. x = rand(n,N); S = rand(k,N);
  4. D_half=sqrt(shiftdim(repmat(S,1,1,k).*shiftdim(repmat(eye(k),1,1,N),1),2));
  5. K_D=reshape(reshape(x,1,n,1,N).*reshape(D_half,k,1,k,N),n*k,[]);
  6. K_S2=reshape(permute(repmat(S,1,1,n).*rot90(shiftdim(repmat(x,1,1,k),1)),[1 3 2]),n*k,N);
  7. H3= K_D*K_D' - K_S2*K_S2';
复制代码




caicaibi 发表于 2021-4-10 22:30:35
很久没遇到这么有意思的题目了,感谢题主,

cxz9503 发表于 2021-4-10 23:54:47
caicaibi 发表于 2021-4-10 22:30
很久没遇到这么有意思的题目了,感谢题主,

你的代码暂时有点看不懂,抽空好好看一下,速度的确有提升,谢谢!
您需要登录后才可以回帖 登录 | 注册

本版积分规则

相关帖子
热门教程
站长推荐
快速回复 返回顶部 返回列表