代码如下:

import torchvision
from PIL import Image
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torchvision.models.feature_extraction import create_feature_extractor

transform = transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)

feature_extractor = create_feature_extractor(model, return_nodes={"conv1": "output"})

original_img = Image.open("dog.jpg")

img = transform(original_img).unsqueeze(0)

out = feature_extractor(img)

plt.imshow(out["output"][0].transpose(0, 1).sum(1).detach().numpy())

plt.show()

效果如下: