已屏蔽 原因:{{ notice.reason }}已屏蔽
{{notice.noticeContent}}
~~空空如也

实现的实验代码是这两天下班后顺着思路往下写的,比较丑陋,回头整理封装一下。有几个值得一提的细节: 1 )虽然标准的VGG16图像识别应用中输入图片尺寸被限制在了224x224,但在这里你可以想用多大用多大(因为forward pass不会传到FC layer,btw无法想象生成224x224的艺术图有卵用?);2)原文使用的L-BFGS在tensorflow中并没有加入,所以使用AdamOptimizer取代;3)由于要对输入图片求导,所以不要用XXXXXnstant作为输入。。4)这玩意儿调参很麻烦

计算content feature的代码:

## compute content feature
with tf.Graph().as_default():
    image=tf.constant(content,dtype=tf.float32)
    with slim.arg_scope(vgg.vgg_arg_scope()):
        _, end_points=vgg.vgg_16(image,spatial_squeeze=False)
    ## load vgg16 model    
    init_fn = slim.assign_from_checkpoint_fn(
    os.path.join('./model/', 'vgg_16.ckpt'),
    slim.get_model_variables())
    ## compute content Feature
    content_feature=tf.reshape(end_points['vgg_16/conv4/conv4_2'],shape=[-1])
    with tf.Session() as sess:
        init_fn(sess)
        content_ft_val=sess.run(content_feature)    

计算style feature的代码(用了许多hard-code):

style_list=['vgg_16/conv1/conv1_1',
            'vgg_16/conv2/conv2_1',
            'vgg_16/conv3/conv3_1',
            'vgg_16/conv4/conv4_1',
            'vgg_16/conv5/conv5_1']
## compute style feature
def style_gram_compute(idx):
    dim=end_points[style_list[idx]].get_shape().as_list()
    style_feature=tf.reshape(end_points[style_list[idx]],shape=[-1,dim[-1]])
    style_gram=tf.reshape(
tf.matmul(style_feature,style_feature,transpose_a=True),
shape=[-1])/(dim[1]*dim[2])
    return style_gram,dim

## compute style feature
with tf.Graph().as_default():
    image=tf.constant(style,dtype=tf.float32)
    with slim.arg_scope(vgg.vgg_arg_scope()):
        _, end_points=vgg.vgg_16(image,spatial_squeeze=False)
    ## load vgg16 model    
    init_fn = slim.assign_from_checkpoint_fn(
    os.path.join('./model/', 'vgg_16.ckpt'),
    slim.get_model_variables())
    ## compute content Feature
    style_gram_1,dim1=style_gram_compute(0)
    style_gram_2,dim2=style_gram_compute(1)
    style_gram_3,dim3=style_gram_compute(2)
    style_gram_4,dim4=style_gram_compute(3)
    style_gram_5,dim5=style_gram_compute(4)
    with tf.Session() as sess:
        init_fn(sess)
        style_gram_val_1=sess.run(style_gram_1)    
        style_gram_val_2=sess.run(style_gram_2)
        style_gram_val_3=sess.run(style_gram_3)
        style_gram_val_4=sess.run(style_gram_4)
        style_gram_val_5=sess.run(style_gram_5)

主优化程序:

def style_loss_compute(style_gram_val,target_style_gram):
    style_gram_const=tf.constant(style_gram_val)
    style_loss=tf.reduce_mean(tf.squared_difference(target_style_gram,style_gram_const))
    return style_loss

with tf.Graph().as_default():
    ### generate target image ####
    image=tf.Variable(tf.random_normal(shape=(1,h,w,3),stddev=np.std(content)*0.1,dtype=tf.float32))
    with slim.arg_scope(vgg.vgg_arg_scope()):
        _, end_points=vgg.vgg_16(image,spatial_squeeze=False)
    ## load vgg16 model    
    init_fn = slim.assign_from_checkpoint_fn(
    os.path.join('./model/', 'vgg_16.ckpt'),
    slim.get_model_variables())
    ## compute target content Feature
    target_content_feature=tf.reshape(end_points['vgg_16/conv4/conv4_2'],shape=[-1])
    ## compute target style Feature
    target_style_gram_1,_=style_gram_compute(0)
    target_style_gram_2,_=style_gram_compute(1)
    target_style_gram_3,_=style_gram_compute(2)
    target_style_gram_4,_=style_gram_compute(3)
    target_style_gram_5,_=style_gram_compute(4)
    ## define content loss
    content_loss=tf.reduce_sum(tf.squared_difference(target_content_feature,content_ft_val))
    ## define style loss
    style_loss_1=style_loss_compute(style_gram_val_1,target_style_gram_1)
    style_loss_2=style_loss_compute(style_gram_val_2,target_style_gram_2)
    style_loss_3=style_loss_compute(style_gram_val_3,target_style_gram_3)
    style_loss_4=style_loss_compute(style_gram_val_4,target_style_gram_4)
    style_loss_5=style_loss_compute(style_gram_val_5,target_style_gram_5)
    w0=1.0/5
    style_loss=w0*style_loss_1+w0*style_loss_2+w0*style_loss_3+w0*style_loss_4+w0*style_loss_5
    ## define total loss
    total_loss=1e-3*content_loss+style_loss
    ## define train step
    global_step = tf.Variable(0, trainable=False)
    starter_learning_rate = 1.
    learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
                                           100, 0.95, staircase=True)
    train_op=tf.train.AdamOptimizer(learning_rate).minimize(loss=total_loss,
                                                            var_list=[image],
                                                            global_step=global_step)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        init_fn(sess)
        max_iter=1000
        for iter in range(max_iter):
            _,loss_content_val,loss_style_val=sess.run([train_op,content_loss,style_loss]) 
            sys.stdout.write('\r'+'iteration:%4d-----Loss_content: %0f6-----Loss_style: %0f6-----learning_rate: %0f5'
                             %(iter,loss_content_val,loss_style_val,learning_rate.eval()))
            output=image.eval()
文号 / 828501

千古风流
名片发私信
学术分 2
总主题 34 帖总回复 364 楼拥有证书:专家 进士 老干部 学者 机友 笔友
注册于 2012-09-03 13:32最后登录 2024-07-10 05:46
主体类型:个人
所属领域:无
认证方式:手机号
IP归属地:未同步

个人简介

Machine Learning, computer vision enthusiast

Google

文件下载
加载中...
{{errorInfo}}
{{downloadWarning}}
你在 {{downloadTime}} 下载过当前文件。
文件名称:{{resource.defaultFile.name}}
下载次数:{{resource.hits}}
上传用户:{{uploader.username}}
所需积分:{{costScores}},{{holdScores}}下载当前附件免费{{description}}
积分不足,去充值
文件已丢失

当前账号的附件下载数量限制如下:
时段 个数
{{f.startingTime}}点 - {{f.endTime}}点 {{f.fileCount}}
视频暂不能访问,请登录试试
仅供内部学术交流或培训使用,请先保存到本地。本内容不代表科创观点,未经原作者同意,请勿转载。
音频暂不能访问,请登录试试
投诉或举报
加载中...
{{tip}}
请选择违规类型:
{{reason.type}}

空空如也

插入资源
全部
图片
视频
音频
附件
全部
未使用
已使用
正在上传
空空如也~
上传中..{{f.progress}}%
处理中..
上传失败,点击重试
等待中...
{{f.name}}
空空如也~
(视频){{r.oname}}
{{selectedResourcesId.indexOf(r.rid) + 1}}
处理中..
处理失败
插入表情
我的表情
共享表情
Emoji
上传
注意事项
最大尺寸100px,超过会被压缩。为保证效果,建议上传前自行处理。
建议上传自己DIY的表情,严禁上传侵权内容。
点击重试等待上传{{s.progress}}%处理中...已上传,正在处理中
空空如也~
处理中...
处理失败
加载中...
草稿箱
加载中...
此处只插入正文,如果要使用草稿中的其余内容,请点击继续创作。
{{fromNow(d.toc)}}
{{getDraftInfo(d)}}
标题:{{d.t}}
内容:{{d.c}}
继续创作
删除插入插入
插入公式
评论控制
加载中...
文号:{{pid}}
加载中...
详情
详情
推送到专栏从专栏移除
设为匿名取消匿名
查看作者
回复
只看作者
加入收藏取消收藏
收藏
取消收藏
折叠回复
置顶取消置顶
评学术分
鼓励
设为精选取消精选
管理提醒
编辑
通过审核
评论控制
退修或删除
历史版本
违规记录
投诉或举报
加入黑名单移除黑名单
查看IP
{{format('YYYY/MM/DD HH:mm:ss', toc)}}
ID: {{user.uid}}