pytorch实现用Resnet提取特征并保存为txt文件的方法

接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。

以下是提取一张jpg图像的特征的程序:

# -*- coding: utf-8 -*-

import os.path

import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable 

import numpy as np
from PIL import Image 

features_dir = './features'

img_path = "hymenoptera_data/train/ants/0013035.jpg"
file_name = img_path.split('/')[-1]
feature_path = os.path.join(features_dir, file_name + '.txt')

transform1 = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()  ]
)

img = Image.open(img_path)
img1 = transform1(img)

#resnet18 = models.resnet18(pretrained = True)
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)

for param in resnet50_feature_extractor.parameters():
  param.requires_grad = False
#resnet152 = models.resnet152(pretrained = True)
#densenet201 = models.densenet201(pretrained = True)
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
#y1 = resnet18(x)
y = resnet50_feature_extractor(x)
y = y.data.numpy()
np.savetxt(feature_path, y, delimiter=',')
#y3 = resnet152(x)
#y4 = densenet201(x)

y_ = np.loadtxt(feature_path, delimiter=',').reshape(1, 2048)

以下是提取一个文件夹下所有jpg、jpeg图像的程序:

# -*- coding: utf-8 -*-
import os, torch, glob
import numpy as np
from torch.autograd import Variable
from PIL import Image
from torchvision import models, transforms
import torch.nn as nn
import shutil
data_dir = './hymenoptera_data'
features_dir = './features'
shutil.copytree(data_dir, os.path.join(features_dir, data_dir[2:]))

def extractor(img_path, saved_path, net, use_gpu):
  transform = transforms.Compose([
      transforms.Scale(256),
      transforms.CenterCrop(224),
      transforms.ToTensor()  ]
  )

  img = Image.open(img_path)
  img = transform(img)

  x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
  if use_gpu:
    x = x.cuda()
    net = net.cuda()
  y = net(x).cpu()
  y = y.data.numpy()
  np.savetxt(saved_path, y, delimiter=',')

if __name__ == '__main__':
  extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']

  files_list = []
  sub_dirs = [x[0] for x in os.walk(data_dir) ]
  sub_dirs = sub_dirs[1:]
  for sub_dir in sub_dirs:
    for extention in extensions:
      file_glob = os.path.join(sub_dir, '*.' + extention)
      files_list.extend(glob.glob(file_glob))

  resnet50_feature_extractor = models.resnet50(pretrained = True)
  resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
  torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
  for param in resnet50_feature_extractor.parameters():
    param.requires_grad = False  

  use_gpu = torch.cuda.is_available()

  for x_path in files_list:
    print(x_path)
    fx_path = os.path.join(features_dir, x_path[2:] + '.txt')
    extractor(x_path, fx_path, resnet50_feature_extractor, use_gpu)

另外最近发现一个很简单的提取不含FC层的网络的方法:

    resnet = models.resnet152(pretrained=True)
    modules = list(resnet.children())[:-1]   # delete the last fc layer.
    convnet = nn.Sequential(*modules)

另一种更简单的方法:

resnet = models.resnet152(pretrained=True)
del resnet.fc

以上这篇pytorch实现用Resnet提取特征并保存为txt文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • python PyTorch预训练示例

    前言 最近使用PyTorch感觉妙不可言,有种当初使用Keras的快感,而且速度还不慢.各种设计直接简洁,方便研究,比tensorflow的臃肿好多了.今天让我们来谈谈PyTorch的预训练,主要是自己写代码的经验以及论坛PyTorch Forums上的一些回答的总结整理. 直接加载预训练模型 如果我们使用的模型和原模型完全一样,那么我们可以直接加载别人训练好的模型: my_resnet = MyResNet(*args, **kwargs) my_resnet.load_state_dict(

  • 获取Pytorch中间某一层权重或者特征的例子

    问题:训练好的网络模型想知道中间某一层的权重或者看看中间某一层的特征,如何处理呢? 1.获取某一层权重,并保存到excel中; 以resnet18为例说明: import torch import pandas as pd import numpy as np import torchvision.models as models resnet18 = models.resnet18(pretrained=True) parm={} for name,parameters in resnet18

  • pytorch 输出中间层特征的实例

    pytorch 输出中间层特征: tensorflow输出中间特征,2种方式: 1. 保存全部模型(包括结构)时,需要之前先add_to_collection 或者 用slim模块下的end_points 2. 只保存模型参数时,可以读取网络结构,然后按照对应的中间层输出即可. but:Pytorch 论坛给出的答案并不好用,无论是hooks,还是重建网络并去掉某些层,这些方法都不好用(在我看来). 我们可以在创建网络class时,在forward时加入一个dict 或者 list,dict是将

  • pytorch实现用Resnet提取特征并保存为txt文件的方法

    接触pytorch一天,发现pytorch上手的确比TensorFlow更快.可以更方便地实现用预训练的网络提特征. 以下是提取一张jpg图像的特征的程序: # -*- coding: utf-8 -*- import os.path import torch import torch.nn as nn from torchvision import models, transforms from torch.autograd import Variable import numpy as np

  • Python3将数据保存为txt文件的方法

    Python3将数据保存为txt文件的方法,具体内容如下所示: f = open("data/model_Weight.txt",'a') #若文件不存在,系统自动创建.'a'表示可连续写入到文件,保留原内容,在原 #内容之后写入.可修改该模式('w+','w','wb'等) f.write("hello,sha") #将字符串写入文件中 f.write("\n") #换行 if __name__=='__main__': fw = open(&

  • python使用PyGame绘制图像并保存为图片文件的方法

    本文实例讲述了python使用PyGame绘制图像并保存为图片文件的方法.分享给大家供大家参考.具体实现方法如下: ''' pg_draw_circle_save101.py draw a blue solid circle on a white background save the drawing to an image file for result see http://prntscr.com/156wxi tested with Python 2.7 and PyGame 1.9.2

  • VB打开与保存txt文件的方法

    本文实例讲述了VB打开与保存txt文件的方法.分享给大家供大家参考.具体如下: Private Sub cmdsave_Click() Dim filelocation As String ' loads save as box commondialog1.ShowSave filelocation = commondialog1.FileName ' append saves over file if it assists Open filelocation For Append As #1

  • python保存字符串到文件的方法

    本文实例讲述了python保存字符串到文件的方法.分享给大家供大家参考.具体实现方法如下: def save(filename, contents): fh = open(filename, 'w') fh.write(contents) fh.close() save('file.name', 'some stuff') 希望本文所述对大家的Python程序设计有所帮助.

  • php限制上传文件类型并保存上传文件的方法

    本文实例讲述了php限制上传文件类型并保存上传文件的方法.分享给大家供大家参考.具体如下: 下面的代码演示了php中如何获取用户上传的文件,并限制文件类型的一般图片文件,最后保存到服务器 <?php $allowedExts = array("gif", "jpeg", "jpg", "png"); $extension = end(explode(".", $_FILES["file&qu

  • python将pandas datarame保存为txt文件的实例

    CSV means Comma Separated Values. It is plain text (ansi). The CSV ("Comma Separated Value") file format is often used to exchange data between disparate applications. The file format, as it is used in Microsoft Excel, has become a pseudo standa

  • Android将String保存为SD卡中TXT文件的方法

    如下所示: public static void stringTxt(String str){ try { FileWriter fw = new FileWriter("/sdcard/aaa" + "/cmd.txt");//SD卡中的路径 fw.flush(); fw.write(str); fw.close(); } catch (Exception e) { e.printStackTrace(); } } 以上这篇Android将String保存为SD卡

  • Vim 强制保存只读类型文件的方法

    发现问题: 在使用vim时,当我们以普通用户去打开一个只有root用户才有权限操作的文件时,我们编辑完成之后,正要保存,却发现,这个文件我们没有权限修改. 每次遇到这样的问题,我都很头疼,好不容易把文件编辑完了,却无法保存,就只能放弃,然后退出,再以root权限打开,重新编辑. 我总是相信,所有的问题都有解决的方法.通过查阅资料,终于解决了这个问题. 解决方案: 底行命令模式执行: :w !sudo tee % w: 表示保存文件 !: 表示执行外部命令 tee: linux命令,这个有点复杂,

  • PHP保存带BOM文件的方法

    文本开头加上chr ( 239 ) . chr ( 187 ) . chr ( 191 )即可.

随机推荐