import cv2import jsonfrom pathlib import Pathimport shutilimport requestsimport gradio_client.client as gradio_clientimport gradio_client.utils as gradio_utils
# 定义Cosmos服务URL和Token# COSMOS_SERVICE_URL = "http://xxxxxx" # 请替换为实际服务URL# EAS_TOKEN = "your_eas_token" # 请替换为实际EAS Token
RGB_TARGETS = [ "state/rgb/robot.front_camera.left.rgb_image", "state/rgb/robot.front_camera.right.rgb_image",]
# --- 模块1: 图像序列转视频 ---def convert_sequence_to_video(input_dir: Path, output_path: Path, fps: int, image_format: str) -> bool: print(f" - Converting to video: {input_dir.name} (Format: {image_format})") image_files = sorted(list(input_dir.glob(f'*.{image_format}'))) if not image_files: print(f" -> No '{image_format}' images found. Skipping.") return False try: first_img = cv2.imread(str(image_files[0])) if first_img is None: raise IOError("Cannot read the first image.") height, width = first_img.shape[:2] except Exception as e: print(f" -> Error reading first image: {e}. Skipping.") return False output_path.parent.mkdir(parents=True, exist_ok=True) fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height)) for image_file in image_files: frame = cv2.imread(str(image_file)) if frame is not None: out.write(frame) out.release() print(f" -> Video created: {output_path}") return True
# --- 新模块: 视频转图像序列 ---def split_video_to_frames(video_path: Path, output_dir: Path, original_image_dir: Path, image_format: str): """ 将视频文件拆分为一帧帧的图片,并使用原始图片的文件名进行命名。 """ print(f" - Splitting video back to frames: {video_path.name}") # 1. 获取原始文件名作为模板 original_image_files = sorted(list(original_image_dir.glob(f'*.{image_format}'))) original_filenames = [p.name for p in original_image_files] if not original_filenames: print(f" -> Warning: Could not find original images in {original_image_dir} to use for naming. Skipping frame splitting.") return
# 2. 准备输出目录 output_dir.mkdir(parents=True, exist_ok=True) # 3. 打开视频文件并逐帧读取 cap = cv2.VideoCapture(str(video_path)) if not cap.isOpened(): print(f" -> Error: Could not open video file {video_path}. Skipping.") return
frame_index = 0 while True: ret, frame = cap.read() if not ret: break # 视频结束
if frame_index < len(original_filenames): # 使用原始文件名来保存新帧 output_filepath = output_dir / original_filenames[frame_index] cv2.imwrite(str(output_filepath), frame) else: # 如果视频帧数多于原始图片数,则停止,避免命名冲突 print(f" -> Warning: Video contains more frames than original image sequence. Stopping at frame {frame_index}.") break frame_index += 1
cap.release() print(f" -> Success! {frame_index} frames saved to: {output_dir}")
# --- 模块2: 调用Cosmos服务 ---def cosmos_sync_with_upload(client, rgb_video_path, seg_video_path, output_dir, original_rgb_dir): """上传视频,调用API,下载结果,并触发视频到帧的转换。""" def upload_file(filepath: Path): if not filepath or not filepath.exists(): return None print(f" - Uploading: {filepath.name}") file_desc = gradio_utils.handle_file(str(filepath)) result_str = client.predict(file_desc, api_name="/upload_file") return json.loads(result_str).get("path")
remote_rgb_path = upload_file(rgb_video_path) remote_seg_path = upload_file(seg_video_path)
if not remote_rgb_path or not remote_seg_path: return False, "视频上传失败"
request_dict = create_cosmos_request(remote_rgb_path, remote_seg_path) print(" - Sending generation request to Cosmos service...") result = client.predict(json.dumps(request_dict), api_name="/generate_video") if isinstance(result, tuple) and len(result) >= 2 and isinstance(result[0], dict): video_path = result[0].get("video") if not video_path: return False, f"API did not return a video path. Message: {result[1]}"
output_file = Path(output_dir) / f"{rgb_video_path.stem}_cosmos_enhanced.mp4" # 统一处理下载或复制的逻辑 success = False if video_path.startswith(("http://", "https://")): try: resp = requests.get(video_path, stream=True, timeout=300) resp.raise_for_status() with open(output_file, "wb") as f: shutil.copyfileobj(resp.raw, f) success = True except requests.exceptions.RequestException as e: return False, f"Failed to download video: {e}" else: source_path = Path(video_path) if source_path.exists(): shutil.copy2(source_path, output_file) success = True else: return False, f"API returned a local path that does not exist: {video_path}"
if success: print(f" -> Augmented video saved to: {output_file}") # 定义新帧的输出目录,例如 .../robot.front_camera.left.rgb_image_cosmos new_frames_output_dir = original_rgb_dir.parent / f"{original_rgb_dir.name}_cosmos" split_video_to_frames( video_path=output_file, output_dir=new_frames_output_dir, original_image_dir=original_rgb_dir, image_format="jpg" # RGB图像的原始格式 ) return True, str(output_file) else: return False, "Failed to retrieve the generated video file." else: return False, f"Unexpected API response format: {result}"
def create_cosmos_request(remote_rgb_path, remote_seg_path): """动态创建Cosmos请求,包含主视频和分割视频的远程路径。""" return { "prompt": "A realistic warehouse environment with consistent lighting, perspective, and camera motion. Preserve the original structure, object positions, and layout from the input video. Ensure the output exactly matches the segmentation video frame-by-frame in timing and content. Camera movement must follow the original path precisely.", "negative_prompt": "The video captures a game playing, with bad crappy graphics and cartoonish frames. It represents a recording of old outdated games. The images are very pixelated and of poor CG quality. There are many subtitles in the footage. Overall, the video is unrealistic and appears cg. Plane background.", "sigma_max": 80, "guidance": 7, "input_video_path": remote_rgb_path, # 主视频路径 "blur_strength": "low", "canny_threshold": "low", "edge": {"control_weight": 0.3}, "seg": { "control_weight": 1.0, "input_control": remote_seg_path # 分割视频路径 } }
# --- 模块3: 主工作流控制器 ---def process_and_augment_replays(output_dir: str, fps: int = 30): source_root = Path("/root/MobilityGenData/replays") output_root = Path(output_dir) if not source_root.is_dir(): return timestamp_dirs = [d for d in source_root.iterdir() if d.is_dir()] if not timestamp_dirs: return client = gradio_client.Client(COSMOS_SERVICE_URL, hf_token=EAS_TOKEN)
for ts_dir in timestamp_dirs: print(f"\nProcessing replay: {ts_dir.name}") final_output_dir = output_root / ts_dir.name final_output_dir.mkdir(exist_ok=True) for rgb_rel_path_str in RGB_TARGETS: rgb_image_dir = ts_dir / rgb_rel_path_str seg_rel_path_str = rgb_rel_path_str.replace("rgb", "segmentation") seg_image_dir = ts_dir / seg_rel_path_str if not (rgb_image_dir.is_dir() and seg_image_dir.is_dir()): continue rgb_video_path = final_output_dir / f"{rgb_image_dir.name}.mp4" seg_video_path = final_output_dir / f"{seg_image_dir.name}.mp4" rgb_ok = convert_sequence_to_video(rgb_image_dir, rgb_video_path, fps, "jpg") seg_ok = convert_sequence_to_video(seg_image_dir, seg_video_path, fps, "png") if not (rgb_ok and seg_ok): continue
cosmos_sync_with_upload( client, rgb_video_path, seg_video_path, final_output_dir, original_rgb_dir=rgb_image_dir )
print("\n" + "="*20 + " 全部处理完成 " + "="*20)
# --- 程序入口 ---if __name__ == "__main__": !mkdir -p /root/MobilityGenData/cosmos_augmented_videos output_directory = "/root/MobilityGenData/cosmos_augmented_videos" process_and_augment_replays(output_dir=output_directory)
评论