已屏蔽 原因:{{ notice.reason }}已屏蔽
{{notice.noticeContent}}
~~空空如也

在Python中使用MNIST较便捷的方法是安装MNIST工具包,在这里可以下载安装。

安装后便可如下直接导入训练数据和测试数据:

import matplotlib.pyplot as plt
import numpy as np
from mnist import MNIST
%matplotlib inline
mndata = MNIST('Path to your data/MNIST')
images, labels=mndata.load_training()
images_test, labels_test=mndata.load_testing()

计算Loss function,同样得解决指数函数数据溢出的问题:

def calculate_loss(w,x,y):
    theta=np.dot(np.transpose(w),x)
    ### first part
    l1=np.multiply(theta,y)
    ### second part
    thetaMax=np.max(theta,axis=0)
    theta-=thetaMax
    l2=np.log(np.sum(np.exp(theta),axis=0))
    ### sum together
    L=np.sum(-l1+l2+thetaMax)
    return L

计算导数:

def calculate_gradient(w,x_batch,y_batch):
    theta=np.dot(np.transpose(w),x_batch)
    theta-=np.max(theta,axis=0)
    mu=np.exp(theta)
    muSum=np.sum(mu,axis=0)
    mu=mu/muSum
    mu-=y_batch
    dL=np.dot(x_batch,np.transpose(mu))
    return dL

用SGD训练:

def train(x,y,batch_sz,lr,max_iter,loss_thresh):
    ### bias trick ###
    batch_sz=100
    d_len=x.shape[1]
    d_dim=x.shape[0]+1
    d_class=np.max(y)-np.min(y)+1
    x_b=np.concatenate((x,np.ones((1,d_len))),axis=0)
    yc=np.zeros((d_class,d_len))
    yc[y,np.arange(d_len)]=1;
    w=np.zeros((d_dim,d_class),dtype=np.float)
    
    Loss_old=0
    Loss=[]
    stepCnt=0
    ### Run SGD ###
    for iter in range(max_iter):
        ### sample a mini batch ###
        batch=np.arange(d_len)
        np.random.shuffle(batch)
        x_batch=x_b[:,batch[:batch_sz]]
        y_batch=yc[:,batch[:batch_sz]]
        ### update weight ###
        dL=calculate_gradient(w,x_batch,y_batch)
        w-=lr*dL
        ### record loss changes ###
        Loss.append(calculate_loss(w,x_b,yc))
        ### learning rate annealing ###
        stepCnt+=1
        if stepCnt==10:
            stepCnt=0
            lr*=0.8
        ### Check if converge ###
        if abs(Loss[-1]-Loss_old)<loss_thresh: break loss_old="Loss[-1]" return w,loss < code></loss_thresh:>

主函数:

x_train=np.transpose(np.matrix(images))
y_train=np.array(labels)
lr=0.5
batch_sz=50
Max_iter=200
loss_thresh=1e-3
w,Loss = train(x_train,y_train,batch_sz,lr,Max_iter,loss_thresh)
print w
plt.plot(Loss)
文号 / 823141

千古风流
名片发私信
学术分 2
总主题 34 帖总回复 364 楼拥有证书:专家 进士 老干部 学者 机友 笔友
注册于 2012-09-03 13:32最后登录 2024-07-10 05:46
主体类型:个人
所属领域:无
认证方式:手机号
IP归属地:未同步

个人简介

Machine Learning, computer vision enthusiast

Google

文件下载
加载中...
{{errorInfo}}
{{downloadWarning}}
你在 {{downloadTime}} 下载过当前文件。
文件名称:{{resource.defaultFile.name}}
下载次数:{{resource.hits}}
上传用户:{{uploader.username}}
所需积分:{{costScores}},{{holdScores}}下载当前附件免费{{description}}
积分不足,去充值
文件已丢失

当前账号的附件下载数量限制如下:
时段 个数
{{f.startingTime}}点 - {{f.endTime}}点 {{f.fileCount}}
视频暂不能访问,请登录试试
仅供内部学术交流或培训使用,请先保存到本地。本内容不代表科创观点,未经原作者同意,请勿转载。
音频暂不能访问,请登录试试
投诉或举报
加载中...
{{tip}}
请选择违规类型:
{{reason.type}}

空空如也

插入资源
全部
图片
视频
音频
附件
全部
未使用
已使用
正在上传
空空如也~
上传中..{{f.progress}}%
处理中..
上传失败,点击重试
等待中...
{{f.name}}
空空如也~
(视频){{r.oname}}
{{selectedResourcesId.indexOf(r.rid) + 1}}
处理中..
处理失败
插入表情
我的表情
共享表情
Emoji
上传
注意事项
最大尺寸100px,超过会被压缩。为保证效果,建议上传前自行处理。
建议上传自己DIY的表情,严禁上传侵权内容。
点击重试等待上传{{s.progress}}%处理中...已上传,正在处理中
空空如也~
处理中...
处理失败
加载中...
草稿箱
加载中...
此处只插入正文,如果要使用草稿中的其余内容,请点击继续创作。
{{fromNow(d.toc)}}
{{getDraftInfo(d)}}
标题:{{d.t}}
内容:{{d.c}}
继续创作
删除插入插入
插入公式
评论控制
加载中...
文号:{{pid}}
加载中...
详情
详情
推送到专栏从专栏移除
设为匿名取消匿名
查看作者
回复
只看作者
加入收藏取消收藏
收藏
取消收藏
折叠回复
置顶取消置顶
评学术分
鼓励
设为精选取消精选
管理提醒
编辑
通过审核
评论控制
退修或删除
历史版本
违规记录
投诉或举报
加入黑名单移除黑名单
查看IP
{{format('YYYY/MM/DD HH:mm:ss', toc)}}
ID: {{user.uid}}