代码如下:
import torchvision
from PIL import Image
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torchvision.models.feature_extraction import create_feature_extractor
transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
feature_extractor = create_feature_extractor(model, return_nodes={"conv1": "output"})
original_img = Image.open("dog.jpg")
img = transform(original_img).unsqueeze(0)
out = feature_extractor(img)
plt.imshow(out["output"][0].transpose(0, 1).sum(1).detach().numpy())
plt.show()
效果如下: