使用pytorch自带的预训练模型提取特征

import torch
import torch.nn as nn
import torchvision

resnet = torchvision.models.resnet50(pretrained=True)
resnet.fc = nn.Linear(2048, 2048, bias=False)
torch.nn.init.eye(resnet.fc.weight)

使用pytorch自带的预训练模型来训练自己的分类器

import torch
import torch.nn as nn
import torchvision

resnet = torchvision.models.resnet50(pretrained=True)
resnet.fc = nn.Linear(resnet.fc.in_features , num)    # num是分类类别数