代码:
import cv2
import torchvision.transforms as transforms
import torch
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485,0.456,0.406), (0.229, 0.224,0.225))
])
filename1 = "modelSegmentation.pt"
model = torch.jit.load(filename1)
model.eval()
image = cv2.imread('a.jpg')
H, W = image.shape[:2]
image = cv2.resize(image, (512, 512), cv2.INTER_AREA)
origin_image = image.copy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = to_tensor(image).unsqueeze(0)
output = model(image)
print(output.size())
label = output.argmax(1) \
.repeat(3, 1, 1) \
.to(dtype=torch.uint8) \
.permute((1, 2, 0)) \
.numpy()
origin_image *= label
origin_image[label == 0] = 255
origin_image = cv2.resize(origin_image, (W, H), cv2.INTER_CUBIC)
cv2.imshow('test', origin_image)
cv2.waitKey()
结果: