最新腾讯高精度动作模仿模型MimicMotion分享
MimicMotion是由腾讯研究人员开发的一种高质量的人类动作视频生成框架。
MimicMotion能够根据用户的单一参考图像和一系列希望模仿的动作姿态,生成高质量且受姿势引导的人类动作视频。
MimicMotion的技术原理包括姿态引导的视频生成、置信度感知的姿态指导、区域损失放大、潜在扩散模型、渐进式潜在融合、预训练模型的利用、U-Net和PoseNet的结构,以及跨帧平滑性。
MimicMotion通过结合姿势置信度信息,实现了更好的时间平滑性,并对噪声训练数据更具鲁棒性。
为了高效生成流畅的长视频,MimicMotion生成具有重叠帧的视频片段,并逐步融合其潜在表示,以控制计算成本。
MimicMotion的应用场景广泛,包括社交行业的展示类动作、教育行业的运动类动作和电商行业的介绍类动作等。
其中github项目地址:https://github.com/tencent/MimicMotion。
一、环境安装
1、python环境
建议安装python版本在3.10以上。
2、pip库安装
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
pip install diffusers==0.27.0 transformers==4.32.1 decord==0.6.0 av jmespath accelerate einops omegaconf matplotlib opencv-python onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple
3、MimicMotion_1模型下载:
wget -P models/ https://huggingface.co/ixaac/MimicMotion/resolve/main/MimicMotion_1-1.pth
4、DWPose模型下载:
wget https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true -O models/DWPose/yolox_l.onnx
wget https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true -O models/DWPose/dw-ll_ucoco_384.onnx
5、stable-video-diffusion-img2vid-xt-1-1模型下载:
git lfs install
git clone https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1
二、功能测试
1、运行测试:
(1)直接预测的python代码
import os
import subprocess
import time
import torch
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
from datetime import datetime
from cog import BasePredictor, Input, Path
from torchvision.transforms.functional import pil_to_tensor, resize, center_crop# Constants
ASPECT_RATIO = 16 / 9
MODEL_CACHE = "models"
BASE_URL = f"https://weights.replicate.delivery/default/MimicMotion/{MODEL_CACHE}/"# Environment setup
os.environ.update({"HF_DATASETS_OFFLINE": "1","TRANSFORMERS_OFFLINE": "1","HF_HOME": MODEL_CACHE,"TORCH_HOME": MODEL_CACHE,"HF_DATASETS_CACHE": MODEL_CACHE,"TRANSFORMERS_CACHE": MODEL_CACHE,"HUGGINGFACE_HUB_CACHE": MODEL_CACHE,
})def download_weights(url: str, dest: str) -> None:"""Download and extract model weights"""start = time.time()print(f"[!] Starting download from URL: {url}")print(f"[~] Destination path: {dest}")dest_dir = os.path.dirname(dest) if ".tar" in dest else destcommand = ["pget", "-vf" + ("x" if ".tar" in url else ""), url, dest]try:print(f"[~] Running command: {' '.join(command)}")subprocess.check_call(command)except subprocess.CalledProcessError as e:raise RuntimeError(f"Download failed. Command '{e.cmd}' exited with status {e.returncode}.") from eprint(f"[+] Download completed in {time.time() - start:.2f} seconds")class Predictor(BasePredictor):def setup(self):"""Load the model into memory for efficient predictions"""os.makedirs(MODEL_CACHE, exist_ok=True)model_files = ["DWPose.tar", "MimicMotion.pth", "MimicMotion_1-1.pth", "SVD.tar"]for model_file in model_files:url = BASE_URL + model_filepath = os.path.join(MODEL_CACHE, model_file)if not os.path.exists(path.replace(".tar", "")):download_weights(url, path)self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {self.device}")# Import only after downloading weightsglobal MimicMotionPipeline, create_pipeline, save_to_mp4, get_video_pose, get_image_posefrom mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipelinefrom mimicmotion.utils.loader import create_pipelinefrom mimicmotion.utils.utils import save_to_mp4from mimicmotion.dwpose.preprocess import get_video_pose, get_image_pose# Load configself.config = OmegaConf.create({"base_model_path": os.path.join(MODEL_CACHE, "SVD/stable-video-diffusion-img2vid-xt-1-1"),"ckpt_path": os.path.join(MODEL_CACHE, "MimicMotion_1-1.pth"),})# Initialize pipelineself.pipeline = create_pipeline(self.config, self.device)self.current_checkpoint = "v1-1"self.current_dtype = torch.float32def predict(self,motion_video: Path = Input(description="Reference video file containing the motion to be mimicked"),appearance_image: Path = Input(description="Reference image file for the appearance of the generated video"),resolution: int = Input(description="Height of the output video in pixels. Width is automatic.", default=576, ge=64, le=1024),chunk_size: int = Input(description="Number of frames per processing chunk", default=16, ge=2),frames_overlap: int = Input(description="Overlapping frames between chunks to smooth transitions", default=6, ge=0),denoising_steps: int = Input(description="Number of denoising steps in diffusion", default=25, ge=1, le=100),noise_strength: float = Input(description="Noise augmentation strength", default=0.0, ge=0.0, le=1.0),guidance_scale: float = Input(description="Guidance scale towards reference", default=2.0, ge=0.1, le=10.0),sample_stride: int = Input(description="Sampling interval for reference video frames", default=2, ge=1),output_frames_per_second: int = Input(description="Frames per second of the output video", default=15, ge=1, le=60),seed: int = Input(description="Random seed. Leave blank for random", default=None),checkpoint_version: str = Input(description="Choose the checkpoint version", choices=["v1", "v1-1"], default="v1-1")) -> Path:if not motion_video.exists():raise FileNotFoundError(f"Reference video file does not exist: {motion_video}")if not appearance_image.exists():raise FileNotFoundError(f"Reference image file does not exist: {appearance_image}")if resolution % 8 != 0:raise ValueError(f"Resolution must be a multiple of 8, got {resolution}")if chunk_size <= frames_overlap:raise ValueError(f"Chunk size ({chunk_size}) must be greater than frames overlap ({frames_overlap})")seed = seed or int.from_bytes(os.urandom(2), "big")print(f"Using seed: {seed}")# Update checkpoint if neededif checkpoint_version != self.current_checkpoint:self.config.ckpt_path = os.path.join(MODEL_CACHE, f"MimicMotion_{checkpoint_version.replace('-', '_')}.pth")self.pipeline = create_pipeline(self.config, self.device)self.current_checkpoint = checkpoint_versionuse_fp16 = torch.float16if use_fp16 != self.current_dtype:torch.set_default_dtype(use_fp16)self.pipeline = create_pipeline(self.config, self.device)self.current_dtype = use_fp16# Preprocess inputspose_pixels, image_pixels = self.preprocess(motion_video, appearance_image, resolution, sample_stride)# Run the model pipelinevideo_frames = self.run_pipeline(image_pixels, pose_pixels, chunk_size, frames_overlap,denoising_steps, noise_strength, guidance_scale, seed)# Save output videooutput_path = f"/tmp/output_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4"save_to_mp4(video_frames, output_path, fps=output_frames_per_second)return Path(output_path)def preprocess(self, video_path, image_path, resolution=576, sample_stride=2):image = Image.open(image_path).convert("RGB")image_tensor = pil_to_tensor(image)h, w = image_tensor.shape[-2:]if h > w:w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64else:w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolutionh_w_ratio = h / wif h_w_ratio < h_target / w_target:h_resize, w_resize = h_target, int(h_target / h_w_ratio)else:h_resize, w_resize = int(w_target * h_w_ratio), w_targetresized_image = resize(image_tensor, [h_resize, w_resize], antialias=True)cropped_image = center_crop(resized_image, [h_target, w_target])image_np = cropped_image.permute((1, 2, 0)).numpy()image_pose = get_image_pose(image_np)video_pose = get_video_pose(video_path, image_np, sample_stride=sample_stride)pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose])image_pixels = np.expand_dims(image_np, 0).transpose(0, 3, 1, 2)return torch.from_numpy(pose_pixels) / 127.5 - 1, torch.from_numpy(image_pixels) / 127.5 - 1def run_pipeline(self, image_pixels, pose_pixels, num_frames, frames_overlap,num_inference_steps, noise_aug_strength, guidance_scale, seed):image_list = [Image.fromarray((img.cpu().numpy().transpose(1, 2, 0) * 127.5 + 127.5).astype(np.uint8))for img in image_pixels]pose_pixels = pose_pixels.unsqueeze(0).to(self.device)generator = torch.Generator(device=self.device).manual_seed(seed)frames = self.pipeline(image_list,image_pose=pose_pixels,num_frames=pose_pixels.size(1),tile_size=num_frames,tile_overlap=frames_overlap,height=pose_pixels.shape[-2],width=pose_pixels.shape[-1],fps=7,noise_aug_strength=noise_aug_strength,num_inference_steps=num_inference_steps,generator=generator,min_guidance_scale=guidance_scale,max_guidance_scale=guidance_scale,decode_chunk_size=8,output_type="pt",device=self.device,).frames.cpu()video_frames = (frames * 255.0).to(torch.uint8)return video_frames[0, 1:] # Remove the reference image frame
(2)调用配置的python代码
import os
import argparse
import logging
import math
from omegaconf import OmegaConf
from datetime import datetime
from pathlib import Path
from typing import Tupleimport numpy as np
import torch
from torchvision.datasets.folder import pil_loader
from torchvision.transforms.functional import pil_to_tensor, resize, center_crop, to_pil_imagefrom mimicmotion.utils.geglu_patch import patch_geglu_inplace
patch_geglu_inplace()from constants import ASPECT_RATIOfrom mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline
from mimicmotion.utils.loader import create_pipeline
from mimicmotion.utils.utils import save_to_mp4
from mimicmotion.dwpose.preprocess import get_video_pose, get_image_poselogging.basicConfig(level=logging.INFO, format="%(asctime)s: [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")def preprocess(video_path: str, image_path: str, resolution: int = 576, sample_stride: int = 2) -> Tuple[torch.Tensor, torch.Tensor]:"""Preprocess ref image pose and video poseArgs:video_path (str): Input video pose path.image_path (str): Reference image path.resolution (int, optional): Defaults to 576.sample_stride (int, optional): Defaults to 2.Returns:pose_pixels, image_pixels (Tuple[torch.Tensor, torch.Tensor]): Processed pose and image tensors."""image_pixels = pil_loader(image_path)image_pixels = pil_to_tensor(image_pixels) # (c, h, w)h, w = image_pixels.shape[-2:]# Compute target height and width according to original aspect ratioif h > w:w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64else:w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolutionh_w_ratio = float(h) / float(w)if h_w_ratio < h_target / w_target:h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio)else:h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_targetimage_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None)image_pixels = center_crop(image_pixels, [h_target, w_target])image_pixels = image_pixels.permute((1, 2, 0)).numpy()# Get image and video pose valuesimage_pose = get_image_pose(image_pixels)video_pose = get_video_pose(video_path, image_pixels, sample_stride=sample_stride)pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose])image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2))return torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, torch.from_numpy(image_pixels) / 127.5 - 1def run_pipeline(pipeline: MimicMotionPipeline, image_pixels: torch.Tensor, pose_pixels: torch.Tensor, device: torch.device, task_config: OmegaConf) -> torch.Tensor:"""Run MimicMotion pipeline to generate video frames.Args:pipeline (MimicMotionPipeline): MimicMotion pipeline object.image_pixels (torch.Tensor): Processed image tensor.pose_pixels (torch.Tensor): Processed pose tensor.device (torch.device): Device to run the pipeline on.task_config (OmegaConf): Task-specific configuration.Returns:torch.Tensor: Generated video frames."""image_pixels = [to_pil_image(img.to(torch.uint8)) for img in (image_pixels + 1.0) * 127.5]generator = torch.Generator(device=device)generator.manual_seed(task_config.seed)frames = pipeline(image_pixels, image_pose=pose_pixels, num_frames=pose_pixels.size(0),tile_size=task_config.num_frames, tile_overlap=task_config.frames_overlap,height=pose_pixels.shape[-2], width=pose_pixels.shape[-1], fps=task_config.fps,noise_aug_strength=task_config.noise_aug_strength, num_inference_steps=task_config.num_inference_steps,generator=generator, min_guidance_scale=task_config.guidance_scale, max_guidance_scale=task_config.guidance_scale, decode_chunk_size=8, output_type="pt", device=device).frames.cpu()video_frames = (frames * 255.0).to(torch.uint8)return video_frames[:, 1:] # Exclude the first framedef process_task(pipeline: MimicMotionPipeline, task: OmegaConf, output_dir: str, device: torch.device):"""Process a single task for MimicMotion pipeline.Args:pipeline (MimicMotionPipeline): MimicMotion pipeline object.task (OmegaConf): Task-specific configuration.output_dir (str): Output directory to save the result.device (torch.device): Device to run the pipeline on."""pose_pixels, image_pixels = preprocess(task.ref_video_path, task.ref_image_path, resolution=task.resolution, sample_stride=task.sample_stride)video_frames = run_pipeline(pipeline, image_pixels, pose_pixels, device, task)output_path = f"{output_dir}/{os.path.basename(task.ref_video_path).split('.')[0]}_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4"save_to_mp4(video_frames, output_path, fps=task.fps)def parse_arguments() -> argparse.Namespace:"""Parse command-line arguments."""parser = argparse.ArgumentParser()parser.add_argument("--log_file", type=str, default=None)parser.add_argument("--inference_config", type=str, default="configs/test.yaml")parser.add_argument("--output_dir", type=str, default="outputs/", help="Path to output directory")parser.add_argument("--no_use_float16", action="store_true", help="Whether to use float16 to speed up inference")return parser.parse_args()def set_logger(log_file: str = None, log_level: int = logging.INFO):"""Set up the logger.Args:log_file (str, optional): Path to log file. Defaults to None.log_level (int, optional): Log level. Defaults to logging.INFO."""if log_file:log_handler = logging.FileHandler(log_file, "w")else:log_handler = logging.FileHandler(f"{args.output_dir}/{datetime.now().strftime('%Y%m%d%H%M%S')}.log", "w")log_handler.setFormatter(logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s"))log_handler.setLevel(log_level)logger.addHandler(log_handler)@torch.no_grad()
def main():args = parse_arguments()if not args.no_use_float16:torch.set_default_dtype(torch.float16)infer_config = OmegaConf.load(args.inference_config)pipeline = create_pipeline(infer_config, device)Path(args.output_dir).mkdir(parents=True, exist_ok=True)set_logger(args.log_file)for task in infer_config.test_case:process_task(pipeline, task, args.output_dir, device)logger.info("Inference completed successfully.")if __name__ == "__main__":main()
未完......
更多详细的欢迎关注:杰哥新技术