写点什么

【深度学习】AI 一键换天

作者:逝缘~
  • 2022 年 7 月 07 日
  • 本文字数:4876 字

    阅读完需:约 16 分钟

【深度学习】AI一键换天

1.实验目标

1.了解图像分割的基本应用;

2.了解运动估计的基本应用;

3.了解图像混合的基本应用。

2.案例内容介绍

案例链接OBS - JupyterLab (huaweicloud.com)

无论是拍人拍景或是其他,“天空”都可以说是摄像中的关键元素。比如,一张平平无奇的景色图加上落日余晖的天空色调,是不是有内味了?(随手就可以变换出各种天空效果:晴天、彩虹、晚霞、暮光、夕阳等等)

当然,自然的天空还不是最酷炫的,今天给大家介绍一款基于原生视频的 AI 处理方法,不仅可以一键切置换天空背景,还可以打造任意“天空之城”。比如,《星际迷航》等科幻电影中经常出现的浩瀚星空、宇宙飞船,也可以利用这项技术融入随手拍的视频中,路人拍摄的公路片也能秒变科幻片,画面毫无违和感。好像只要脑洞够大,利用这项 AI 技术,可以创作无限种玩法。

基于视觉的视频天空替换和协调方法,该方法可以在具有可控风格的视频中自动生成逼真的天空背景。与以前的天空编辑方法专注于静态照片或需要集成在智能手机中的惯性测量装置拍摄视频不同,该方法完全基于视觉,对捕获设备没有任何要求,并且可以很好地应用于在线或离线处理场景。

算法流程大致可以分为三个步骤:

(1) 天空抠图

这一步主要是通过对蒙版数据集进行训练,将图片中的天空和其它物体进行像素级的划分,将天空部分从图片中分离。

(2) 运动估计

对图片中物体的位移情况进行分析,预估相机的移动方向,使替换后的天空和之前的天空位移一致。

(3) 图像混合

将去掉天空的原视频和要替换后的天空视频进行融合,同时对非天空的部分采用色彩叠加,使天空和其它物体的视觉效果相近,使视频效果更加逼真。


最后,算法使用数据增强的方法模拟出同一张图片在不同光照和天气的情况下的图片,使算法具有更强的适应性。

3.实验步骤

3.1 安装和导入依赖包

import osimport moxing as mox file_name = 'SkyAR'if not os.path.exists(file_name):    mox.file.copy('obs://modelarts-labs-bj4-v2/case_zoo/SkyAR/SkyAR.zip', 'SkyAR.zip')    os.system('unzip SkyAR.zip')    os.system('rm SkyAR.zip')mox.file.copy_parallel('obs://modelarts-labs-bj4-v2/case_zoo/SkyAR/resnet50-19c8e357.pth', '/home/ma-user/.cache/torch/checkpoints/resnet50-19c8e357.pth')
复制代码


!pip uninstall opencv-python -y!pip uninstall opencv-contrib-python -y
!pip install opencv-contrib-python==4.5.3.56
复制代码


cd SkyAR/
复制代码


import timeimport jsonimport base64import numpy as npimport matplotlib.pyplot as pltimport cv2import argparsefrom networks import *from skyboxengine import *import utilsimport torchfrom IPython.display import clear_output, Image, display, HTML %matplotlib inline # 如果存在GPU则在GPU上面运行device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
复制代码

3.2 设定算法参数 

SkyAR 算法提供了以下五个参数来调整换天的效果:

skybox_center_crop: 天空体中心偏移

auto_light_matching: 是否自动亮度匹配

relighting_factor: 补光

recoloring_factor: 重新着色

halo_effect: 是否开启光环效应

且提供了 datadir 和 skybox 两个参数来指定待处理的原视频和要替换的天空图片,通过路径进行指定即可,如下所示:

parameter = {  "net_G": "coord_resnet50",  "ckptdir": "./checkpoints_G_coord_resnet50",   "input_mode": "video",  "datadir": "./test_videos/sky.mp4",  # 待处理的原视频路径  "skybox": "sky.jpg",  # 要替换的天空图片路径   "in_size_w": 384,  "in_size_h": 384,  "out_size_w": 845,  "out_size_h": 480,   "skybox_center_crop": 0.5,  "auto_light_matching": False,  "relighting_factor": 0.8,  "recoloring_factor": 0.5,  "halo_effect": True,   "output_dir": "./jpg_output",  "save_jpgs": False} str_json = json.dumps(parameter)
复制代码

3.3 预览一下原视频

video_name = parameter['datadir'] def arrayShow(img):    img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25, interpolation=cv2.INTER_NEAREST)    _,ret = cv2.imencode('.jpg', img)    return Image(data=ret) # 打开一个视频流cap = cv2.VideoCapture(video_name) frame_id = 0while True:    try:        clear_output(wait=True) # 清除之前的显示        ret, frame = cap.read() # 读取一帧图片        if ret:            frame_id += 1            if frame_id > 200:                break            cv2.putText(frame, str(frame_id), (5, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)  # 画frame_id            tmp = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 转换色彩模式            img = arrayShow(frame)            display(img) # 显示图片            time.sleep(0.05) # 线程睡眠一段时间再处理下一帧图片        else:            break    except KeyboardInterrupt:        cap.release()cap.release()
复制代码

预览一下要替换的天空图片

img= cv2.imread(os.path.join('./skybox', parameter['skybox']))img2 = img[:, :, ::-1]plt.imshow(img2)
复制代码

3.4 定义 SkyFilter 类 

class Struct:    def __init__(self, **entries):        self.__dict__.update(entries)        def parse_config():    data = json.loads(str_json)    args = Struct(**data)     return args args = parse_config()
复制代码


class SkyFilter():     def __init__(self, args):         self.ckptdir = args.ckptdir        self.datadir = args.datadir        self.input_mode = args.input_mode         self.in_size_w, self.in_size_h = args.in_size_w, args.in_size_h        self.out_size_w, self.out_size_h = args.out_size_w, args.out_size_h         self.skyboxengine = SkyBox(args)         self.net_G = define_G(input_nc=3, output_nc=1, ngf=64, netG=args.net_G).to(device)        self.load_model()         self.video_writer = cv2.VideoWriter('out.avi',                                            cv2.VideoWriter_fourcc(*'MJPG'),                                            20.0,                                            (args.out_size_w, args.out_size_h))        self.video_writer_cat = cv2.VideoWriter('compare.avi',                                                cv2.VideoWriter_fourcc(*'MJPG'),                                                20.0,                                                (2*args.out_size_w, args.out_size_h))         if os.path.exists(args.output_dir) is False:            os.mkdir(args.output_dir)         self.output_img_list = []         self.save_jpgs = args.save_jpgs      def load_model(self):        # 加载预训练的天空抠图模型        print('loading the best checkpoint...')        checkpoint = torch.load(os.path.join(self.ckptdir, 'best_ckpt.pt'),                                map_location=device)        self.net_G.load_state_dict(checkpoint['model_G_state_dict'])        self.net_G.to(device)        self.net_G.eval()      def write_video(self, img_HD, syneth):         frame = np.array(255.0 * syneth[:, :, ::-1], dtype=np.uint8)        self.video_writer.write(frame)         frame_cat = np.concatenate([img_HD, syneth], axis=1)        frame_cat = np.array(255.0 * frame_cat[:, :, ::-1], dtype=np.uint8)        self.video_writer_cat.write(frame_cat)         # 定义结果缓冲区        self.output_img_list.append(frame_cat)      def synthesize(self, img_HD, img_HD_prev):         h, w, c = img_HD.shape         img = cv2.resize(img_HD, (self.in_size_w, self.in_size_h))         img = np.array(img, dtype=np.float32)        img = torch.tensor(img).permute([2, 0, 1]).unsqueeze(0)         with torch.no_grad():            G_pred = self.net_G(img.to(device))            G_pred = torch.nn.functional.interpolate(G_pred,                                                     (h, w),                                                     mode='bicubic',                                                     align_corners=False)            G_pred = G_pred[0, :].permute([1, 2, 0])            G_pred = torch.cat([G_pred, G_pred, G_pred], dim=-1)            G_pred = np.array(G_pred.detach().cpu())            G_pred = np.clip(G_pred, a_max=1.0, a_min=0.0)         skymask = self.skyboxengine.skymask_refinement(G_pred, img_HD)         syneth = self.skyboxengine.skyblend(img_HD, img_HD_prev, skymask)         return syneth, G_pred, skymask      def cvtcolor_and_resize(self, img_HD):         img_HD = cv2.cvtColor(img_HD, cv2.COLOR_BGR2RGB)        img_HD = np.array(img_HD / 255., dtype=np.float32)        img_HD = cv2.resize(img_HD, (self.out_size_w, self.out_size_h))         return img_HD             def process_video(self):        # 逐帧处理视频        cap = cv2.VideoCapture(self.datadir)        m_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))        img_HD_prev = None         for idx in range(m_frames):            ret, frame = cap.read()            if ret:                img_HD = self.cvtcolor_and_resize(frame)                 if img_HD_prev is None:                    img_HD_prev = img_HD                 syneth, G_pred, skymask = self.synthesize(img_HD, img_HD_prev)                 self.write_video(img_HD, syneth)                 img_HD_prev = img_HD                 if (idx + 1) % 50 == 0:                    print(f'processing video, frame {idx + 1} / {m_frames} ... ')             else:  # 如果到达最后一帧                break
复制代码

3.5 开始处理视频 

sf = SkyFilter(args)sf.process_video()
复制代码

3.6 对比原视频和处理后的视频

video_name = "compare.avi" def arrayShow(img):    _,ret = cv2.imencode('.jpg', img)    return Image(data=ret) # 打开一个视频流cap = cv2.VideoCapture(video_name) frame_id = 0while True:    try:        clear_output(wait=True) # 清除之前的显示        ret, frame = cap.read() # 读取一帧图片        if ret:            frame_id += 1            cv2.putText(frame, str(frame_id), (5, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)  # 画frame_id            tmp = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 转换色彩模式            img = arrayShow(frame)            display(img) # 显示图片            time.sleep(0.05) # 线程睡眠一段时间再处理下一帧图片        else:            break    except KeyboardInterrupt:        cap.release()cap.release()
复制代码

3.7 生成你自己的换天视频

三个步骤实现自定义视频的换天效果:

(1)在自己本地电脑上准备好一个待处理的 mp4 视频文件和一张天空图片;

(2)参考此文档,将视频文件和图片文件分别上传到 ModelArts JupyterLab 的 SkyAR/test_videos 目录和 SkyAR/skybox 目录下;

(3)修改步骤 2 “设定算法参数” 中 datadir 和 skybox 两个参数的路径为你刚上传的视频和图片路径;

(4)重新运行步骤 2~6。

发布于: 刚刚阅读数: 5
用户头像

逝缘~

关注

还未添加个人签名 2022.07.01 加入

还未添加个人简介

评论

发布
暂无评论
【深度学习】AI一键换天_人工智能_逝缘~_InfoQ写作社区