热门

最新

红包

立Flag

投票

同城

我的

发布
m0_51335239
Wells0
3 年前
truem0_51335239

代码:
import cv2
import torchvision.transforms as transforms
import torch

to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485,0.456,0.406), (0.229, 0.224,0.225))
])

filename1 = "modelSegmentation.pt"
model = torch.jit.load(filename1)
model.eval()

image = cv2.imread('a.jpg')
H, W = image.shape[:2]
image = cv2.resize(image, (512, 512), cv2.INTER_AREA)
origin_image = image.copy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

image = to_tensor(image).unsqueeze(0)

output = model(image)
print(output.size())

label = output.argmax(1) \
.repeat(3, 1, 1) \
.to(dtype=torch.uint8) \
.permute((1, 2, 0)) \
.numpy()
origin_image *= label
origin_image[label == 0] = 255

origin_image = cv2.resize(origin_image, (W, H), cv2.INTER_CUBIC)
cv2.imshow('test', origin_image)
cv2.waitKey()

结果:

CSDN App 扫码分享
分享
评论
点赞
打赏
  • 复制链接
  • 举报
下一条:
i
立即登录