Training with PPO
XLRON uses Proximal Policy Optimization (PPO) to train reinforcement learning agents for optical network resource allocation. The entire training loop — environment simulation, rollout collection, advantage estimation, and gradient updates — runs end-to-end in JAX, compiled as a single program that executes on GPU/TPU.
Training is accessed through the train.py script. The environment and traffic flags are the same as those used for heuristic evaluation — this page focuses on the training-specific flags.
Building a Training Command
A training command combines:
- Environment flags — topology, resources, traffic (see Heuristic Evaluation)
- Parallelism and scale — number of environments, timesteps, rollout length
- Core PPO hyperparameters — learning rate, discount factor, clipping
- Algorithmic features — action masking, reward centering, prioritized sampling
- Schedules — learning rate, entropy, GAE-lambda annealing
- Model architecture — MLP, GNN, or Transformer
- Logging and output — W&B, model saving, diagnostics
A minimal training command:
python -m xlron.train.train \
--env_type=rmsa \
--topology_name=nsfnet_deeprmsa_directed \
--link_resources=100 \
--k=5 \
--load=250 \
--continuous_operation \
--TOTAL_TIMESTEPS=5000000 \
--NUM_ENVS=16 \
--ROLLOUT_LENGTH=100 \
--LR=5e-4
1. Parallelism and Scale
--NUM_ENVS
Number of parallel environment instances per learner. Each environment runs an independent simulation. The batch size for each update is NUM_ENVS * ROLLOUT_LENGTH. More environments improve sample diversity and statistical stability, but increase memory usage. Typical values: 16 to 2000.
--ROLLOUT_LENGTH
Number of environment steps collected per rollout before a PPO update. Longer rollouts capture more temporal context and allow better advantage estimation, but delay updates. Default: 150. Typical values: 100 to 256.
--TOTAL_TIMESTEPS
Total environment steps across all environments. The number of PPO updates is TOTAL_TIMESTEPS / (NUM_ENVS * ROLLOUT_LENGTH). Default: 1000000.
--STEPS_PER_INCREMENT
Steps per logging increment. Training is divided into TOTAL_TIMESTEPS / STEPS_PER_INCREMENT increments, with metrics reported after each. Default: 100000.
--UPDATE_EPOCHS
Number of passes over the rollout buffer per update step. Multiple epochs reuse the same data, improving sample efficiency at the risk of becoming too off-policy. Default: 1. Typical values: 1 to 10.
--NUM_MINIBATCHES
Number of minibatches to split the rollout buffer into per epoch. The minibatch size is (NUM_ENVS * ROLLOUT_LENGTH) / NUM_MINIBATCHES. Default: 1.
--NUM_LEARNERS
Number of independent learners, each with their own neural network parameters and set of environments. Useful for running multiple seeds in parallel or for meta-learning. Default: 1.
--NUM_DEVICES
Number of accelerator devices to use. Environments and learners are distributed across devices. Default: 1.
2. Core PPO Hyperparameters
--LR
Initial learning rate for the Adam optimizer. This is arguably the most important hyperparameter. Default: 5e-4. Typical range for sweeps: 1e-4 to 1e-3.
--GAMMA
Discount factor for future rewards. Values close to 1.0 make the agent more far-sighted. Default: 0.999. Typical sweep range: 0.99 to 0.9999.
--GAE_LAMBDA
Lambda parameter for Generalized Advantage Estimation (GAE). Controls the bias-variance tradeoff in advantage estimation. Higher values reduce bias but increase variance. Default: None (uses automatic annealing from INITIAL_LAMBDA to FINAL_LAMBDA). Set explicitly to disable annealing, e.g. --GAE_LAMBDA=0.95. Typical sweep range: 0.9 to 0.999.
--INITIAL_LAMBDA / --FINAL_LAMBDA
When GAE_LAMBDA is not set, lambda is annealed from INITIAL_LAMBDA (default 0.9) to FINAL_LAMBDA (default 0.98) over the course of training using a hyperbolic secant schedule. This starts with lower variance (shorter horizon) and gradually increases the horizon as the value function improves. The annealing speed is controlled by --LAMBDA_SCHEDULE_MULTIPLIER.
--CLIP_EPS
PPO clipping parameter. Limits how much the policy ratio can deviate from 1.0 per update. Smaller values are more conservative. Default: 0.2. Typical sweep values: 0.02, 0.04, 0.08, 0.16.
--VF_COEF
Value function loss coefficient. Scales the critic loss relative to the actor loss. Default: 0.5. Typical sweep range: 0.1 to 1.0.
--ENT_COEF
Entropy bonus coefficient. Encourages exploration by penalising deterministic policies. Default: 0.0. Typical sweep range: 0.001 to 0.1. Can be scheduled (see Schedules).
--MAX_GRAD_NORM
Maximum gradient norm for gradient clipping. Prevents destructively large updates. Default: 0.5.
--ADAM_EPS / --ADAM_BETA1 / --ADAM_BETA2
Adam optimizer parameters. Defaults: 1e-5, 0.9, 0.999. Typical sweep ranges: ADAM_BETA1 in [0.8, 0.99], ADAM_BETA2 in [0.9, 0.999].
--WEIGHT_DECAY
Weight decay (L2 regularization) for the AdamW optimizer. Default: 0.0.
--REWARD_SCALE
Multiply all rewards by this factor. Can help with training stability when rewards are very small or very large. Default: 1.0. Typical sweep values: 1, 10.
--LOGR_CLIP / --ADV_CLIP
Clip the log probability ratio to [-LOGR_CLIP, +LOGR_CLIP] and normalized advantages to [-ADV_CLIP, +ADV_CLIP]. Prevents numerical instabilities from extreme values. Defaults: 10.0, 10.0.
3. Algorithmic Features
Invalid Action Masking
--ACTION_MASKING
Enable invalid action masking (default: True). The environment provides a mask of valid actions at each step; invalid actions receive logits of -1e8, ensuring they are never selected. This is critical for optical network environments where most actions are invalid at any given step.
--OFF_POLICY_IAM
Off-policy invalid action masking (default: False). When enabled, the importance ratio in PPO is computed as unmasked_policy / masked_policy rather than masked_policy / masked_policy. This means the ratio reflects the probability the unmasked (behavior) policy would have assigned to the taken action, which can be beneficial when the valid action set changes significantly between rollout and update time.
Valid Mass and Gating
The PPO loss uses a gating mechanism to handle steps where the agent has very few or no valid actions:
- Steps with 0 valid actions are completely gated out (no gradient signal)
- Steps with only 1 valid action are gated out of the actor/entropy loss (the action is forced, so no learning signal)
- Steps with low "valid mass" (probability the unmasked policy places on valid actions) have their actor/entropy loss contribution damped
--VALID_MASS_TARGET
Threshold below which the actor/entropy loss is linearly damped based on valid mass. A step with valid_mass = VALID_MASS_TARGET gets full weight; valid_mass = 0 gets zero weight. Default: 0.05. Tune in range 0.02 to 0.1.
--VALID_MASS_LOSS_COEF
Coefficient for an auxiliary loss that encourages the unmasked policy to place probability on valid actions. This helps prevent the unmasked logits from drifting away from valid actions, which would cause the valid mass to collapse and reduce learning signal. Default: 0.0 (disabled). Enable with values like 0.01 to 0.1.
--VML_SCHEDULE / --VML_END_FRACTION
Schedule for the valid mass loss coefficient. Supports constant, linear, cosine. VML_END_FRACTION > 1.0 anneals the coefficient upward (e.g. 10.0 means the final coefficient is 10x the initial). Default: constant, 10.0.
Reward Centering
--REWARD_CENTERING
Enable reward centering (default: False). Subtracts a running estimate of the average reward from all rewards before computing TD errors. This can stabilize training in environments with non-zero average reward by keeping advantage estimates centered. The average reward estimate is updated using a decaying step size.
--INITIAL_AVERAGE_REWARD / --REWARD_STEPSIZE
Initial average reward estimate and step size for the exponential moving average update. Defaults: 0.0, 0.001.
Prioritized Experience Replay
--PRIO_ALPHA
Priority exponent for prioritized sampling of the rollout buffer. 0.0 = uniform sampling (default), 1.0 = fully prioritized by absolute advantage. Samples with higher absolute advantage are replayed more often, with importance sampling corrections to maintain unbiased gradients.
--PRIO_BETA0
Initial importance sampling correction exponent. Annealed from PRIO_BETA0 to 1.0 over training. 1.0 = full correction from the start (default). Lower values allow more biased but potentially faster early learning.
VTrace / Puffer Advantage
--RHO_CLIP / --C_CLIP
Clipping parameters for VTrace-style off-policy correction in the advantage calculation. When both are set to positive values, the importance ratios in the GAE calculation are clipped:
RHO_CLIP: Clips the ratio used in the TD error (delta = rho_t * (r + gamma * V(s') - V(s)))C_CLIP: Clips the ratio used in the GAE accumulation (A_t = delta_t + gamma * lambda * c_t * A_{t+1})
When both are <= 0 (default), standard GAE is used without clipping. When lambda=1 and both clips are set, this reduces to the VTrace algorithm. These are useful when doing multiple epochs of updates, as the policy changes between epochs making the data off-policy.
--include_no_op
Add a "no operation" action to the action space. Default: False.
--STEP_ON_GRADIENT
When True, the schedule step counter increments on each gradient update (each minibatch). When False (default), it increments once per update loop. This affects how quickly learning rate and other schedules progress.
4. Schedules
XLRON supports scheduling for several hyperparameters. All schedules operate over the total number of gradient steps (NUM_UPDATES * UPDATE_EPOCHS * NUM_MINIBATCHES).
Learning Rate Schedule
--LR_SCHEDULE
| Value | Description |
|---|---|
cosine |
Cosine decay from LR to LR * LR_END_FRACTION (default) |
warmup_cosine |
Linear warmup to LR * WARMUP_MULTIPLIER, then cosine decay |
linear |
Linear decay from LR to LR * LR_END_FRACTION |
constant |
Fixed learning rate |
--LR_END_FRACTION
Final learning rate as a fraction of initial LR. Default: 0.1 (i.e. LR decays to 10% of its initial value).
--WARMUP_MULTIPLIER
Peak learning rate during warmup, as a multiple of initial LR. Default: 1.0. Set to e.g. 2.0 to warm up to 2x the initial LR before decaying.
--WARMUP_STEPS_FRACTION
Fraction of total training steps used for warmup. Default: 0.2.
--LR_SCHEDULE_MULTIPLIER
Multiply the schedule horizon by this factor. Values > 1.0 slow down the schedule (decay takes longer). Default: 1.0.
Entropy Schedule
--ENT_SCHEDULE
Schedule for the entropy coefficient. Supports constant (default), linear, cosine. With linear or cosine, the entropy coefficient decays from ENT_COEF to ENT_COEF * ENT_END_FRACTION.
--ENT_END_FRACTION
Final entropy coefficient as a fraction of initial. Default: 0.1. Typical sweep values: 0.1, 0.2, 0.5.
--ENT_SCHEDULE_MULTIPLIER
Multiply the entropy schedule horizon. Default: 1.0.
Separate Value Function Optimizer
--SEPARATE_VF_OPTIMIZER
Use a separate optimizer for the critic (value function) with its own hyperparameters. Default: False. When enabled, the following flags control the critic optimizer:
--VF_LR— Learning rate (default:LR / 3)--VF_LR_SCHEDULE— Schedule type (default: same asLR_SCHEDULE)--VF_LR_END_FRACTION— End fraction (default: same asLR_END_FRACTION)--VF_WARMUP_MULTIPLIER— Warmup peak (default: same asWARMUP_MULTIPLIER)--VF_WARMUP_STEPS_FRACTION— Warmup fraction (default: same asWARMUP_STEPS_FRACTION)--VF_ADAM_EPS,--VF_ADAM_BETA1,--VF_ADAM_BETA2— Adam parameters--VF_MAX_GRAD_NORM— Gradient clipping--VF_WEIGHT_DECAY— Weight decay--VF_SCHEDULE_MULTIPLIER— Schedule horizon multiplier
5. Model Architecture
XLRON provides three neural network architectures. All follow an actor-critic design with shared or separate encoders.
MLP (Default)
A simple multi-layer perceptron. Used by default when neither --USE_GNN nor --USE_TRANSFORMER is set.
| Flag | Description | Default |
|---|---|---|
--NUM_LAYERS |
Number of hidden layers (shared by actor and critic) | 2 |
--NUM_UNITS |
Hidden units per layer | 64 |
--ACTIVATION |
Activation function (tanh, relu) |
tanh |
The MLP takes a flat observation vector as input. Its simplicity makes it fast to train and a good baseline.
GNN (Graph Neural Network)
A Jraph-based graph neural network that operates directly on the network topology. Enabled with --USE_GNN.
| Flag | Description | Default |
|---|---|---|
--message_passing_steps |
Rounds of message passing | 3 |
--edge_embedding_size |
Edge embedding dimension | 128 |
--edge_mlp_layers / --edge_mlp_latent |
Edge update MLP depth/width | 2 / 128 |
--node_embedding_size |
Node embedding dimension | 16 |
--node_mlp_layers / --node_mlp_latent |
Node update MLP depth/width | 2 / 128 |
--global_embedding_size |
Global embedding dimension | 8 |
--global_mlp_layers / --global_mlp_latent |
Global update MLP depth/width | 1 / 16 |
--attn_mlp_layers / --attn_mlp_latent |
Attention MLP depth/width (0 to disable attention) | 1 / 64 |
--num_spectral_features |
Number of spectral features per node | 8 |
--gnn_layer_norm |
Layer normalization in GNN | True |
--mlp_layer_norm |
Layer normalization in MLPs | False |
--normalize_by_link_length |
Normalize features by link length | False |
--DISABLE_NODE_FEATURES |
Disable node features (use only edge features) | False |
The GNN has an inductive bias for graph-structured problems — it can generalize across topologies and naturally captures link adjacency. When attn_mlp_layers > 0, it uses GATv2-style attention on edges.
Transformer
A transformer encoder that processes link-level tokens with graph-aware positional encodings (WIRE — Wavelet-Induced Rotary Encodings). Enabled with --USE_TRANSFORMER.
| Flag | Description | Default |
|---|---|---|
--transformer_num_layers |
Number of transformer encoder layers | 1 |
--transformer_num_heads |
Number of attention heads | 4 |
--transformer_embedding_size |
Token embedding dimension | 128 |
--transformer_intermediate_size |
Feed-forward hidden dimension | 256 |
--transformer_obs_type |
Observation type: departure (departure times), occupancy (binary), or capacity (for RWA-LR) |
departure |
--num_wire_features |
Number of spectral features for WIRE positional encoding | 8 |
--transformer_actor_mlp_width / --transformer_actor_mlp_depth |
Actor head MLP width/depth | 128 / 1 |
--transformer_critic_mlp_width / --transformer_critic_mlp_depth |
Critic head MLP width/depth | 128 / 2 |
--transformer_share_layers |
Share encoder layers between actor and critic | False |
--transformer_dropout_rate / --transformer_attention_dropout_rate |
Dropout rates | 0.1 / 0.1 |
--transformer_enable_dropout |
Enable dropout during training | False |
The transformer processes each link as a token, with per-link spectral features and request-specific information. WIRE positional encodings inject graph structure through rotary embeddings derived from the graph Laplacian.
Slot Aggregation
--aggregate_slots
Reduce the action space by grouping frequency slots into blocks of N. The agent selects a block, and first-fit allocation selects the specific slot within the block. Default: 1 (no aggregation). Typical values for large action spaces: 10, 20, 100.
6. Logging, Saving, and Evaluation
Weights & Biases
--WANDB
Enable logging to W&B. See Understanding XLRON for setup and sweep configuration.
--LOG_LOSS_INFO
Log loss metrics (actor loss, value loss, entropy, etc.) to W&B. Default: True.
--ENHANCED_LOGGING
Enable detailed diagnostic logging including valid fraction, clip fraction, ratio statistics, valid mass statistics, advantage statistics, and more. Useful for debugging training dynamics. Default: False.
Model Saving and Loading
--SAVE_MODEL
Save the best model (by blocking probability) during training. Saved locally to --MODEL_PATH and uploaded to W&B if --WANDB is enabled. Default: False.
--MODEL_PATH
Path to save/load model files. If unspecified, models are saved to models/<EXPERIMENT_NAME>.eqx.
--RETRAIN_MODEL
Load a previously saved model and continue training. Restores model weights but not optimizer state.
--EVAL_MODEL
Load a saved model for evaluation only (no training). Uses the same evaluation infrastructure as heuristic evaluation.
--KEEP_VF
When loading a model, keep only the pre-trained value function (critic) and reinitialize the actor. Useful for transfer learning.
Evaluation During Training
--EVAL_DURING_TRAINING
Run periodic evaluation using the current policy without exploration noise. This tracks the best-performing model across training. Default: False.
--EVAL_TIMESTEPS
Number of timesteps per evaluation run. Default: 10000.
--EVAL_FREQUENCY
Run evaluation every N increments. Default: 1.
Other Logging Flags
--DATA_OUTPUT_FILE / --TRAJ_DATA_OUTPUT_FILE
Save per-episode or per-step metrics to CSV files.
--DOWNSAMPLE_FACTOR
Block-average every N consecutive data points before uploading to W&B. Default: 1.
--log_actions / --log_path_lengths
Log detailed per-step action and path length data. Default: False.
7. Hyperparameter Sweeps
XLRON supports W&B hyperparameter sweeps. Define a sweep configuration YAML file and use the wandb sweep command to launch. See Understanding XLRON for full details.
An example sweep config is provided at examples/sweep.yaml. The key hyperparameters to sweep are:
| Parameter | Typical Range | Notes |
|---|---|---|
LR |
1e-4 to 1e-3 |
Most impactful hyperparameter |
GAMMA |
0.99 to 0.9999 |
Discount factor |
GAE_LAMBDA |
0.9 to 0.999 |
Advantage estimation horizon |
CLIP_EPS |
0.02, 0.04, 0.08, 0.16 |
PPO clipping |
VF_COEF |
0.1 to 1.0 |
Value loss weight |
ENT_COEF |
0.001 to 0.1 |
Exploration |
ENT_END_FRACTION |
0.1, 0.2, 0.5 |
Entropy decay |
REWARD_SCALE |
1, 10 |
Reward magnitude |
OFF_POLICY_IAM |
true, false |
Action masking style |
ADAM_BETA1 |
0.8 to 0.99 |
Optimizer momentum |
ADAM_BETA2 |
0.9 to 0.999 |
Optimizer second moment |
ROLLOUT_LENGTH |
128, 256 |
Trajectory length |
aggregate_slots |
10, 20, 100 |
Action space reduction |
For transformer-specific sweeps, also consider: transformer_num_heads, transformer_embedding_size, transformer_obs_type, transformer_actor_mlp_width, transformer_actor_mlp_depth.
To launch a sweep:
wandb sweep examples/sweep.yaml
wandb agent <SWEEP_ID>
8. Flagfiles
You can save a set of flags to a text file and load them with --flagfile=PATH:
# my_flags.txt
--env_type=rmsa
--topology_name=nsfnet_deeprmsa_directed
--link_resources=100
--k=50
--load=250
--continuous_operation
--ENV_WARMUP_STEPS=3000
python -m xlron.train.train --flagfile=my_flags.txt --LR=1e-4 --NUM_ENVS=64
Command-line flags override flagfile values.
Examples
Example 1: DeepRMSA Reproduction
Reproduce the DeepRMSA paper training setup:
python -m xlron.train.train \
--env_type=deeprmsa \
--topology_name=nsfnet_deeprmsa_directed \
--link_resources=100 \
--k=5 \
--load=250 \
--mean_service_holding_time=25 \
--max_requests=1000 \
--continuous_operation \
--ENV_WARMUP_STEPS=3000 \
--NUM_LAYERS=5 \
--NUM_UNITS=128 \
--NUM_ENVS=16 \
--ROLLOUT_LENGTH=100 \
--TOTAL_TIMESTEPS=5000000 \
--LR=5e-5 \
--LR_SCHEDULE=linear \
--WARMUP_MULTIPLIER=2 \
--UPDATE_EPOCHS=10 \
--GAE_LAMBDA=0.9 \
--GAMMA=0.95 \
--ACTION_MASKING
Example 2: RMSA with Transformer and Reward Centering
Train a transformer agent with reward centering and entropy scheduling:
python -m xlron.train.train \
--env_type=rmsa \
--topology_name=nsfnet_deeprmsa_directed \
--link_resources=100 \
--k=50 \
--load=200 \
--mean_service_holding_time=1 \
--continuous_operation \
--ENV_WARMUP_STEPS=3000 \
--truncate_holding_time \
--USE_TRANSFORMER \
--transformer_num_layers=2 \
--transformer_num_heads=4 \
--transformer_embedding_size=128 \
--transformer_obs_type=departure \
--aggregate_slots=20 \
--NUM_ENVS=64 \
--TOTAL_TIMESTEPS=60000000 \
--UPDATE_EPOCHS=1 \
--LR=5e-4 \
--REWARD_CENTERING \
--ENT_COEF=0.01 \
--ENT_SCHEDULE=cosine \
--ENT_END_FRACTION=0.1 \
--WANDB \
--LOG_LOSS_INFO \
--DOWNSAMPLE_FACTOR=100
Example 3: GNN Agent
Train a GNN agent with graph attention:
python -m xlron.train.train \
--env_type=rmsa \
--topology_name=nsfnet_deeprmsa_directed \
--link_resources=100 \
--k=5 \
--load=250 \
--continuous_operation \
--ENV_WARMUP_STEPS=3000 \
--USE_GNN \
--message_passing_steps=3 \
--edge_embedding_size=128 \
--node_embedding_size=16 \
--attn_mlp_layers=1 \
--num_spectral_features=8 \
--NUM_ENVS=32 \
--TOTAL_TIMESTEPS=10000000 \
--LR=3e-4 \
--GAMMA=0.999 \
--WANDB
Example 4: Training with Advanced Features
Combine multiple algorithmic features:
python -m xlron.train.train \
--env_type=rmsa \
--topology_name=nsfnet_deeprmsa_directed \
--link_resources=100 \
--k=50 \
--load=250 \
--continuous_operation \
--ENV_WARMUP_STEPS=3000 \
--truncate_holding_time \
--NUM_ENVS=64 \
--ROLLOUT_LENGTH=128 \
--TOTAL_TIMESTEPS=20000000 \
--LR=3e-4 \
--LR_SCHEDULE=warmup_cosine \
--WARMUP_MULTIPLIER=2 \
--WARMUP_STEPS_FRACTION=0.1 \
--GAMMA=0.999 \
--CLIP_EPS=0.08 \
--VF_COEF=0.5 \
--ENT_COEF=0.01 \
--ENT_SCHEDULE=cosine \
--ENT_END_FRACTION=0.2 \
--REWARD_CENTERING \
--OFF_POLICY_IAM \
--VALID_MASS_LOSS_COEF=0.01 \
--REWARD_SCALE=10 \
--WANDB \
--ENHANCED_LOGGING \
--SAVE_MODEL
Example 5: Save and Evaluate a Model
Save during training:
python -m xlron.train.train \
--env_type=rmsa \
--topology_name=nsfnet_deeprmsa_directed \
--link_resources=100 \
--k=50 \
--load=250 \
--continuous_operation \
--ENV_WARMUP_STEPS=3000 \
--NUM_ENVS=64 \
--TOTAL_TIMESTEPS=10000000 \
--LR=3e-4 \
--SAVE_MODEL \
--MODEL_PATH=models/my_agent.eqx
Then evaluate:
python -m xlron.train.train \
--env_type=rmsa \
--topology_name=nsfnet_deeprmsa_directed \
--link_resources=100 \
--k=50 \
--load=250 \
--continuous_operation \
--ENV_WARMUP_STEPS=3000 \
--NUM_ENVS=2000 \
--TOTAL_TIMESTEPS=20000000 \
--EVAL_MODEL \
--MODEL_PATH=models/my_agent.eqx