GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现
RNN
GRU
matlab codes
RNN网络考虑到了具有时间数列的样本数据,但是RNN仍存在着一些问题,比如随着时间的推移,RNN单元就失去了对很久之前信息的保存和处理的能力,而且存在着gradient vanishing问题。
所以有些特殊类型的RNN网络相继被提出,比如LSTM(long short term memory)和GRU(gated recurrent unit)(Chao,et al. 2014).这里我主要推导一下GRU参数的迭代过程
GRU单元结构如下图所示
1479126283494.jpg
数据流过程如下
其中表示Hadamard积,即对应元素乘积;下标表示节点的index,上标表示时刻;表示隐层到输出层的参数矩阵,分别是隐层和输出层的节点个数;分别表示输入和上一时刻隐层到更新门z的连接矩阵,表示输入数据的维度;分别表示输入和上一时刻隐层到重置门r的连接矩阵;分别表示输入和上一时刻的隐层到待选状态的连接矩阵。
针对于时刻t,使用链式求导法则,计算参数矩阵的梯度,其中E是代价函数,首先计算对隐层输出的梯度,因为隐层输出牵涉到多个时刻
所以
其中分别是对应激活函数的线性和部分
现在对参数计算梯度
令
则
将上面的式子矢量化(行向量)表示:
那接下来使用matlab来实现一个小例子,看看GRU的效果,同样是二进制相加的问题
- function error= GRUtest( )
% 初始化训练数据
uNum=16;%单元个数
maxInt=2^uNum;
% 初始化网络结构
xdim=2;
ydim=1;
hdim=16;
eta=0.1;
%初始化网络参数
Wy=rand(hdim,ydim)*2-1;
Wr=rand(xdim,hdim)*2-1;
Ur=rand(hdim,hdim)*2-1;
W =rand(xdim,hdim)*2-1;
U =rand(hdim,hdim)*2-1;
Wz=rand(xdim,hdim)*2-1;
Uz=rand(hdim,hdim)*2-1;
rvalues=zeros(uNum+1,hdim);
zvalues=zeros(uNum+1,hdim);
hbarvalues=zeros(uNum,hdim);
hvalues = zeros(uNum,hdim);
yvalues=zeros(uNum,ydim);
for p=1:10000
aInt=randi(maxInt/2);
bInt=randi(maxInt/2);
cInt=aInt+bInt;
at=dec2bin(aInt)-'0';
bt=dec2bin(bInt)-'0';
ct=dec2bin(cInt)-'0';
a=zeros(1,uNum);
b=zeros(1,uNum);
c=zeros(1,uNum);
a(1:size(at,2))=at(end:-1:1);
b(1:size(bt,2))=bt(end:-1:1);
c(1:size(ct,2))=ct(end:-1:1);
xvalues=[a;b]';
d=c';
% 前向计算
rvalues(1,:)=sigmoid(xvalues(1,:)*Wr);
hbarvalues(1,:)=outTanh(xvalues(1,:)*W);
zvalues(1,:)=sigmoid(xvalues(1,:)*Wz);
hvalues(1,:)=zvalues(1,:).*hbarvalues(1,:);
yvalues(1,:)=sigmoid(hvalues(1,:)*Wy);
for t=2:uNum
rvalues(t,:)=sigmoid(xvalues(t,:)*Wr+hvalues(t-1,:)*Ur);
hbarvalues(t,:)=outTanh(xvalues(t,:)*W+(rvalues(t,:).*hvalues(t-1,:))*U);
zvalues(t,:)=sigmoid(xvalues(t,:)*Wz+hvalues(t-1,:)*Uz);
hvalues(t,:)=(1-zvalues(t,:)).*hvalues(t-1,:)+zvalues(t,:).*hbarvalues(t,:);
yvalues(t,:)=sigmoid(hvalues(t,:)*Wy);
end
% 误差反向传播
delta_r_next=zeros(1,hdim);
delta_z_next=zeros(1,hdim);
delta_h_next=zeros(1,hdim);
delta_next=zeros(1,hdim);
dWy=zeros(hdim,ydim);
dWr=zeros(xdim,hdim);
dUr=zeros(hdim,hdim);
dW=zeros(xdim,hdim);
dU=zeros(hdim,hdim);
dWz=zeros(xdim,hdim);
dUz=zeros(hdim,hdim);
for t=uNum:-1:2
delta_y=(yvalues(t,:)-d(t,:)).*diffsigmoid(yvalues(t,:));
delta_h=delta_y*Wy'+delta_z_next*Uz'+delta_next*U'.*rvalues(t+1,:)+delta_r_next*Ur'+delta_h_next.*(1-zvalues(t+1,:));
delta_z=delta_h.*(hbarvalues(t,:)-hvalues(t-1,:)).*diffsigmoid(zvalues(t,:));
delta =delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:));
delta_r=hvalues(t-1,:).*((delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)))*U').*diffsigmoid(rvalues(t,:));
dWy=dWy+hvalues(t,:)'*delta_y;
dWz=dWz+xvalues(t,:)'*delta_z;
dUz=dUz+hvalues(t-1,:)'*delta_z;
dW =dW+xvalues(t,:)'*delta;
dU =dU+(rvalues(t,:).*hvalues(t-1,:))'*delta ;
dWr=dWr+xvalues(t,:)'*delta_r;
dUr=dUr+hvalues(t-1,:)'*delta_r;
delta_r_next=delta_r;
delta_z_next=delta_z;
delta_h_next=delta_h;
delta_next =delta;
end
t=1;
delta_y=(yvalues(t,:)-d(t,:)).*diffsigmoid(yvalues(t,:));
delta_h=delta_y*Wy'+delta_z_next*Uz'+delta_next*U'.*rvalues(t+1,:)+delta_r_next*Ur'+delta_h_next.*(1-zvalues(t+1,:));
delta_z=delta_h.*(hbarvalues(t,:)-0).*diffsigmoid(zvalues(t,:));
delta =delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:));
delta_r=0.*((delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)))*U').*diffsigmoid(rvalues(t,:));
dWy=dWy+hvalues(t,:)'*delta_y;
dWz=dWz+xvalues(t,:)'*delta_z;
dW =dW+xvalues(t,:)'*delta;
dWr=dWr+xvalues(t,:)'*delta_r;
Wy = Wy-eta*dWy;
Wr = Wr-eta*dWr;
Ur = Ur-eta*dUr;
W = W -eta*dW;
U = U-eta*dU;
Wz = Wz-eta*dWz;
Uz = Uz-eta*dUz;
error = (norm(yvalues-d,2))/2.0;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if mod(p,500)==0
fprintf('******************第%s次迭代****************\n',int2str(p));
yvalues=round(yvalues(end:-1:1));
y=bin2dec(int2str(yvalues'));
fprintf('y=%d\n',y);
fprintf('c=%d\n',cInt);
fprintf('样本误差:e=%f\n',error);
end
end
end
function f=sigmoid(x)
f=1./(1+exp(-x));
end
function fd = diffsigmoid(f)
fd=f.*(1-f);
end
function g=outTanh(x)
g=1-2./(1+exp(2*x));
end
function gd=diffoutTanh(g)
gd=1-g.^2;
end
部分实验结果
1479392393541.jpg