使用pytorch自带的预训练模型提取特征
import torch
import torch.nn as nn
import torchvision
resnet = torchvision.models.resnet50(pretrained=True)
resnet.fc = nn.Linear(2048, 2048, bias=False)
torch.nn.init.eye(resnet.fc.weight)
使用pytorch自带的预训练模型来训练自己的分类器
import torch
import torch.nn as nn
import torchvision
resnet = torchvision.models.resnet50(pretrained=True)
resnet.fc = nn.Linear(resnet.fc.in_features , num) # num是分类类别数