GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现

2023-09-06,

GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现

RNN
GRU
matlab codes

RNN网络考虑到了具有时间数列的样本数据,但是RNN仍存在着一些问题,比如随着时间的推移,RNN单元就失去了对很久之前信息的保存和处理的能力,而且存在着gradient vanishing问题。

所以有些特殊类型的RNN网络相继被提出,比如LSTM(long short term memory)和GRU(gated recurrent unit)(Chao,et al. 2014).这里我主要推导一下GRU参数的迭代过程

GRU单元结构如下图所示

1479126283494.jpg

数据流过程如下

其中表示Hadamard积,即对应元素乘积;下标表示节点的index,上标表示时刻;表示隐层到输出层的参数矩阵,分别是隐层和输出层的节点个数;分别表示输入和上一时刻隐层到更新门z的连接矩阵,表示输入数据的维度;分别表示输入和上一时刻隐层到重置门r的连接矩阵;分别表示输入和上一时刻的隐层到待选状态的连接矩阵。

针对于时刻t,使用链式求导法则,计算参数矩阵的梯度,其中E是代价函数,首先计算对隐层输出的梯度,因为隐层输出牵涉到多个时刻

所以

其中分别是对应激活函数的线性和部分

现在对参数计算梯度

将上面的式子矢量化(行向量)表示:

那接下来使用matlab来实现一个小例子,看看GRU的效果,同样是二进制相加的问题

    function error= GRUtest( ) 


    % 初始化训练数据 


    uNum=16;%单元个数 


    maxInt=2^uNum; 


    % 初始化网络结构 


    xdim=2; 


    ydim=1; 


    hdim=16; 


    eta=0.1; 


    %初始化网络参数 


    Wy=rand(hdim,ydim)*2-1; 


    Wr=rand(xdim,hdim)*2-1; 


    Ur=rand(hdim,hdim)*2-1; 


    W =rand(xdim,hdim)*2-1; 


    U =rand(hdim,hdim)*2-1; 


    Wz=rand(xdim,hdim)*2-1; 


    Uz=rand(hdim,hdim)*2-1; 



    rvalues=zeros(uNum+1,hdim); 


    zvalues=zeros(uNum+1,hdim); 


    hbarvalues=zeros(uNum,hdim); 


    hvalues = zeros(uNum,hdim); 


    yvalues=zeros(uNum,ydim); 



    for p=1:10000 


    aInt=randi(maxInt/2); 


    bInt=randi(maxInt/2); 


    cInt=aInt+bInt; 


    at=dec2bin(aInt)-'0'; 


    bt=dec2bin(bInt)-'0'; 


    ct=dec2bin(cInt)-'0'; 


    a=zeros(1,uNum); 


    b=zeros(1,uNum); 


    c=zeros(1,uNum); 


    a(1:size(at,2))=at(end:-1:1); 


    b(1:size(bt,2))=bt(end:-1:1); 


    c(1:size(ct,2))=ct(end:-1:1); 


    xvalues=[a;b]'; 


    d=c'; 



    % 前向计算 


    rvalues(1,:)=sigmoid(xvalues(1,:)*Wr); 


    hbarvalues(1,:)=outTanh(xvalues(1,:)*W); 


    zvalues(1,:)=sigmoid(xvalues(1,:)*Wz); 


    hvalues(1,:)=zvalues(1,:).*hbarvalues(1,:); 


    yvalues(1,:)=sigmoid(hvalues(1,:)*Wy); 


    for t=2:uNum 


    rvalues(t,:)=sigmoid(xvalues(t,:)*Wr+hvalues(t-1,:)*Ur); 


    hbarvalues(t,:)=outTanh(xvalues(t,:)*W+(rvalues(t,:).*hvalues(t-1,:))*U); 


    zvalues(t,:)=sigmoid(xvalues(t,:)*Wz+hvalues(t-1,:)*Uz); 


    hvalues(t,:)=(1-zvalues(t,:)).*hvalues(t-1,:)+zvalues(t,:).*hbarvalues(t,:); 


    yvalues(t,:)=sigmoid(hvalues(t,:)*Wy);  


    end 



    % 误差反向传播 


    delta_r_next=zeros(1,hdim); 


    delta_z_next=zeros(1,hdim); 


    delta_h_next=zeros(1,hdim); 


    delta_next=zeros(1,hdim); 



    dWy=zeros(hdim,ydim); 


    dWr=zeros(xdim,hdim); 


    dUr=zeros(hdim,hdim); 


    dW=zeros(xdim,hdim); 


    dU=zeros(hdim,hdim); 


    dWz=zeros(xdim,hdim); 


    dUz=zeros(hdim,hdim); 



    for t=uNum:-1:2 


    delta_y=(yvalues(t,:)-d(t,:)).*diffsigmoid(yvalues(t,:)); 


    delta_h=delta_y*Wy'+delta_z_next*Uz'+delta_next*U'.*rvalues(t+1,:)+delta_r_next*Ur'+delta_h_next.*(1-zvalues(t+1,:)); 


    delta_z=delta_h.*(hbarvalues(t,:)-hvalues(t-1,:)).*diffsigmoid(zvalues(t,:)); 


    delta =delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)); 


    delta_r=hvalues(t-1,:).*((delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)))*U').*diffsigmoid(rvalues(t,:)); 



    dWy=dWy+hvalues(t,:)'*delta_y; 


    dWz=dWz+xvalues(t,:)'*delta_z; 


    dUz=dUz+hvalues(t-1,:)'*delta_z; 


    dW =dW+xvalues(t,:)'*delta; 


    dU =dU+(rvalues(t,:).*hvalues(t-1,:))'*delta ; 


    dWr=dWr+xvalues(t,:)'*delta_r; 


    dUr=dUr+hvalues(t-1,:)'*delta_r; 



    delta_r_next=delta_r; 


    delta_z_next=delta_z; 


    delta_h_next=delta_h; 


    delta_next =delta; 



    end 



    t=1; 


    delta_y=(yvalues(t,:)-d(t,:)).*diffsigmoid(yvalues(t,:)); 


    delta_h=delta_y*Wy'+delta_z_next*Uz'+delta_next*U'.*rvalues(t+1,:)+delta_r_next*Ur'+delta_h_next.*(1-zvalues(t+1,:)); 


    delta_z=delta_h.*(hbarvalues(t,:)-0).*diffsigmoid(zvalues(t,:)); 


    delta =delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)); 


    delta_r=0.*((delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)))*U').*diffsigmoid(rvalues(t,:)); 



    dWy=dWy+hvalues(t,:)'*delta_y; 


    dWz=dWz+xvalues(t,:)'*delta_z; 


    dW =dW+xvalues(t,:)'*delta; 


    dWr=dWr+xvalues(t,:)'*delta_r; 



    Wy = Wy-eta*dWy; 


    Wr = Wr-eta*dWr; 


    Ur = Ur-eta*dUr; 


    W = W -eta*dW; 


    U = U-eta*dU; 


    Wz = Wz-eta*dWz; 


    Uz = Uz-eta*dUz; 


    error = (norm(yvalues-d,2))/2.0; 


    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 


    if mod(p,500)==0 


    fprintf('******************第%s次迭代****************\n',int2str(p)); 


    yvalues=round(yvalues(end:-1:1)); 


    y=bin2dec(int2str(yvalues')); 


    fprintf('y=%d\n',y); 


    fprintf('c=%d\n',cInt); 


    fprintf('样本误差:e=%f\n',error); 


    end 


    end 


    end 



    function f=sigmoid(x) 


    f=1./(1+exp(-x)); 


    end 



    function fd = diffsigmoid(f) 


    fd=f.*(1-f); 


    end 



    function g=outTanh(x) 


    g=1-2./(1+exp(2*x)); 


    end 



    function gd=diffoutTanh(g) 


    gd=1-g.^2; 


    end 


部分实验结果

1479392393541.jpg

GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现的相关教程结束。

《GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现.doc》

下载本文的Word格式文档,以方便收藏与打印。