import osimport subprocessimport sysfrom dataclasses import dataclassfrom pathlib import Path
import torchimport tyrofrom transformers import TrainingArguments
from gr00t.data.dataset import LeRobotSingleDatasetfrom gr00t.data.schema import EmbodimentTagfrom gr00t.experiment.data_config import DATA_CONFIG_MAPfrom gr00t.experiment.runner import TrainRunnerfrom gr00t.model.gr00t_n1 import GR00T_N1from gr00t.utils.peft import get_lora_model
@dataclassclass Config: """Configuration for GR00T model fine-tuning."""
# Dataset parameters dataset_path: str """Path to the dataset directory."""
output_dir: str = "/tmp/gr00t" """Directory to save model checkpoints."""
data_config: str = "gr1_arms_only" """Data configuration name from DATA_CONFIG_MAP."""
# Training parameters batch_size: int = 16 """Batch size per GPU for training."""
max_steps: int = 10000 """Maximum number of training steps."""
num_gpus: int = 1 """Number of GPUs to use for training."""
save_steps: int = 500 """Number of steps between saving checkpoints."""
# Model parameters base_model_path: str = "nvidia/GR00T-N1-2B" """Path or HuggingFace model ID for the base model."""
tune_llm: bool = False """Whether to fine-tune the language model backbone."""
tune_visual: bool = True """Whether to fine-tune the vision tower."""
tune_projector: bool = True """Whether to fine-tune the projector."""
tune_diffusion_model: bool = True """Whether to fine-tune the diffusion model."""
resume: bool = False """Whether to resume from a checkpoint."""
# Advanced training parameters learning_rate: float = 1e-4 """Learning rate for training."""
weight_decay: float = 1e-5 """Weight decay for AdamW optimizer."""
warmup_ratio: float = 0.05 """Ratio of total training steps used for warmup."""
lora_rank: int = 0 """Rank for the LORA model."""
lora_alpha: int = 16 """Alpha value for the LORA model."""
lora_dropout: float = 0.1 """Dropout rate for the LORA model."""
dataloader_num_workers: int = 8 """Number of workers for data loading."""
report_to: str = "wandb" """Where to report training metrics (e.g., 'wandb', 'tensorboard')."""
# Data loading parameters embodiment_tag: str = "new_embodiment" """Embodiment tag to use for training. e.g. 'new_embodiment', 'gr1'"""
video_backend: str = "decord" """Video backend to use for training. [decord, torchvision_av]"""
###################################################################################### main training function#####################################################################################
def main(config: Config): """Main training function.""" # ------------ step 1: load dataset ------------ embodiment_tag = EmbodimentTag(config.embodiment_tag)
# 1.1 modality configs and transforms data_config_cls = DATA_CONFIG_MAP[config.data_config] modality_configs = data_config_cls.modality_config() transforms = data_config_cls.transform()
# 1.2 data loader train_dataset = LeRobotSingleDataset( dataset_path=config.dataset_path, modality_configs=modality_configs, transforms=transforms, embodiment_tag=embodiment_tag, # This will override the dataset's embodiment tag to "new_embodiment" video_backend=config.video_backend, )
# ------------ step 2: load model ------------ model = GR00T_N1.from_pretrained( pretrained_model_name_or_path=config.base_model_path, tune_llm=config.tune_llm, # backbone's LLM tune_visual=config.tune_visual, # backbone's vision tower tune_projector=config.tune_projector, # action head's projector tune_diffusion_model=config.tune_diffusion_model, # action head's DiT )
# Set the model's compute_dtype to bfloat16 model.compute_dtype = "bfloat16" model.config.compute_dtype = "bfloat16"
if config.lora_rank > 0: model = get_lora_model( model, rank=config.lora_rank, lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout, )
# 2.1 modify training args training_args = TrainingArguments( output_dir=config.output_dir, run_name=None, remove_unused_columns=False, deepspeed="", gradient_checkpointing=False, bf16=True, tf32=True, per_device_train_batch_size=config.batch_size, gradient_accumulation_steps=1, dataloader_num_workers=config.dataloader_num_workers, dataloader_pin_memory=False, dataloader_persistent_workers=True, optim="adamw_torch", adam_beta1=0.95, adam_beta2=0.999, adam_epsilon=1e-8, learning_rate=config.learning_rate, weight_decay=config.weight_decay, warmup_ratio=config.warmup_ratio, lr_scheduler_type="cosine", logging_steps=10.0, num_train_epochs=300, max_steps=config.max_steps, save_strategy="steps", save_steps=config.save_steps, save_total_limit=8, report_to=config.report_to, seed=42, do_eval=False, ddp_find_unused_parameters=False, ddp_bucket_cap_mb=100, torch_compile_mode=None, )
# 2.2 run experiment experiment = TrainRunner( train_dataset=train_dataset, model=model, training_args=training_args, resume_from_checkpoint=config.resume, )
# 2.3 run experiment experiment.train()
if __name__ == "__main__": # Parse arguments using tyro config = tyro.cli(Config)
# Print the tyro config print("\n" + "=" * 50) print("GR00T FINE-TUNING CONFIGURATION:") print("=" * 50) for key, value in vars(config).items(): print(f"{key}: {value}") print("=" * 50 + "\n")
available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
# Validate GPU configuration assert ( config.num_gpus <= available_gpus ), f"Number of GPUs requested ({config.num_gpus}) is greater than the available GPUs ({available_gpus})" assert config.num_gpus > 0, "Number of GPUs must be greater than 0" print(f"Using {config.num_gpus} GPUs")
if config.num_gpus == 1: # Single GPU mode - set CUDA_VISIBLE_DEVICES=0 os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Run the script normally main(config) else: if os.environ.get("IS_TORCHRUN", "0") == "1": main(config) else: # Multi-GPU mode - use torchrun script_path = Path(__file__).absolute() # Remove any existing CUDA_VISIBLE_DEVICES from environment if "CUDA_VISIBLE_DEVICES" in os.environ: del os.environ["CUDA_VISIBLE_DEVICES"]
# Use subprocess.run instead of os.system cmd = [ "torchrun", "--standalone", f"--nproc_per_node={config.num_gpus}", "--nnodes=1", # default to 1 node for now str(script_path), ]
# Convert config to command line arguments for key, value in vars(config).items(): if isinstance(value, bool): # For boolean values, use --flag or --no-flag format if value: cmd.append(f"--{key.replace('_', '-')}") else: cmd.append(f"--no-{key.replace('_', '-')}") else: # For non-boolean values, use --key value format cmd.append(f"--{key.replace('_', '-')}") cmd.append(str(value)) print("Running torchrun command: ", cmd) env = os.environ.copy() env["IS_TORCHRUN"] = "1" sys.exit(subprocess.run(cmd, env=env).returncode)
评论