在Python中使用MNIST较便捷的方法是安装MNIST工具包,在这里可以下载安装。
安装后便可如下直接导入训练数据和测试数据:
<code class="language-python">import matplotlib.pyplot as plt import numpy as np from mnist import MNIST %matplotlib inline mndata = MNIST('Path to your data/MNIST') images, labels=mndata.load_training() images_test, labels_test=mndata.load_testing() </code>
计算Loss function,同样得解决指数函数数据溢出的问题:
<code class="language-python">def calculate_loss(w,x,y): theta=np.dot(np.transpose(w),x) ### first part l1=np.multiply(theta,y) ### second part thetaMax=np.max(theta,axis=0) theta-=thetaMax l2=np.log(np.sum(np.exp(theta),axis=0)) ### sum together L=np.sum(-l1+l2+thetaMax) return L </code>
计算导数:
<code class="language-python">def calculate_gradient(w,x_batch,y_batch): theta=np.dot(np.transpose(w),x_batch) theta-=np.max(theta,axis=0) mu=np.exp(theta) muSum=np.sum(mu,axis=0) mu=mu/muSum mu-=y_batch dL=np.dot(x_batch,np.transpose(mu)) return dL </code>
用SGD训练:
<code class="language-python">def train(x,y,batch_sz,lr,max_iter,loss_thresh): ### bias trick ### batch_sz=100 d_len=x.shape[1] d_dim=x.shape[0]+1 d_class=np.max(y)-np.min(y)+1 x_b=np.concatenate((x,np.ones((1,d_len))),axis=0) yc=np.zeros((d_class,d_len)) yc[y,np.arange(d_len)]=1; w=np.zeros((d_dim,d_class),dtype=np.float) Loss_old=0 Loss=[] stepCnt=0 ### Run SGD ### for iter in range(max_iter): ### sample a mini batch ### batch=np.arange(d_len) np.random.shuffle(batch) x_batch=x_b[:,batch[:batch_sz]] y_batch=yc[:,batch[:batch_sz]] ### update weight ### dL=calculate_gradient(w,x_batch,y_batch) w-=lr*dL ### record loss changes ### Loss.append(calculate_loss(w,x_b,yc)) ### learning rate annealing ### stepCnt+=1 if stepCnt==10: stepCnt=0 lr*=0.8 ### Check if converge ### if abs(Loss[-1]-Loss_old)<loss_thresh: break loss_old="Loss[-1]" return w,loss < code></loss_thresh:></code>
主函数:
<code class="language-python">x_train=np.transpose(np.matrix(images)) y_train=np.array(labels) lr=0.5 batch_sz=50 Max_iter=200 loss_thresh=1e-3 w,Loss = train(x_train,y_train,batch_sz,lr,Max_iter,loss_thresh) print w plt.plot(Loss) </code>
时段 | 个数 |
---|---|
{{f.startingTime}}点 - {{f.endTime}}点 | {{f.fileCount}} |