Skip to content

Environments

EnvParams

Dataclass to hold environment parameters. Parameters are immutable.

Parameters:

Name Type Description Default
max_requests Scalar

Maximum number of requests in an episode

required
incremental_loading Scalar

Incremental increase in traffic load (non-expiring requests)

required
end_first_blocking Scalar

End episode on first blocking event

required
continuous_operation Scalar

If True, do not reset the environment at the end of an episode

required
edges Array

Two column array defining source-dest node-pair edges of the graph

required
slot_size Scalar

Spectral width of frequency slot in GHz

required
consider_modulation_format Scalar

If True, consider modulation format to determine required slots

required
link_length_array Array

Array of link lengths

required
aggregate_slots Scalar

Number of slots to aggregate into a single action (First-Fit with aggregation)

required
guardband Scalar

Guard band in slots

required
directed_graph bool

Whether graph is directed (one fibre per link per transmission direction)

required
temperature Scalar

Temp. used for softmax differentiable approximation

required
window_size Scalar

Window size for weighted average of neighbouring cells in differentiable indexing

required
Source code in xlron/environments/dataclasses.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
@struct.dataclass
class EnvParams:
    """Dataclass to hold environment parameters. Parameters are immutable.

    Args:
        max_requests (chex.Scalar): Maximum number of requests in an episode
        incremental_loading (chex.Scalar): Incremental increase in traffic load (non-expiring requests)
        end_first_blocking (chex.Scalar): End episode on first blocking event
        continuous_operation (chex.Scalar): If True, do not reset the environment at the end of an episode
        edges (chex.Array): Two column array defining source-dest node-pair edges of the graph
        slot_size (chex.Scalar): Spectral width of frequency slot in GHz
        consider_modulation_format (chex.Scalar): If True, consider modulation format to determine required slots
        link_length_array (chex.Array): Array of link lengths
        aggregate_slots (chex.Scalar): Number of slots to aggregate into a single action (First-Fit with aggregation)
        guardband (chex.Scalar): Guard band in slots
        directed_graph (bool): Whether graph is directed (one fibre per link per transmission direction)
        temperature (chex.Scalar): Temp. used for softmax differentiable approximation
        window_size (chex.Scalar): Window size for weighted average of neighbouring cells in differentiable indexing
    """

    num_nodes: int = struct.field(pytree_node=False)
    num_links: int = struct.field(pytree_node=False)
    max_requests: int = struct.field(pytree_node=False)
    incremental_loading: bool = struct.field(pytree_node=False)
    end_first_blocking: bool = struct.field(pytree_node=False)
    terminate_on_episode_end: bool = struct.field(pytree_node=False)
    continuous_operation: bool = struct.field(pytree_node=False)
    edges: HashableArrayWrapper = struct.field(pytree_node=False)
    slot_size: int = struct.field(pytree_node=False)
    consider_modulation_format: bool = struct.field(pytree_node=False)
    link_length_array: HashableArrayWrapper = struct.field(pytree_node=False)
    aggregate_slots: int = struct.field(pytree_node=False)
    guardband: int = struct.field(pytree_node=False)
    directed_graph: bool = struct.field(pytree_node=False)
    maximise_throughput: bool = struct.field(pytree_node=False)
    reward_type: str = struct.field(pytree_node=False)
    values_bw: HashableArrayWrapper = struct.field(pytree_node=False)
    truncate_holding_time: bool = struct.field(pytree_node=False)
    traffic_array: bool = struct.field(pytree_node=False)
    pack_path_bits: bool = struct.field(pytree_node=False)
    relative_arrival_times: bool = struct.field(pytree_node=False)
    temperature: float = struct.field(pytree_node=False)
    differentiable: bool = struct.field(pytree_node=False)
    num_spectral_features: int = struct.field(pytree_node=False)
    line_graph_spectral_features: HashableArrayWrapper | None = struct.field(pytree_node=False)
    path_link_array: HashableArrayWrapper = struct.field(pytree_node=False)
    path_se_array: HashableArrayWrapper = struct.field(pytree_node=False)
    unique_se_values: HashableArrayWrapper = struct.field(pytree_node=False)
    k_paths: int = struct.field(pytree_node=False)
    link_resources: int = struct.field(pytree_node=False)
    k_paths: int = struct.field(pytree_node=False)
    mean_service_holding_time: float = struct.field(pytree_node=False)
    load: float = struct.field(pytree_node=False)
    arrival_rate: float = struct.field(pytree_node=False)
    random_traffic: bool = struct.field(pytree_node=False)
    include_no_op: bool = struct.field(pytree_node=False)  # Include a "no op" action
    transformer_obs_type: str = struct.field(pytree_node=False)
    use_gnn: bool = struct.field(pytree_node=False)
    profile: bool = struct.field(pytree_node=False)

EnvState

Dataclass to hold environment state. State is mutable and arrays are traced on JIT compilation.

Parameters:

Name Type Description Default
current_time Scalar

Current time in environment

required
holding_time Scalar

Holding time of current request

required
total_timesteps Scalar

Total timesteps in environment

required
total_requests Scalar

Total requests in environment

required
graph GraphsTuple

Graph tuple representing network state

required
full_link_slot_mask Array

Action mask for link slot action (including if slot actions are aggregated)

required
accepted_services Array

Number of accepted services

required
accepted_bitrate Array

Accepted bitrate

required
Source code in xlron/environments/dataclasses.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
@struct.dataclass
class EnvState:
    """Dataclass to hold environment state. State is mutable and arrays are traced on JIT compilation.

    Args:
        current_time (chex.Scalar): Current time in environment
        holding_time (chex.Scalar): Holding time of current request
        total_timesteps (chex.Scalar): Total timesteps in environment
        total_requests (chex.Scalar): Total requests in environment
        graph (jraph.GraphsTuple): Graph tuple representing network state
        full_link_slot_mask (chex.Array): Action mask for link slot action (including if slot actions are aggregated)
        accepted_services (chex.Array): Number of accepted services
        accepted_bitrate (chex.Array): Accepted bitrate
    """

    current_time: chex.Array
    holding_time: chex.Array
    arrival_time: chex.Array
    total_timesteps: chex.Array
    total_requests: chex.Array
    graph: jraph.GraphsTuple
    full_link_slot_mask: chex.Array
    accepted_services: chex.Array
    accepted_bitrate: chex.Array
    total_bitrate: chex.Array
    list_of_requests: chex.Array
    link_slot_array: chex.Array
    request_array: chex.Array
    link_slot_departure_array: chex.Array
    link_slot_mask: chex.Array
    traffic_matrix: chex.Array
    valid_mass: chex.Array

RSAMultibandEnvParams

Bases: RSAEnvParams

Dataclass to hold environment parameters for MultiBandRSA (RBSA).

Source code in xlron/environments/dataclasses.py
401
402
403
404
405
406
@struct.dataclass
class RSAMultibandEnvParams(RSAEnvParams):
    """Dataclass to hold environment parameters for MultiBandRSA (RBSA)."""

    gap_starts: HashableArrayWrapper = struct.field(pytree_node=False)
    gap_widths: HashableArrayWrapper = struct.field(pytree_node=False)

RSAMultibandEnvState

Bases: RSAEnvState

Dataclass to hold environment state for MultiBandRSA (RBSA).

Source code in xlron/environments/dataclasses.py
394
395
396
397
398
@struct.dataclass
class RSAMultibandEnvState(RSAEnvState):
    """Dataclass to hold environment state for MultiBandRSA (RBSA)."""

    pass

Dataclasses

DeepRMSAEnvState

Bases: RSAEnvState

Dataclass to hold environment state for DeepRMSA.

Parameters:

Name Type Description Default
path_stats Array

Path stats array containing

required
Source code in xlron/environments/dataclasses.py
218
219
220
221
222
223
224
225
226
227
228
229
230
@struct.dataclass
class DeepRMSAEnvState(RSAEnvState):
    """Dataclass to hold environment state for DeepRMSA.

    Args:
        path_stats (chex.Array): Path stats array containing
        1. Required slots on path
        2. Total available slots on path
        3. Size of 1st free spectrum block
        4. Avg. free block size
    """

    path_stats: chex.Array

EnvParams

Dataclass to hold environment parameters. Parameters are immutable.

Parameters:

Name Type Description Default
max_requests Scalar

Maximum number of requests in an episode

required
incremental_loading Scalar

Incremental increase in traffic load (non-expiring requests)

required
end_first_blocking Scalar

End episode on first blocking event

required
continuous_operation Scalar

If True, do not reset the environment at the end of an episode

required
edges Array

Two column array defining source-dest node-pair edges of the graph

required
slot_size Scalar

Spectral width of frequency slot in GHz

required
consider_modulation_format Scalar

If True, consider modulation format to determine required slots

required
link_length_array Array

Array of link lengths

required
aggregate_slots Scalar

Number of slots to aggregate into a single action (First-Fit with aggregation)

required
guardband Scalar

Guard band in slots

required
directed_graph bool

Whether graph is directed (one fibre per link per transmission direction)

required
temperature Scalar

Temp. used for softmax differentiable approximation

required
window_size Scalar

Window size for weighted average of neighbouring cells in differentiable indexing

required
Source code in xlron/environments/dataclasses.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
@struct.dataclass
class EnvParams:
    """Dataclass to hold environment parameters. Parameters are immutable.

    Args:
        max_requests (chex.Scalar): Maximum number of requests in an episode
        incremental_loading (chex.Scalar): Incremental increase in traffic load (non-expiring requests)
        end_first_blocking (chex.Scalar): End episode on first blocking event
        continuous_operation (chex.Scalar): If True, do not reset the environment at the end of an episode
        edges (chex.Array): Two column array defining source-dest node-pair edges of the graph
        slot_size (chex.Scalar): Spectral width of frequency slot in GHz
        consider_modulation_format (chex.Scalar): If True, consider modulation format to determine required slots
        link_length_array (chex.Array): Array of link lengths
        aggregate_slots (chex.Scalar): Number of slots to aggregate into a single action (First-Fit with aggregation)
        guardband (chex.Scalar): Guard band in slots
        directed_graph (bool): Whether graph is directed (one fibre per link per transmission direction)
        temperature (chex.Scalar): Temp. used for softmax differentiable approximation
        window_size (chex.Scalar): Window size for weighted average of neighbouring cells in differentiable indexing
    """

    num_nodes: int = struct.field(pytree_node=False)
    num_links: int = struct.field(pytree_node=False)
    max_requests: int = struct.field(pytree_node=False)
    incremental_loading: bool = struct.field(pytree_node=False)
    end_first_blocking: bool = struct.field(pytree_node=False)
    terminate_on_episode_end: bool = struct.field(pytree_node=False)
    continuous_operation: bool = struct.field(pytree_node=False)
    edges: HashableArrayWrapper = struct.field(pytree_node=False)
    slot_size: int = struct.field(pytree_node=False)
    consider_modulation_format: bool = struct.field(pytree_node=False)
    link_length_array: HashableArrayWrapper = struct.field(pytree_node=False)
    aggregate_slots: int = struct.field(pytree_node=False)
    guardband: int = struct.field(pytree_node=False)
    directed_graph: bool = struct.field(pytree_node=False)
    maximise_throughput: bool = struct.field(pytree_node=False)
    reward_type: str = struct.field(pytree_node=False)
    values_bw: HashableArrayWrapper = struct.field(pytree_node=False)
    truncate_holding_time: bool = struct.field(pytree_node=False)
    traffic_array: bool = struct.field(pytree_node=False)
    pack_path_bits: bool = struct.field(pytree_node=False)
    relative_arrival_times: bool = struct.field(pytree_node=False)
    temperature: float = struct.field(pytree_node=False)
    differentiable: bool = struct.field(pytree_node=False)
    num_spectral_features: int = struct.field(pytree_node=False)
    line_graph_spectral_features: HashableArrayWrapper | None = struct.field(pytree_node=False)
    path_link_array: HashableArrayWrapper = struct.field(pytree_node=False)
    path_se_array: HashableArrayWrapper = struct.field(pytree_node=False)
    unique_se_values: HashableArrayWrapper = struct.field(pytree_node=False)
    k_paths: int = struct.field(pytree_node=False)
    link_resources: int = struct.field(pytree_node=False)
    k_paths: int = struct.field(pytree_node=False)
    mean_service_holding_time: float = struct.field(pytree_node=False)
    load: float = struct.field(pytree_node=False)
    arrival_rate: float = struct.field(pytree_node=False)
    random_traffic: bool = struct.field(pytree_node=False)
    include_no_op: bool = struct.field(pytree_node=False)  # Include a "no op" action
    transformer_obs_type: str = struct.field(pytree_node=False)
    use_gnn: bool = struct.field(pytree_node=False)
    profile: bool = struct.field(pytree_node=False)

EnvState

Dataclass to hold environment state. State is mutable and arrays are traced on JIT compilation.

Parameters:

Name Type Description Default
current_time Scalar

Current time in environment

required
holding_time Scalar

Holding time of current request

required
total_timesteps Scalar

Total timesteps in environment

required
total_requests Scalar

Total requests in environment

required
graph GraphsTuple

Graph tuple representing network state

required
full_link_slot_mask Array

Action mask for link slot action (including if slot actions are aggregated)

required
accepted_services Array

Number of accepted services

required
accepted_bitrate Array

Accepted bitrate

required
Source code in xlron/environments/dataclasses.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
@struct.dataclass
class EnvState:
    """Dataclass to hold environment state. State is mutable and arrays are traced on JIT compilation.

    Args:
        current_time (chex.Scalar): Current time in environment
        holding_time (chex.Scalar): Holding time of current request
        total_timesteps (chex.Scalar): Total timesteps in environment
        total_requests (chex.Scalar): Total requests in environment
        graph (jraph.GraphsTuple): Graph tuple representing network state
        full_link_slot_mask (chex.Array): Action mask for link slot action (including if slot actions are aggregated)
        accepted_services (chex.Array): Number of accepted services
        accepted_bitrate (chex.Array): Accepted bitrate
    """

    current_time: chex.Array
    holding_time: chex.Array
    arrival_time: chex.Array
    total_timesteps: chex.Array
    total_requests: chex.Array
    graph: jraph.GraphsTuple
    full_link_slot_mask: chex.Array
    accepted_services: chex.Array
    accepted_bitrate: chex.Array
    total_bitrate: chex.Array
    list_of_requests: chex.Array
    link_slot_array: chex.Array
    request_array: chex.Array
    link_slot_departure_array: chex.Array
    link_slot_mask: chex.Array
    traffic_matrix: chex.Array
    valid_mass: chex.Array

GNModelEnvParams

Bases: RSAEnvParams

Dataclass to hold environment state for GN model environments.

Source code in xlron/environments/dataclasses.py
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
@struct.dataclass
class GNModelEnvParams(RSAEnvParams):
    """Dataclass to hold environment state for GN model environments."""

    ref_lambda: chex.Scalar = struct.field(pytree_node=False)
    max_spans: chex.Scalar = struct.field(pytree_node=False)
    max_span_length: chex.Scalar = struct.field(pytree_node=False)
    nonlinear_coeff: chex.Scalar = struct.field(pytree_node=False)
    raman_gain_slope: chex.Scalar = struct.field(pytree_node=False)
    attenuation: chex.Scalar = struct.field(pytree_node=False)
    attenuation_bar: chex.Scalar = struct.field(pytree_node=False)
    dispersion_coeff: chex.Scalar = struct.field(pytree_node=False)
    dispersion_slope: chex.Scalar = struct.field(pytree_node=False)
    transceiver_snr: HashableArrayWrapper = struct.field(pytree_node=False)
    amplifier_noise_figure: HashableArrayWrapper = struct.field(pytree_node=False)
    coherent: bool = struct.field(pytree_node=False)
    num_roadms: chex.Scalar = struct.field(pytree_node=False)
    roadm_loss: chex.Scalar = struct.field(pytree_node=False)
    roadm_express_loss: HashableArrayWrapper = struct.field(pytree_node=False)
    roadm_add_drop_loss: HashableArrayWrapper = struct.field(pytree_node=False)
    roadm_noise_figure: HashableArrayWrapper = struct.field(pytree_node=False)
    num_spans: chex.Scalar = struct.field(pytree_node=False)
    launch_power_type: str = struct.field(pytree_node=False)
    snr_margin: chex.Scalar = struct.field(pytree_node=False)
    max_snr: chex.Scalar = struct.field(pytree_node=False)
    max_power: chex.Scalar = struct.field(pytree_node=False)
    min_power: chex.Scalar = struct.field(pytree_node=False)
    step_power: chex.Scalar = struct.field(pytree_node=False)
    last_fit: bool = struct.field(pytree_node=False)
    max_power_per_fibre: chex.Scalar = struct.field(pytree_node=False)
    default_launch_power: chex.Scalar = struct.field(pytree_node=False)
    power_per_channel: chex.Scalar = struct.field(pytree_node=False)  # linear Watts
    mod_format_correction: bool = struct.field(pytree_node=False)
    monitor_active_lightpaths: bool = struct.field(
        pytree_node=False
    )  # Monitor active lightpaths for throughput calculation
    gap_starts: HashableArrayWrapper = struct.field(pytree_node=False)
    gap_widths: HashableArrayWrapper = struct.field(pytree_node=False)
    uniform_spans: bool = struct.field(pytree_node=False)
    min_snr: chex.Scalar = struct.field(pytree_node=False)
    fec_threshold: chex.Scalar = struct.field(pytree_node=False)
    band_slot_order_ff: HashableArrayWrapper = struct.field(
        pytree_node=False
    )  # Slot permutation for band-preference first-fit (empty if unused)
    band_slot_order_lf: HashableArrayWrapper = struct.field(
        pytree_node=False
    )  # Slot permutation for band-preference last-fit (empty if unused)
    slot_centre_freq_array: HashableArrayWrapper = struct.field(
        pytree_node=False
    )  # Per-slot centre frequencies in relative GHz offset from ref_lambda
    num_subchannels: int = struct.field(pytree_node=False)  # Nyquist subchannels per slot for SPM

GNModelEnvState

Bases: RSAEnvState

Dataclass to hold environment state for RSA with GN model.

Source code in xlron/environments/dataclasses.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
@struct.dataclass
class GNModelEnvState(RSAEnvState):
    """Dataclass to hold environment state for RSA with GN model."""

    link_snr_array: chex.Array  # Available SNR on each link
    channel_centre_bw_array: chex.Array  # Channel centre bandwidth for each active connection
    path_index_array: (
        chex.Array
    )  # Contains indices of lightpaths in use on slots (used for lightpath SNR calculation)
    channel_power_array: chex.Array  # Channel power for each active connection
    channel_centre_bw_array_prev: (
        chex.Array
    )  # Channel centre bandwidth for each active connection in previous timestep
    path_index_array_prev: (
        chex.Array
    )  # Contains indices of lightpaths in use on slots in previous timestep
    channel_power_array_prev: (
        chex.Array
    )  # Channel power for each active connection in previous timestep
    channel_centre_freq_array: chex.Array  # Per-slot centre frequency in GHz
    channel_centre_freq_array_prev: chex.Array  # Previous timestep centre frequency for undo
    launch_power_array: chex.Array  # Launch power array

HashableArrayWrapper

Bases: Generic[T]

Wrapper for making arrays hashable. In order to access pre-computed data, such as shortest paths between node-pairs or the constituent links of a path, within a jitted function, we need to make the arrays containing this data hashable. By defining this wrapper, we can define a hash method that returns a hash of the array's bytes, thus making the array hashable. From: https://github.com/google/jax/issues/4572#issuecomment-709677518

Source code in xlron/environments/dataclasses.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class HashableArrayWrapper(Generic[T]):
    """Wrapper for making arrays hashable.
    In order to access pre-computed data, such as shortest paths between node-pairs or the constituent links of a path,
    within a jitted function, we need to make the arrays containing this data hashable. By defining this wrapper, we can
    define a __hash__ method that returns a hash of the array's bytes, thus making the array hashable.
    From: https://github.com/google/jax/issues/4572#issuecomment-709677518
    """

    def __init__(self, val: Array):
        self.val = val

    def __getattribute__(self, prop):
        if prop == "val" or prop == "__hash__" or prop == "__eq__":
            return super(HashableArrayWrapper, self).__getattribute__(prop)
        return getattr(self.val, prop)

    def __getitem__(self, key):
        return self.val[key]

    def __setitem__(self, key, val):
        self.val[key] = val

    def __hash__(self):
        return hash(self.val.tobytes())

    def __eq__(self, other):
        if isinstance(other, HashableArrayWrapper):
            return self.__hash__() == other.__hash__()

        f = getattr(self.val, "__eq__")
        return f(self, other)

LogEnvState

Dataclass to hold environment state for logging.

Parameters:

Name Type Description Default
env_state EnvState

Environment state

required
lengths Scalar

Lengths

required
returns Scalar

Returns

required
cum_returns Scalar

Cumulative returns

required
accepted_services Scalar

Accepted services

required
accepted_bitrate Scalar

Accepted bitrate

required
total_bitrate Scalar

Total bitrate requested

required
utilisation Scalar

Network utilisation

required
terminal Scalar

Terminal flag (true termination condition met)

required
truncated Scalar

Truncated flag (max steps reached)

required
Source code in xlron/environments/dataclasses.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@struct.dataclass
class LogEnvState:
    """Dataclass to hold environment state for logging.

    Args:
        env_state (EnvState): Environment state
        lengths (chex.Scalar): Lengths
        returns (chex.Scalar): Returns
        cum_returns (chex.Scalar): Cumulative returns
        accepted_services (chex.Scalar): Accepted services
        accepted_bitrate (chex.Scalar): Accepted bitrate
        total_bitrate (chex.Scalar): Total bitrate requested
        utilisation (chex.Scalar): Network utilisation
        terminal (chex.Scalar): Terminal flag (true termination condition met)
        truncated (chex.Scalar): Truncated flag (max steps reached)
    """

    env_state: EnvState
    lengths: chex.Array
    returns: chex.Array
    cum_returns: chex.Array
    accepted_services: chex.Array
    accepted_bitrate: chex.Array
    total_bitrate: chex.Array
    utilisation: chex.Array
    terminal: chex.Array
    truncated: chex.Array

MultiBandRSAEnvParams

Bases: RSAEnvParams

Dataclass to hold environment parameters for MultiBandRSA (RBSA).

Source code in xlron/environments/dataclasses.py
266
267
268
269
270
271
@struct.dataclass
class MultiBandRSAEnvParams(RSAEnvParams):
    """Dataclass to hold environment parameters for MultiBandRSA (RBSA)."""

    gap_start: chex.Scalar = struct.field(pytree_node=False)
    gap_width: chex.Scalar = struct.field(pytree_node=False)

MultiBandRSAEnvState

Bases: RSAEnvState

Dataclass to hold environment state for MultiBandRSA (RBSA).

Source code in xlron/environments/dataclasses.py
259
260
261
262
263
@struct.dataclass
class MultiBandRSAEnvState(RSAEnvState):
    """Dataclass to hold environment state for MultiBandRSA (RBSA)."""

    pass

RMSAGNModelEnvParams

Bases: GNModelEnvParams

Dataclass to hold environment params for RMSA with GN model.

Parameters:

Name Type Description Default
link_snr_array Array

Link SNR array

required
Source code in xlron/environments/dataclasses.py
367
368
369
370
371
372
373
374
375
376
@struct.dataclass
class RMSAGNModelEnvParams(GNModelEnvParams):
    """Dataclass to hold environment params for RMSA with GN model.

    Args:
        link_snr_array (chex.Array): Link SNR array
    """

    modulations_array: HashableArrayWrapper = struct.field(pytree_node=False)
    fec_rate: chex.Scalar = struct.field(pytree_node=False)

RMSAGNModelEnvState

Bases: GNModelEnvState

Dataclass to hold environment state for RMSA with GN model.

Parameters:

Name Type Description Default
link_snr_array Array

Link SNR array

required
Source code in xlron/environments/dataclasses.py
379
380
381
382
383
384
385
386
387
388
389
390
391
@struct.dataclass
class RMSAGNModelEnvState(GNModelEnvState):
    """Dataclass to hold environment state for RMSA with GN model.

    Args:
        link_snr_array (chex.Array): Link SNR array
    """

    modulation_format_index_array: chex.Array  # Modulation format index for each active connection
    modulation_format_index_array_prev: (
        chex.Array
    )  # Modulation format index for each active connection in previous timestep
    mod_format_mask: chex.Array  # Modulation format mask

RSAEnvParams

Bases: EnvParams

Dataclass to hold environment parameters for RSA.

Parameters:

Name Type Description Default
num_nodes Scalar

Number of nodes

required
num_links Scalar

Number of links

required
link_resources Scalar

Number of link resources

required
k_paths Scalar

Number of paths

required
mean_service_holding_time Scalar

Mean service holding time

required
load Scalar

Load

required
arrival_rate Scalar

Arrival rate

required
path_link_array Array

Path link array

required
random_traffic bool

Random traffic matrix for RSA on each reset (else uniform or custom)

required
max_slots Scalar

Maximum number of slots

required
path_se_array Array

Path spectral efficiency array

required
deterministic_requests bool

If True, use deterministic requests

required
multiple_topologies bool

If True, use multiple topologies

required
Source code in xlron/environments/dataclasses.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
@struct.dataclass
class RSAEnvParams(EnvParams):
    """Dataclass to hold environment parameters for RSA.

    Args:
        num_nodes (chex.Scalar): Number of nodes
        num_links (chex.Scalar): Number of links
        link_resources (chex.Scalar): Number of link resources
        k_paths (chex.Scalar): Number of paths
        mean_service_holding_time (chex.Scalar): Mean service holding time
        load (chex.Scalar): Load
        arrival_rate (chex.Scalar): Arrival rate
        path_link_array (chex.Array): Path link array
        random_traffic (bool): Random traffic matrix for RSA on each reset (else uniform or custom)
        max_slots (chex.Scalar): Maximum number of slots
        path_se_array (chex.Array): Path spectral efficiency array
        deterministic_requests (bool): If True, use deterministic requests
        multiple_topologies (bool): If True, use multiple topologies
    """

    max_slots: chex.Scalar = struct.field(pytree_node=False)
    deterministic_requests: bool = struct.field(pytree_node=False)
    multiple_topologies: bool = struct.field(pytree_node=False)
    log_actions: bool = struct.field(pytree_node=False)
    disable_node_features: bool = struct.field(pytree_node=False)

RSAEnvState

Bases: EnvState

Dataclass to hold environment state for RSA.

Parameters:

Name Type Description Default
link_slot_array Array

Link slot array

required
request_array Array

Request array

required
link_slot_departure_array Array

Link slot departure array

required
link_slot_mask Array

Link slot mask

required
traffic_matrix Array

Traffic matrix

required
Source code in xlron/environments/dataclasses.py
176
177
178
179
180
181
182
183
184
185
186
187
188
@struct.dataclass
class RSAEnvState(EnvState):
    """Dataclass to hold environment state for RSA.

    Args:
        link_slot_array (chex.Array): Link slot array
        request_array (chex.Array): Request array
        link_slot_departure_array (chex.Array): Link slot departure array
        link_slot_mask (chex.Array): Link slot mask
        traffic_matrix (chex.Array): Traffic matrix
    """

    pass

RSAGNModelEnvParams

Bases: GNModelEnvParams

Dataclass to hold environment params for RSA with GN model.

Source code in xlron/environments/dataclasses.py
351
352
353
354
355
@struct.dataclass
class RSAGNModelEnvParams(GNModelEnvParams):
    """Dataclass to hold environment params for RSA with GN model."""

    pass

RSAGNModelEnvState

Bases: GNModelEnvState

Dataclass to hold environment state for RSA with GN model.

Source code in xlron/environments/dataclasses.py
358
359
360
361
362
363
364
@struct.dataclass
class RSAGNModelEnvState(GNModelEnvState):
    """Dataclass to hold environment state for RSA with GN model."""

    active_lightpaths_array: chex.Array  # Active lightpath array. 1 x M array. Each value is a lightpath index. Used to calculate total throughput.
    active_lightpaths_array_departure: chex.Array  # Active lightpath array departure time.
    throughput: chex.Array  # Current network throughput

RSAMultibandEnvParams

Bases: RSAEnvParams

Dataclass to hold environment parameters for MultiBandRSA (RBSA).

Source code in xlron/environments/dataclasses.py
401
402
403
404
405
406
@struct.dataclass
class RSAMultibandEnvParams(RSAEnvParams):
    """Dataclass to hold environment parameters for MultiBandRSA (RBSA)."""

    gap_starts: HashableArrayWrapper = struct.field(pytree_node=False)
    gap_widths: HashableArrayWrapper = struct.field(pytree_node=False)

RSAMultibandEnvState

Bases: RSAEnvState

Dataclass to hold environment state for MultiBandRSA (RBSA).

Source code in xlron/environments/dataclasses.py
394
395
396
397
398
@struct.dataclass
class RSAMultibandEnvState(RSAEnvState):
    """Dataclass to hold environment state for MultiBandRSA (RBSA)."""

    pass

RWALightpathReuseEnvState

Bases: RSAEnvState

Dataclass to hold environment state for RWA with lightpath reuse.

Parameters:

Name Type Description Default
path_index_array Array

Contains indices of lightpaths in use on slots

required
path_capacity_array Array

Contains remaining capacity of each lightpath

required
link_capacity_array Array

Contains remaining capacity of lightpath on each link-slot

required
Source code in xlron/environments/dataclasses.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
@struct.dataclass
class RWALightpathReuseEnvState(RSAEnvState):
    """Dataclass to hold environment state for RWA with lightpath reuse.

    Args:
        path_index_array (chex.Array): Contains indices of lightpaths in use on slots
        path_capacity_array (chex.Array): Contains remaining capacity of each lightpath
        link_capacity_array (chex.Array): Contains remaining capacity of lightpath on each link-slot
    """

    path_index_array: chex.Array  # Contains indices of lightpaths in use on slots
    path_capacity_array: chex.Array  # Contains remaining capacity of each lightpath
    link_capacity_array: chex.Array  # Contains remaining capacity of lightpath on each link-slot
    time_since_last_departure: chex.Array  # Time since last departure

VONEEnvParams

Bases: RSAEnvParams

Dataclass to hold environment parameters for VONE.

Parameters:

Name Type Description Default
node_resources Scalar

Number of node resources

required
max_edges Scalar

Maximum number of edges

required
min_node_resources Scalar

Minimum number of node resources

required
max_node_resources Scalar

Maximum number of node resources

required
Source code in xlron/environments/dataclasses.py
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
@struct.dataclass
class VONEEnvParams(RSAEnvParams):
    """Dataclass to hold environment parameters for VONE.

    Args:
        node_resources (chex.Scalar): Number of node resources
        max_edges (chex.Scalar): Maximum number of edges
        min_node_resources (chex.Scalar): Minimum number of node resources
        max_node_resources (chex.Scalar): Maximum number of node resources
    """

    node_resources: chex.Scalar = struct.field(pytree_node=False)
    max_edges: chex.Scalar = struct.field(pytree_node=False)
    min_node_resources: chex.Scalar = struct.field(pytree_node=False)
    max_node_resources: chex.Scalar = struct.field(pytree_node=False)

VONEEnvState

Bases: RSAEnvState

Dataclass to hold environment state for VONE.

Parameters:

Name Type Description Default
node_capacity_array Array

Node capacity array

required
node_resource_array Array

Node resource array

required
node_departure_array Array

Node departure array

required
action_counter Array

Action counter

required
action_history Array

Action history

required
node_mask_s Array

Node mask for source node

required
node_mask_d Array

Node mask for destination node

required
virtual_topology_patterns Array

Virtual topology patterns

required
values_nodes Array

Values for nodes

required
Source code in xlron/environments/dataclasses.py
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
@struct.dataclass
class VONEEnvState(RSAEnvState):
    """Dataclass to hold environment state for VONE.

    Args:
        node_capacity_array (chex.Array): Node capacity array
        node_resource_array (chex.Array): Node resource array
        node_departure_array (chex.Array): Node departure array
        action_counter (chex.Array): Action counter
        action_history (chex.Array): Action history
        node_mask_s (chex.Array): Node mask for source node
        node_mask_d (chex.Array): Node mask for destination node
        virtual_topology_patterns (chex.Array): Virtual topology patterns
        values_nodes (chex.Array): Values for nodes
    """

    node_capacity_array: chex.Array
    node_resource_array: chex.Array
    node_departure_array: chex.Array
    action_counter: chex.Array
    action_history: chex.Array
    node_mask_s: chex.Array
    node_mask_d: chex.Array
    virtual_topology_patterns: chex.Array
    values_nodes: chex.Array

Environment wrappers

JitProfiler

Wall-clock profiler for JAX JIT-compiled code.

On CPU

• Uses host callbacks for fine-grained section timing.

On GPU

• Automatically switches to first-call-only timing. • Measures compilation + first execution latency. • Fine-grained per-call timings are intentionally disabled to avoid misleading synchronization artifacts.

This profiler records host-side timestamps using jax.debug.callback, allowing coarse wall-clock profiling of sections inside JIT-compiled code. Profiling is designed to be gated by a static Python boolean (e.g. params.profile) so that all profiling logic is resolved at trace time and introduces zero runtime overhead when disabled.

Features

• Manual markers via mark(), start(), and end() • Function-level profiling via call() • Automatic jax.named_scope integration for clearer JAX traces • Safety checks to ensure the profiling flag is static at trace time • Aggregation across repeated calls with a readable summary table

Basic usage inside JIT (manual markers):

if params.profile:
    jit_profiler.mark("process_action:start")
with jax.named_scope("process_action"):
    ...
if params.profile:
    jit_profiler.mark("process_action:end")

Function-wrapping usage inside JIT (recommended):

action_mask = jit_profiler.call(
    params.profile,
    mask_slots,
    state,
    params,
    name="mask_actions",  # optional, defaults to fn.__name__
)

Block-style usage inside JIT:

jit_profiler.start(params.profile, "action_logic")
...
jit_profiler.end(params.profile, "action_logic")

Notes

• The enabled flag must be a Python bool known at trace time (e.g. params.profile with pytree_node=False). • Passing a traced or JAX boolean will raise a TypeError. • All timing is wall-clock time measured on the host, not device time.

After execution (outside JIT):

jit_profiler.summary()
Source code in xlron/environments/wrappers.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
class JitProfiler:
    """Wall-clock profiler for JAX JIT-compiled code.

    On CPU:
        • Uses host callbacks for fine-grained section timing.

    On GPU:
        • Automatically switches to *first-call-only* timing.
        • Measures compilation + first execution latency.
        • Fine-grained per-call timings are intentionally disabled
          to avoid misleading synchronization artifacts.

    This profiler records host-side timestamps using `jax.debug.callback`,
    allowing coarse wall-clock profiling of sections inside JIT-compiled code.
    Profiling is designed to be gated by a *static* Python boolean (e.g.
    `params.profile`) so that all profiling logic is resolved at trace time and
    introduces zero runtime overhead when disabled.

    Features
    --------
    • Manual markers via `mark()`, `start()`, and `end()`
    • Function-level profiling via `call()`
    • Automatic `jax.named_scope` integration for clearer JAX traces
    • Safety checks to ensure the profiling flag is static at trace time
    • Aggregation across repeated calls with a readable summary table

    Basic usage inside JIT (manual markers):

        if params.profile:
            jit_profiler.mark("process_action:start")
        with jax.named_scope("process_action"):
            ...
        if params.profile:
            jit_profiler.mark("process_action:end")

    Function-wrapping usage inside JIT (recommended):

        action_mask = jit_profiler.call(
            params.profile,
            mask_slots,
            state,
            params,
            name="mask_actions",  # optional, defaults to fn.__name__
        )

    Block-style usage inside JIT:

        jit_profiler.start(params.profile, "action_logic")
        ...
        jit_profiler.end(params.profile, "action_logic")

    Notes
    -----
    • The `enabled` flag *must* be a Python `bool` known at trace time
        (e.g. `params.profile` with `pytree_node=False`).
    • Passing a traced or JAX boolean will raise a `TypeError`.
    • All timing is wall-clock time measured on the host, not device time.

    After execution (outside JIT):

        jit_profiler.summary()
    """

    def __init__(self):
        self._timestamps: list[tuple[str, float]] = []
        self._backend = jax.default_backend()
        self._is_gpu = self._backend == "gpu"
        self._warned_gpu = False
        self._seen_first_call: set[str] = set()

        if self._is_gpu and not self._warned_gpu:
            print(
                "JitProfiler warning:\n"
                "  GPU backend detected. Fine-grained host-side timings inside JIT\n"
                "  are unreliable due to asynchronous execution.\n"
                "  `call()` will record ONLY first-call (compile + first execution)\n"
                "  latency per section.\n"
                "  Use jax.profiler / TensorBoard / Nsight for steady-state GPU timing."
            )
            self._warned_gpu = True

    def _record(self, label):
        self._timestamps.append((str(label), time.time()))

    @staticmethod
    def _assert_static_bool(x, name="enabled"):
        if not isinstance(x, bool):
            raise TypeError(
                f"JitProfiler.call(): `{name}` must be a Python bool "
                "(static at trace time), e.g. params.profile"
            )

    def mark(self, label: str):
        """Insert a timing marker. Safe to call inside JIT."""
        jax.debug.callback(self._record, label)

    def reset(self):
        """Clear all recorded timestamps and first-call tracking."""
        self._timestamps.clear()
        self._seen_first_call.clear()

    def start(self, enabled: bool, name: str):
        """Insert a start marker for a block."""
        self._assert_static_bool(enabled)
        if enabled:
            self.mark(f"{name}:start")

    def end(self, enabled: bool, name: str):
        """Insert an end marker for a block."""
        self._assert_static_bool(enabled)
        if enabled:
            self.mark(f"{name}:end")

    # -----------------------
    # GPU-only first-call path
    # -----------------------

    def _call_first_only(self, enabled: bool, fn, *args, name=None, **kwargs):
        """Record compile + first execution latency (GPU only)."""
        self._assert_static_bool(enabled)

        section = name or fn.__name__

        if section in self._seen_first_call or not enabled:
            return fn(*args, **kwargs)

        self._seen_first_call.add(section)

        start_time = time.time()
        out = fn(*args, **kwargs)
        out = jax.block_until_ready(out)
        elapsed = time.time() - start_time

        # Append synthetic start/end entries so summary sees them
        self._timestamps.append((f"{section}:start", start_time))
        self._timestamps.append((f"{section}:end", start_time + elapsed))

        # Also keep a :first entry for clarity
        self._timestamps.append((f"{section}:first", elapsed))

        return out

    # -----------------------
    # Public API
    # -----------------------

    def call(self, enabled: bool, fn, *args, name: str | None = None, **kwargs):
        """Profile a function call inside JIT-compiled code.

        CPU:
            Fine-grained section timing via callbacks.

        GPU:
            Records only first-call (compile + first execution) latency.
        """
        self._assert_static_bool(enabled)

        if self._is_gpu:
            return self._call_first_only(enabled, fn, *args, name=name, **kwargs)

        # CPU path
        section = name or fn.__name__
        if enabled:
            self.mark(f"{section}:start")
        with jax.named_scope(section):
            out = fn(*args, **kwargs)
        if enabled:
            self.mark(f"{section}:end")
        return out

    # -----------------------
    # Summary
    # -----------------------

    def summary(self):
        """Print timing breakdown from collected start/end marker pairs.

        Expects markers in the format "name:start" and "name:end".
        Aggregates across repeated calls (e.g. many step_env invocations).

        Also reports GPU `:first` entries if present.
        """
        if len(self._timestamps) < 2:
            print("JitProfiler: not enough markers recorded.")
            return

        totals: dict[str, float] = defaultdict(float)
        counts: dict[str, int] = defaultdict(int)
        order: list[str] = []
        pending: dict[str, float] = {}

        # Aggregate :start / :end
        for label, ts in self._timestamps:
            if label.endswith(":start"):
                name = label[: -len(":start")]
                pending[name] = ts
            elif label.endswith(":end"):
                name = label[: -len(":end")]
                if name in pending:
                    elapsed = ts - pending.pop(name)
                    totals[name] += elapsed
                    counts[name] += 1
                    if name not in order:
                        order.append(name)

        # Include any :first entries for GPU
        for label, val in self._timestamps:
            if label.endswith(":first"):
                name = label[: -len(":first")]
                totals[name] += val
                counts[name] += 1
                if name not in order:
                    order.append(name)

        if not totals:
            print("JitProfiler: no matched start/end pairs found.")
            return

        total_wall = sum(totals.values())
        header = f"{'Section':<30} {'Calls':>8} {'Total (s)':>10} {'Mean (us)':>10} {'%':>6}"
        print("\n" + "=" * len(header))
        print("JIT PROFILER SUMMARY")
        if self._backend.upper() == "GPU":
            print("  (GPU backend --> profile only indicates first-call latencies)")
        print("=" * len(header))
        print(header)
        print("-" * len(header))
        for name in order:
            n = counts[name]
            total_t = totals[name]
            mean_us = 1e6 * total_t / n
            pct = 100.0 * total_t / total_wall if total_wall > 0 else 0.0
            print(f"{name:<30} {n:>8} {total_t:>10.4f} {mean_us:>10.1f} {pct:>5.1f}%")
        print("-" * len(header))
        print(f"{'TOTAL':<30} {'':>8} {total_wall:>10.4f}")
        print("=" * len(header) + "\n")

call(enabled, fn, *args, name=None, **kwargs)

Profile a function call inside JIT-compiled code.

CPU

Fine-grained section timing via callbacks.

GPU

Records only first-call (compile + first execution) latency.

Source code in xlron/environments/wrappers.py
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
def call(self, enabled: bool, fn, *args, name: str | None = None, **kwargs):
    """Profile a function call inside JIT-compiled code.

    CPU:
        Fine-grained section timing via callbacks.

    GPU:
        Records only first-call (compile + first execution) latency.
    """
    self._assert_static_bool(enabled)

    if self._is_gpu:
        return self._call_first_only(enabled, fn, *args, name=name, **kwargs)

    # CPU path
    section = name or fn.__name__
    if enabled:
        self.mark(f"{section}:start")
    with jax.named_scope(section):
        out = fn(*args, **kwargs)
    if enabled:
        self.mark(f"{section}:end")
    return out

end(enabled, name)

Insert an end marker for a block.

Source code in xlron/environments/wrappers.py
359
360
361
362
363
def end(self, enabled: bool, name: str):
    """Insert an end marker for a block."""
    self._assert_static_bool(enabled)
    if enabled:
        self.mark(f"{name}:end")

mark(label)

Insert a timing marker. Safe to call inside JIT.

Source code in xlron/environments/wrappers.py
344
345
346
def mark(self, label: str):
    """Insert a timing marker. Safe to call inside JIT."""
    jax.debug.callback(self._record, label)

reset()

Clear all recorded timestamps and first-call tracking.

Source code in xlron/environments/wrappers.py
348
349
350
351
def reset(self):
    """Clear all recorded timestamps and first-call tracking."""
    self._timestamps.clear()
    self._seen_first_call.clear()

start(enabled, name)

Insert a start marker for a block.

Source code in xlron/environments/wrappers.py
353
354
355
356
357
def start(self, enabled: bool, name: str):
    """Insert a start marker for a block."""
    self._assert_static_bool(enabled)
    if enabled:
        self.mark(f"{name}:start")

summary()

Print timing breakdown from collected start/end marker pairs.

Expects markers in the format "name:start" and "name:end". Aggregates across repeated calls (e.g. many step_env invocations).

Also reports GPU :first entries if present.

Source code in xlron/environments/wrappers.py
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def summary(self):
    """Print timing breakdown from collected start/end marker pairs.

    Expects markers in the format "name:start" and "name:end".
    Aggregates across repeated calls (e.g. many step_env invocations).

    Also reports GPU `:first` entries if present.
    """
    if len(self._timestamps) < 2:
        print("JitProfiler: not enough markers recorded.")
        return

    totals: dict[str, float] = defaultdict(float)
    counts: dict[str, int] = defaultdict(int)
    order: list[str] = []
    pending: dict[str, float] = {}

    # Aggregate :start / :end
    for label, ts in self._timestamps:
        if label.endswith(":start"):
            name = label[: -len(":start")]
            pending[name] = ts
        elif label.endswith(":end"):
            name = label[: -len(":end")]
            if name in pending:
                elapsed = ts - pending.pop(name)
                totals[name] += elapsed
                counts[name] += 1
                if name not in order:
                    order.append(name)

    # Include any :first entries for GPU
    for label, val in self._timestamps:
        if label.endswith(":first"):
            name = label[: -len(":first")]
            totals[name] += val
            counts[name] += 1
            if name not in order:
                order.append(name)

    if not totals:
        print("JitProfiler: no matched start/end pairs found.")
        return

    total_wall = sum(totals.values())
    header = f"{'Section':<30} {'Calls':>8} {'Total (s)':>10} {'Mean (us)':>10} {'%':>6}"
    print("\n" + "=" * len(header))
    print("JIT PROFILER SUMMARY")
    if self._backend.upper() == "GPU":
        print("  (GPU backend --> profile only indicates first-call latencies)")
    print("=" * len(header))
    print(header)
    print("-" * len(header))
    for name in order:
        n = counts[name]
        total_t = totals[name]
        mean_us = 1e6 * total_t / n
        pct = 100.0 * total_t / total_wall if total_wall > 0 else 0.0
        print(f"{name:<30} {n:>8} {total_t:>10.4f} {mean_us:>10.1f} {pct:>5.1f}%")
    print("-" * len(header))
    print(f"{'TOTAL':<30} {'':>8} {total_wall:>10.4f}")
    print("=" * len(header) + "\n")

LogWrapper

Bases: GymnaxWrapper

Log the episode returns and lengths. Modified from: https://github.com/RobertTLange/gymnax/blob/master/gymnax/wrappers/purerl.py

Source code in xlron/environments/wrappers.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class LogWrapper(GymnaxWrapper):
    """Log the episode returns and lengths.
    Modified from: https://github.com/RobertTLange/gymnax/blob/master/gymnax/wrappers/purerl.py
    """

    def __init__(self, env: environment.Environment):
        super().__init__(env)

    @partial(jax.jit, static_argnums=(0,))
    def reset(
        self,
        key: chex.PRNGKey,
        params: Optional[RSAEnvParams] = None,
        state: Optional[RSAEnvState] = None,
    ) -> Tuple[chex.Array, LogEnvState]:
        obs, env_state = self._env.reset(key, params, state)
        log_state = LogEnvState(
            env_state=env_state,
            lengths=jnp.array(0, dtype=dtype_config.LARGE_INT_DTYPE),
            returns=jnp.array(0, dtype=dtype_config.REWARD_DTYPE),
            cum_returns=jnp.array(0, dtype=dtype_config.LARGE_FLOAT_DTYPE),
            accepted_services=jnp.array(0, dtype=dtype_config.LARGE_INT_DTYPE),
            accepted_bitrate=jnp.array(0, dtype=dtype_config.LARGE_FLOAT_DTYPE),
            total_bitrate=jnp.array(0, dtype=dtype_config.LARGE_FLOAT_DTYPE),
            utilisation=jnp.array(0, dtype=dtype_config.LARGE_FLOAT_DTYPE),
            terminal=jnp.array(False),
            truncated=jnp.array(False),
        )
        return obs, log_state

    @partial(jax.jit, static_argnums=(0,))
    def step(
        self,
        key: chex.PRNGKey,
        log_state: LogEnvState,
        action: Union[int, float] | Tuple[Union[int, float], Union[int, float]],
        params: RSAEnvParams,
    ) -> Tuple[Array, LogEnvState, float, bool, bool, dict]:
        obs, env_state, reward, terminal, truncated, info = self._env.step(
            key, log_state.env_state, action, params
        )
        done = jnp.logical_or(terminal, truncated)
        log_state = LogEnvState(
            env_state=env_state,
            lengths=log_state.lengths * (1 - done) + 1,
            returns=jnp.asarray(reward, dtype=dtype_config.REWARD_DTYPE),
            cum_returns=log_state.cum_returns * (1 - done) + reward,
            accepted_services=env_state.accepted_services,
            accepted_bitrate=env_state.accepted_bitrate,
            total_bitrate=env_state.total_bitrate,
            utilisation=jnp.count_nonzero(env_state.link_slot_array)
            / env_state.link_slot_array.size,
            terminal=terminal,
            truncated=truncated,
        )
        info["lengths"] = log_state.lengths
        info["returns"] = log_state.returns
        info["cum_returns"] = log_state.cum_returns
        info["accepted_services"] = log_state.accepted_services
        info["accepted_bitrate"] = log_state.accepted_bitrate
        info["total_bitrate"] = log_state.total_bitrate
        info["utilisation"] = log_state.utilisation
        info["terminal"] = terminal
        info["truncated"] = truncated
        # First check if we're dealing with RSAGNModelEnvParams
        is_gn_params = isinstance(params, RSAGNModelEnvParams)

        # For RSA params, unpack the action
        if is_gn_params:
            action, power_action = action
            info["launch_power"] = power_action

        # Now, if we need to log actions OR we have RSA params, compute the common fields
        if is_gn_params or params.log_actions:
            # Compute common fields
            nodes_sd, dr_request = read_rsa_request(log_state.env_state.request_array)
            source, dest = nodes_sd
            i = get_path_indices(
                params, source, dest, params.k_paths, params.num_nodes, directed=params.directed_graph
            ).astype(jnp.int32)
            path_index, slot_index = process_path_action(log_state.env_state, params, action)

            # Set common info
            info["path_index"] = i + path_index
            info["slot_index"] = slot_index
            info["source"] = source
            info["dest"] = dest
            info["data_rate"] = dr_request

            # RSA-specific throughput info
            if is_gn_params:
                info["throughput"] = env_state.throughput

            # Logging-specific info
            if params.log_actions:
                # RSA-specific logging
                if is_gn_params:
                    path = params.path_link_array.val[path_index.astype(jnp.int32)]
                    info["path_snr"] = get_snr_for_path(path, env_state.link_snr_array, params)[
                        slot_index.astype(jnp.int32)
                    ]
                # Common logging fields
                info["arrival_time"] = env_state.current_time[0]
                info["departure_time"] = env_state.current_time[0] + env_state.holding_time[0]
        return obs, log_state, reward, terminal, truncated, info

    def _tree_flatten(self) -> Tuple[Tuple[Any, ...], Tuple[Any, ...]]:
        children = ()  # arrays / dynamic values
        aux_data = (self._env,)  # static values, e.g. env params
        return (children, aux_data)

    @classmethod
    def _tree_unflatten(cls, aux_data: Tuple[Any, ...], children: Tuple[Any, ...]) -> "LogWrapper":
        return cls(*children, *aux_data)

Profiler

Simple wall-clock profiler that tracks named sections.

Usage

profiler = Profiler()

with profiler.section("compilation"): ...

for i in range(10): with profiler.section("training_step", frames=1000): ...

profiler.summary()

Source code in xlron/environments/wrappers.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
class Profiler:
    """Simple wall-clock profiler that tracks named sections.

    Usage:
        profiler = Profiler()

        with profiler.section("compilation"):
            ...

        for i in range(10):
            with profiler.section("training_step", frames=1000):
                ...

        profiler.summary()
    """

    def __init__(self, enabled: bool = True):
        self.enabled = enabled
        # Each key maps to a list of (elapsed_secs, frames) tuples
        self._records: dict[str, list[tuple[float, int | None]]] = {}
        self._order: list[str] = []  # Insertion order of section names

    def section(self, tag: str, frames: int | None = None) -> "_ProfileSection":
        """Return a context manager that times the enclosed block.

        Args:
            tag: Name for this section. Repeated uses accumulate.
            frames: Optional work-unit count (e.g. timesteps) for throughput.
        """
        return _ProfileSection(self, tag, frames)

    def _record(self, tag: str, elapsed: float, frames: int | None):
        if not self.enabled:
            return
        if tag not in self._records:
            self._records[tag] = []
            self._order.append(tag)
        self._records[tag].append((elapsed, frames))

    def summary(self):
        """Print a table summarising all recorded sections."""
        if not self._records:
            return
        total_wall = sum(e for entries in self._records.values() for e, _ in entries)
        header = (
            f"{'Section':<30} {'Calls':>6} {'Total (s)':>10} {'Mean (s)':>10} {'%':>6} {'FPS':>12}"
        )
        print("\n" + "=" * len(header))
        print("PROFILER SUMMARY")
        print("=" * len(header))
        print(header)
        print("-" * len(header))
        for tag in self._order:
            entries = self._records[tag]
            n = len(entries)
            total_t = sum(e for e, _ in entries)
            mean_t = total_t / n
            pct = 100.0 * total_t / total_wall if total_wall > 0 else 0.0
            total_frames = sum(f for _, f in entries if f is not None)
            fps_str = f"{total_frames / total_t:.2e}" if total_frames and total_t > 0 else ""
            print(f"{tag:<30} {n:>6} {total_t:>10.2f} {mean_t:>10.4f} {pct:>5.1f}% {fps_str:>12}")
        print("-" * len(header))
        print(f"{'TOTAL':<30} {'':>6} {total_wall:>10.2f}")
        print("=" * len(header) + "\n")

section(tag, frames=None)

Return a context manager that times the enclosed block.

Parameters:

Name Type Description Default
tag str

Name for this section. Repeated uses accumulate.

required
frames int | None

Optional work-unit count (e.g. timesteps) for throughput.

None
Source code in xlron/environments/wrappers.py
186
187
188
189
190
191
192
193
def section(self, tag: str, frames: int | None = None) -> "_ProfileSection":
    """Return a context manager that times the enclosed block.

    Args:
        tag: Name for this section. Repeated uses accumulate.
        frames: Optional work-unit count (e.g. timesteps) for throughput.
    """
    return _ProfileSection(self, tag, frames)

summary()

Print a table summarising all recorded sections.

Source code in xlron/environments/wrappers.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def summary(self):
    """Print a table summarising all recorded sections."""
    if not self._records:
        return
    total_wall = sum(e for entries in self._records.values() for e, _ in entries)
    header = (
        f"{'Section':<30} {'Calls':>6} {'Total (s)':>10} {'Mean (s)':>10} {'%':>6} {'FPS':>12}"
    )
    print("\n" + "=" * len(header))
    print("PROFILER SUMMARY")
    print("=" * len(header))
    print(header)
    print("-" * len(header))
    for tag in self._order:
        entries = self._records[tag]
        n = len(entries)
        total_t = sum(e for e, _ in entries)
        mean_t = total_t / n
        pct = 100.0 * total_t / total_wall if total_wall > 0 else 0.0
        total_frames = sum(f for _, f in entries if f is not None)
        fps_str = f"{total_frames / total_t:.2e}" if total_frames and total_t > 0 else ""
        print(f"{tag:<30} {n:>6} {total_t:>10.2f} {mean_t:>10.4f} {pct:>5.1f}% {fps_str:>12}")
    print("-" * len(header))
    print(f"{'TOTAL':<30} {'':>6} {total_wall:>10.2f}")
    print("=" * len(header) + "\n")

TimeIt

Context manager for timing execution of code blocks.

Source code in xlron/environments/wrappers.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
class TimeIt:
    """Context manager for timing execution of code blocks."""

    def __init__(self, tag, frames=None):
        self.tag = tag
        self.frames = frames

    def __enter__(self):
        self.start = timeit.default_timer()
        return self

    def __exit__(self, *args):
        self.elapsed_secs = timeit.default_timer() - self.start
        msg = self.tag + (": Elapsed time=%.2fs" % self.elapsed_secs)
        if self.frames:
            msg += ", FPS=%.2e" % (self.frames / self.elapsed_secs)
        print(msg)

Environment functions

aggregate_slots(full_mask, params)

Aggregate slot mask via max-pooling.

Source code in xlron/environments/env_funcs.py
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
@partial(jax.jit, static_argnums=(1,))
def aggregate_slots(full_mask: Array, params: EnvParams) -> Array:
    """Aggregate slot mask via max-pooling."""
    num_actions = math.ceil(params.link_resources / params.aggregate_slots)

    # Full mask has shape (k_paths * link_resources,)
    # Pad to make divisible by aggregate_slots
    pad_size = num_actions * params.aggregate_slots - params.link_resources
    if pad_size > 0:
        full_mask = full_mask.reshape((params.k_paths, params.link_resources))
        full_mask = jnp.pad(full_mask, ((0, 0), (0, pad_size)), constant_values=0)

    # Reshape to (k_paths, num_actions, aggregate_slots) and max over windows
    reshaped = full_mask.reshape(params.k_paths, num_actions, params.aggregate_slots)
    agg_mask = jnp.max(reshaped, axis=2).reshape(-1)

    return agg_mask

check_action_rmsa_gn_model(state, action_info, params)

Check if action is valid for RSA GN model Args: state (EnvState): Environment state params (EnvParams): Environment parameters action (chex.Array): Action array Returns: bool: True if action is invalid, False if action is valid

Source code in xlron/environments/env_funcs.py
3724
3725
3726
3727
3728
3729
3730
3731
3732
3733
3734
3735
3736
3737
3738
3739
3740
3741
3742
3743
3744
3745
3746
3747
3748
3749
@partial(jax.jit, static_argnums=(2,))
def check_action_rmsa_gn_model(state: EnvState, action_info: ActionInfo, params: EnvParams) -> bool:
    """Check if action is valid for RSA GN model
    Args:
        state (EnvState): Environment state
        params (EnvParams): Environment parameters
        action (chex.Array): Action array
    Returns:
        bool: True if action is invalid, False if action is valid
    """
    # Check if action is valid
    # TODO - log failure reasons in info
    snr_sufficient_check = check_snr_sufficient(state, params)
    rsa_check = check_action_rsa(state, action_info, params)
    # Check total power per link doesn't exceed max_power_per_fibre
    total_power = compute_total_power_per_link(state.channel_power_array, state.path_index_array)
    power_check = jnp.any(total_power > params.max_power_per_fibre)
    return jnp.any(
        jnp.stack(
            (
                rsa_check,
                snr_sufficient_check,
                power_check,
            )
        )
    )

check_action_rsa(state, action_info, params)

Differentiable version of check_action_rsa.

Parameters:

Name Type Description Default
state

Current environment state

required
temperature

Controls sharpness of sigmoid

required

Returns:

Type Description

Continuous value that behaves like the original boolean check

Source code in xlron/environments/env_funcs.py
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
def check_action_rsa(state, action_info, params):
    """
    Differentiable version of check_action_rsa.

    Args:
        state: Current environment state
        temperature: Controls sharpness of sigmoid

    Returns:
        Continuous value that behaves like the original boolean check
    """
    # Calculate differentiable version of each check
    spectrum_reuse_check = differentiable_check_no_spectrum_reuse(state, action_info, params)
    overflow_check = check_slot_overflow(state, action_info, params)
    no_action_check = check_no_op(state, action_info, params)
    unique_path_check = check_real_path(state, action_info, params)
    # For multiple checks, use a differentiable version of "any"
    # Instead of jnp.any, use max to combine checks
    combined_check = jnp.max(
        jnp.stack(
            [
                spectrum_reuse_check,
                overflow_check,
                no_action_check,
                unique_path_check,
            ]
        )
    )
    return combined_check

check_no_op(state, action_info, params)

Check for the "NO OP" action. This will be the maximum valid action idex + 1, resulting in a path index exceeding K paths.

Source code in xlron/environments/env_funcs.py
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
def check_no_op(state: EnvState, action_info: ActionInfo, params: EnvParams):
    """Check for the "NO OP" action.
    This will be the maximum valid action idex + 1,
    resulting in a path index exceeding K paths."""
    overflow = differentiable_compare(
        action_info.path_index,
        params.k_paths,
        op_type=">=",
        temperature=params.temperature,
        differentiable=params.differentiable,
    )
    return overflow

check_no_spectrum_reuse(state, action_info, params)

slot-=1 when used, should be zero when unoccupied, so check if any < -1 in slot array.

Parameters:

Name Type Description Default
link_slot_array

Link slot array (L x S) where L is number of links and S is number of slots

required

Returns:

Name Type Description
bool bool

True if check failed, False if check passed

Source code in xlron/environments/env_funcs.py
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
def check_no_spectrum_reuse(state: EnvState, action_info: ActionInfo, params: EnvParams) -> bool:
    """slot-=1 when used, should be zero when unoccupied, so check if any < -1 in slot array.

    Args:
        link_slot_array: Link slot array (L x S) where L is number of links and S is number of slots

    Returns:
        bool: True if check failed, False if check passed
    """
    path_mask = action_info.path[:, None]  # (num_links, 1)
    slots = path_mask * state.link_slot_array  # (num_links, link_resources)
    check = differentiable_compare(
        jnp.max(jnp.max(slots, axis=0)), 1, ">", params.temperature, params.differentiable
    )
    return check

check_real_path(state, action_info, params)

Check if path is a dummy (all-zeros). A valid path always uses at least one link.

Source code in xlron/environments/env_funcs.py
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
def check_real_path(state: EnvState, action_info: ActionInfo, params: EnvParams):
    """Check if path is a dummy (all-zeros). A valid path always uses at least one link."""
    is_dummy = differentiable_compare(
        jnp.max(action_info.path),
        0,
        op_type="==",
        temperature=params.temperature,
        differentiable=params.differentiable,
    )
    return is_dummy

check_slot_overflow(state, action_info, params)

If the action selects slot near the end, then the required slots can overflow and start filling from the start of the array, which might be free! To prevent this, we check the action index + required slots

Source code in xlron/environments/env_funcs.py
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
def check_slot_overflow(state: EnvState, action_info: ActionInfo, params: EnvParams):
    """If the action selects slot near the end, then the required slots can
    overflow and start filling from the start of the array, which might be free!
    To prevent this, we check the action index + required slots
    """
    overflow = differentiable_compare(
        action_info.initial_slot_index + action_info.num_slots,
        params.link_resources,
        op_type=">",
        temperature=params.temperature,
        differentiable=params.differentiable,
    )
    return overflow

check_snr_sufficient(state, params)

Check if SNR is sufficient for all active connections. Args: state (EnvState): Environment state params (EnvParams): Environment parameters Returns: jnp.array: 1 if any active connection has insufficient SNR, else 0

Source code in xlron/environments/env_funcs.py
3464
3465
3466
3467
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
def check_snr_sufficient(state: RSAGNModelEnvState, params: RSAGNModelEnvParams) -> chex.Array:
    """Check if SNR is sufficient for all active connections.
    Args:
        state (EnvState): Environment state
        params (EnvParams): Environment parameters
    Returns:
        jnp.array: 1 if any active connection has insufficient SNR, else 0
    """
    required_snr_array = get_required_snr_se_kurtosis_array(
        state.modulation_format_index_array, 2, params
    )
    lightpath_snr_array = get_lightpath_snr(state, params)
    # Only check slots that are actually occupied (have nonzero channel bandwidth)
    is_occupied = state.channel_centre_bw_array != 0
    snr_insufficient = jnp.where(lightpath_snr_array >= required_snr_array, 0, 1)
    snr_insufficient = jnp.where(is_occupied, snr_insufficient, 0)
    return jnp.any(snr_insufficient)

complete_step_rmsa_gn_model(state, action_info, check, params)

Complete step for RMSA GN-model environments.

Same as RSA GN-model, plus modulation_format_index_array restoration.

Source code in xlron/environments/env_funcs.py
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
def complete_step_rmsa_gn_model(
    state: RSAGNModelEnvState,
    action_info: ActionInfo,
    check: chex.Array,
    params: RSAGNModelEnvParams,
) -> RSAGNModelEnvState:
    """Complete step for RMSA GN-model environments.

    Same as RSA GN-model, plus modulation_format_index_array restoration.
    """
    fail = check
    success = 1 - check

    fail_f_dep = fail.astype(state.link_slot_departure_array.dtype)
    fail_f_slots = fail.astype(state.link_slot_array.dtype)

    # --- Undo partial RSA allocation on failure ---
    state = state.replace(
        link_slot_array=state.link_slot_array - (fail_f_slots * action_info.affected_slots_mask),
        link_slot_departure_array=state.link_slot_departure_array
        - (
            fail_f_dep * action_info.affected_slots_mask * (state.current_time + state.holding_time)
        ),
    )

    # --- Restore GN-model-specific state on failure ---
    one_m_fail = 1 - fail
    one_m_fail_f = one_m_fail.astype(state.channel_power_array.dtype)

    state = state.replace(
        channel_centre_bw_array=state.channel_centre_bw_array * one_m_fail_f
        + state.channel_centre_bw_array_prev * fail.astype(state.channel_centre_bw_array.dtype),
        channel_power_array=state.channel_power_array * one_m_fail_f
        + state.channel_power_array_prev * fail.astype(state.channel_power_array.dtype),
        channel_centre_freq_array=state.channel_centre_freq_array * one_m_fail_f
        + state.channel_centre_freq_array_prev * fail.astype(state.channel_centre_freq_array.dtype),
        path_index_array=state.path_index_array * one_m_fail.astype(state.path_index_array.dtype)
        + state.path_index_array_prev * fail.astype(state.path_index_array.dtype),
        modulation_format_index_array=state.modulation_format_index_array
        * one_m_fail.astype(state.modulation_format_index_array.dtype)
        + state.modulation_format_index_array_prev
        * fail.astype(state.modulation_format_index_array.dtype),
    )

    # --- Book-keeping (always) ---
    state = state.replace(
        accepted_services=state.accepted_services + success,
        accepted_bitrate=state.accepted_bitrate
        + (success * action_info.requested_datarate * params.fec_rate),
        total_bitrate=state.total_bitrate + action_info.requested_datarate,
        total_timesteps=state.total_timesteps + 1,
    )
    return state

complete_step_rsa(state, action_info, check, params)

If the request is unsuccessful i.e. checks fail, then remove the partial (unfinalised) resource allocation. Partial resource allocation is indicated by negative time in link slot departure array. Check for values in link_slot_departure_array that are less than zero. If found, increase link_slot_array by +1 and link_slot_departure_array by current_time + holding_time of current request.

Parameters:

Name Type Description Default
state EnvState

Environment state

required

Returns:

Type Description
EnvState

Updated environment state

Source code in xlron/environments/env_funcs.py
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
def complete_step_rsa(
    state: EnvState, action_info: ActionInfo, check: Array, params: EnvParams
) -> EnvState:
    """If the request is unsuccessful i.e. checks fail, then remove the partial (unfinalised) resource allocation.
    Partial resource allocation is indicated by negative time in link slot departure array.
    Check for values in link_slot_departure_array that are less than zero.
    If found, increase link_slot_array by +1 and link_slot_departure_array by current_time + holding_time of current request.

    Args:
        state: Environment state

    Returns:
        Updated environment state
    """
    fail = check
    success = 1 - check
    state = state.replace(
        link_slot_array=state.link_slot_array - (fail * action_info.affected_slots_mask),
        link_slot_departure_array=state.link_slot_departure_array
        - (fail * action_info.affected_slots_mask * (state.current_time + state.holding_time)),
        accepted_services=state.accepted_services + success,
        accepted_bitrate=state.accepted_bitrate + (success * action_info.requested_datarate),
        total_bitrate=state.total_bitrate + action_info.requested_datarate,
        total_timesteps=state.total_timesteps + 1,
    )
    return state

complete_step_rsa_gn_model(state, action_info, check, params)

Complete step for RSA GN-model environments.

On failure (check==1), undo partial slot allocation and restore GN-model auxiliary arrays from their *_prev snapshots.

Source code in xlron/environments/env_funcs.py
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
def complete_step_rsa_gn_model(
    state: RSAGNModelEnvState,
    action_info: ActionInfo,
    check: chex.Array,
    params: RSAGNModelEnvParams,
) -> RSAGNModelEnvState:
    """Complete step for RSA GN-model environments.

    On failure (check==1), undo partial slot allocation and restore GN-model
    auxiliary arrays from their *_prev snapshots.
    """
    fail = check
    success = 1 - check

    # Cast once for broadcasting/multiplication
    fail_f_dep = fail.astype(state.link_slot_departure_array.dtype)
    fail_f_slots = fail.astype(state.link_slot_array.dtype)

    # --- Undo partial RSA allocation on failure (same pattern as complete_step_rsa) ---
    state = state.replace(
        link_slot_array=state.link_slot_array - (fail_f_slots * action_info.affected_slots_mask),
        link_slot_departure_array=state.link_slot_departure_array
        - (
            fail_f_dep * action_info.affected_slots_mask * (state.current_time + state.holding_time)
        ),
    )

    # --- Restore GN-model-specific state on failure (blend current vs *_prev) ---
    one_m_fail = 1 - fail
    one_m_fail_f = one_m_fail.astype(state.channel_power_array.dtype)

    state = state.replace(
        channel_centre_bw_array=state.channel_centre_bw_array * one_m_fail_f
        + state.channel_centre_bw_array_prev * fail.astype(state.channel_centre_bw_array.dtype),
        channel_power_array=state.channel_power_array * one_m_fail_f
        + state.channel_power_array_prev * fail.astype(state.channel_power_array.dtype),
        channel_centre_freq_array=state.channel_centre_freq_array * one_m_fail_f
        + state.channel_centre_freq_array_prev * fail.astype(state.channel_centre_freq_array.dtype),
        path_index_array=state.path_index_array * one_m_fail.astype(state.path_index_array.dtype)
        + state.path_index_array_prev * fail.astype(state.path_index_array.dtype),
    )

    if params.monitor_active_lightpaths:
        # Only undo partially-added lightpaths (negative departure), and only if fail==1
        neg = (state.active_lightpaths_array_departure < zero).astype(
            state.active_lightpaths_array_departure.dtype
        )
        do_undo_dep = neg * fail.astype(state.active_lightpaths_array_departure.dtype)
        do_undo_lp = do_undo_dep.astype(state.active_lightpaths_array.dtype)

        state = state.replace(
            active_lightpaths_array=state.active_lightpaths_array * (1 - do_undo_lp)
            + jnp.array(-1, dtype=state.active_lightpaths_array.dtype) * do_undo_lp,
            active_lightpaths_array_departure=state.active_lightpaths_array_departure
            + do_undo_dep * (state.current_time + state.holding_time),
        )

    # --- Book-keeping (always) ---
    state = state.replace(
        accepted_services=state.accepted_services + success,
        accepted_bitrate=state.accepted_bitrate + (success * action_info.requested_datarate),
        total_bitrate=state.total_bitrate + action_info.requested_datarate,
        total_timesteps=state.total_timesteps + 1,
    )
    return state

complete_step_rwalr(state, action_info, check, params)

Complete step for RWA-LR environments. Unlike complete_step_rsa, this does not modify link_slot_array on failure, because implement_action_rwalr already handles the undo via blending and link_slot_array stores a capacity mask (not an occupancy counter).

Source code in xlron/environments/env_funcs.py
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
def complete_step_rwalr(
    state: EnvState, action_info: ActionInfo, check: Array, params: EnvParams
) -> EnvState:
    """Complete step for RWA-LR environments.
    Unlike complete_step_rsa, this does not modify link_slot_array on failure,
    because implement_action_rwalr already handles the undo via blending and
    link_slot_array stores a capacity mask (not an occupancy counter).
    """
    fail = check
    success = 1 - check
    state = state.replace(
        link_slot_departure_array=state.link_slot_departure_array
        - (fail * action_info.affected_slots_mask * (state.current_time + state.holding_time)),
        accepted_services=state.accepted_services + success,
        accepted_bitrate=state.accepted_bitrate + (success * action_info.requested_datarate),
        total_bitrate=state.total_bitrate + action_info.requested_datarate,
        total_timesteps=state.total_timesteps + 1,
    )
    return state

compute_band_gaps_from_csv(link_resources, ref_lambda, slot_size, band_data_filepath=None)

Compute band gap slot positions from band definition CSV data.

Reads the band data CSV which defines frequency ranges for each optical band. Any slot whose centre frequency falls between bands is marked as a gap slot.

Parameters:

Name Type Description Default
link_resources int

Number of frequency slots per link.

required
ref_lambda float

Reference wavelength (m).

required
slot_size float

Slot width in GHz.

required
band_data_filepath str | None

Optional path to band data CSV file. Defaults to built-in band_data.csv.

None

Returns:

Type Description
Tuple[list, list]

Tuple of (gap_start_slots, gap_width_slots) as Python lists of ints.

Source code in xlron/environments/env_funcs.py
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
def compute_band_gaps_from_csv(
    link_resources: int,
    ref_lambda: float,
    slot_size: float,
    band_data_filepath: str | None = None,
) -> Tuple[list, list]:
    """Compute band gap slot positions from band definition CSV data.

    Reads the band data CSV which defines frequency ranges for each optical band.
    Any slot whose centre frequency falls between bands is marked as a gap slot.

    Args:
        link_resources: Number of frequency slots per link.
        ref_lambda: Reference wavelength (m).
        slot_size: Slot width in GHz.
        band_data_filepath: Optional path to band data CSV file. Defaults to
            built-in ``band_data.csv``.

    Returns:
        Tuple of (gap_start_slots, gap_width_slots) as Python lists of ints.
    """
    f = (
        pathlib.Path(band_data_filepath)
        if band_data_filepath
        else (pathlib.Path(__file__).parents[1].absolute() / "data" / "gn_model" / "band_data.csv")
    )
    band_data = np.genfromtxt(f, delimiter=",", skip_header=1, usecols=(1, 2, 3, 4))
    # Columns: wavelength_min_nm, wavelength_max_nm, frequency_min_ghz, frequency_max_ghz
    band_freq_lo = band_data[:, 2]  # frequency_min_ghz
    band_freq_hi = band_data[:, 3]  # frequency_max_ghz

    # Compute per-slot absolute frequencies in GHz
    slot_centres = (np.arange(link_resources) - (link_resources - 1) / 2) * slot_size
    ref_frequency_ghz = c / ref_lambda / 1e9
    slot_frequencies_ghz = ref_frequency_ghz + slot_centres

    # Mark each slot as covered if its centre falls within any band's frequency range
    covered = np.zeros(link_resources, dtype=bool)
    for i, freq in enumerate(slot_frequencies_ghz):
        for j in range(len(band_freq_lo)):
            if band_freq_lo[j] <= freq <= band_freq_hi[j]:
                covered[i] = True
                break

    # Find contiguous runs of uncovered slots
    gap_start_slots = []
    gap_width_slots = []
    in_gap = False
    gap_start = 0
    for i in range(link_resources):
        if not covered[i] and not in_gap:
            gap_start = i
            in_gap = True
        elif covered[i] and in_gap:
            gap_start_slots.append(gap_start)
            gap_width_slots.append(i - gap_start)
            in_gap = False
    if in_gap:
        gap_start_slots.append(gap_start)
        gap_width_slots.append(link_resources - gap_start)

    return gap_start_slots, gap_width_slots

compute_band_layout(slot_size, band_preference, inter_band_gap_ghz=25.0, band_data_filepath=None)

Compute band layout: link_resources, ref_lambda, slot centre frequencies, gaps, and orderings.

Given a slot size and selected bands, this function: 1. Determines how many slots fit in each band 2. Inserts 1 gap slot per inter-band boundary (representing inter_band_gap_ghz) 3. Computes absolute centre frequencies for every slot 4. Returns all derived quantities needed by make_env

Parameters:

Name Type Description Default
slot_size float

Spectral width of a frequency slot in GHz.

required
band_preference str

Comma-separated band names in preference order (e.g. "C,L,S").

required
inter_band_gap_ghz float

Physical spectral width of inter-band gap in GHz (~0.2 nm).

25.0
band_data_filepath str | None

Optional path to band data CSV. Defaults to built-in.

None

Returns:

Type Description
dict

Dict with keys: link_resources, ref_lambda, slot_centre_freq_array (relative GHz),

dict

gap_start_slots, gap_width_slots, band_slot_order_ff, band_slot_order_lf.

Source code in xlron/environments/env_funcs.py
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
def compute_band_layout(
    slot_size: float,
    band_preference: str,
    inter_band_gap_ghz: float = 25.0,
    band_data_filepath: str | None = None,
) -> dict:
    """Compute band layout: link_resources, ref_lambda, slot centre frequencies, gaps, and orderings.

    Given a slot size and selected bands, this function:
    1. Determines how many slots fit in each band
    2. Inserts 1 gap slot per inter-band boundary (representing ``inter_band_gap_ghz``)
    3. Computes absolute centre frequencies for every slot
    4. Returns all derived quantities needed by make_env

    Args:
        slot_size: Spectral width of a frequency slot in GHz.
        band_preference: Comma-separated band names in preference order (e.g. "C,L,S").
        inter_band_gap_ghz: Physical spectral width of inter-band gap in GHz (~0.2 nm).
        band_data_filepath: Optional path to band data CSV. Defaults to built-in.

    Returns:
        Dict with keys: link_resources, ref_lambda, slot_centre_freq_array (relative GHz),
        gap_start_slots, gap_width_slots, band_slot_order_ff, band_slot_order_lf.
    """
    f = (
        pathlib.Path(band_data_filepath)
        if band_data_filepath
        else (pathlib.Path(__file__).parents[1].absolute() / "data" / "gn_model" / "band_data.csv")
    )
    band_names_raw = np.genfromtxt(f, delimiter=",", skip_header=1, usecols=(0,), dtype=str)
    band_data_num = np.genfromtxt(f, delimiter=",", skip_header=1, usecols=(1, 2, 3, 4))
    band_freq_lo = band_data_num[:, 2]  # frequency_min_ghz
    band_freq_hi = band_data_num[:, 3]  # frequency_max_ghz

    # Build lookup: band_name -> (freq_lo, freq_hi)
    band_info = {}
    for i, name in enumerate(band_names_raw):
        band_info[name.upper()] = (band_freq_lo[i], band_freq_hi[i])

    # Parse preference list and filter to selected bands
    preference_list = [b.strip().upper() for b in band_preference.split(",")]
    selected = []
    for name in preference_list:
        if name not in band_info:
            raise ValueError(
                f"Band '{name}' not found in band data CSV. Available: {list(band_info.keys())}"
            )
        selected.append((name, band_info[name][0], band_info[name][1]))

    # Sort selected bands by frequency (ascending)
    selected.sort(key=lambda x: x[1])

    # Compute slots per band and build the layout
    slot_centres_abs_ghz = []  # absolute centre frequencies in GHz
    gap_start_slots = []
    gap_width_slots = []
    band_slot_ranges = {}  # band_name -> list of slot indices (for ordering)

    slot_idx = 0
    for i, (name, freq_lo, freq_hi) in enumerate(selected):
        band_width = freq_hi - freq_lo
        num_slots_in_band = int(math.floor(band_width / slot_size))

        # Slot centres within this band: start half a slot_size from the low edge
        band_start = freq_lo + slot_size / 2
        band_slots_indices = list(range(slot_idx, slot_idx + num_slots_in_band))
        band_slot_ranges[name] = band_slots_indices

        for j in range(num_slots_in_band):
            slot_centres_abs_ghz.append(band_start + j * slot_size)

        slot_idx += num_slots_in_band

        # Insert gap slot between this band and the next (if not the last band)
        if i < len(selected) - 1:
            next_freq_lo = selected[i + 1][1]
            gap_centre = (freq_hi + next_freq_lo) / 2
            gap_start_slots.append(slot_idx)
            gap_width_slots.append(1)
            slot_centres_abs_ghz.append(gap_centre)
            slot_idx += 1

    link_resources = slot_idx
    slot_centres_abs_ghz = np.array(slot_centres_abs_ghz, dtype=np.float64)

    # Compute ref_lambda as centre of the total frequency range
    total_freq_min = slot_centres_abs_ghz[0] - slot_size / 2
    total_freq_max = slot_centres_abs_ghz[-1] + slot_size / 2
    centre_freq_ghz = (total_freq_min + total_freq_max) / 2
    ref_lambda = c / (centre_freq_ghz * 1e9)

    # Convert to relative GHz offsets from ref_lambda
    ref_freq_ghz = c / ref_lambda / 1e9
    slot_centre_freq_rel_ghz = slot_centres_abs_ghz - ref_freq_ghz

    # Build band-preference slot orderings
    order_ff = []
    order_lf = []
    for name in preference_list:
        slots = band_slot_ranges.get(name, [])
        order_ff.extend(slots)
        order_lf.extend(reversed(slots))
    # Add any bands not in preference list (shouldn't happen, but for safety)
    for name, _, _ in selected:
        if name not in preference_list:
            slots = band_slot_ranges.get(name, [])
            order_ff.extend(slots)
            order_lf.extend(reversed(slots))
    # Gap slots last
    for gs in gap_start_slots:
        order_ff.append(gs)
        order_lf.append(gs)

    assert len(order_ff) == link_resources, (
        f"band_slot_order length {len(order_ff)} != link_resources {link_resources}"
    )

    return {
        "link_resources": link_resources,
        "ref_lambda": ref_lambda,
        "slot_centre_freq_array": slot_centre_freq_rel_ghz.astype(np.float32),
        "gap_start_slots": gap_start_slots,
        "gap_width_slots": gap_width_slots,
        "band_slot_order_ff": np.array(order_ff, dtype=np.int32),
        "band_slot_order_lf": np.array(order_lf, dtype=np.int32),
    }

compute_band_slot_order(link_resources, ref_lambda, slot_size, band_preference, band_data_filepath=None)

Compute band-preference-ordered slot index arrays for first-fit and last-fit.

Parameters:

Name Type Description Default
link_resources int

Number of frequency slots per link.

required
ref_lambda float

Reference wavelength (m).

required
slot_size float

Slot width in GHz.

required
band_preference str

Comma-separated band names in preference order (e.g. "C,L,S").

required
band_data_filepath str | None

Optional path to band data CSV. Defaults to built-in.

None

Returns:

Type Description
ndarray

Tuple of (band_slot_order_ff, band_slot_order_lf) as numpy int32 arrays

ndarray

of shape (link_resources,).

Tuple[ndarray, ndarray]

band_slot_order_ff has slots ascending within each band, bands in preference order.

Tuple[ndarray, ndarray]

band_slot_order_lf has slots descending within each band, bands in preference order.

Source code in xlron/environments/env_funcs.py
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
def compute_band_slot_order(
    link_resources: int,
    ref_lambda: float,
    slot_size: float,
    band_preference: str,
    band_data_filepath: str | None = None,
) -> Tuple[np.ndarray, np.ndarray]:
    """Compute band-preference-ordered slot index arrays for first-fit and last-fit.

    Args:
        link_resources: Number of frequency slots per link.
        ref_lambda: Reference wavelength (m).
        slot_size: Slot width in GHz.
        band_preference: Comma-separated band names in preference order (e.g. "C,L,S").
        band_data_filepath: Optional path to band data CSV. Defaults to built-in.

    Returns:
        Tuple of (band_slot_order_ff, band_slot_order_lf) as numpy int32 arrays
        of shape (link_resources,).
        band_slot_order_ff has slots ascending within each band, bands in preference order.
        band_slot_order_lf has slots descending within each band, bands in preference order.
    """
    f = (
        pathlib.Path(band_data_filepath)
        if band_data_filepath
        else (pathlib.Path(__file__).parents[1].absolute() / "data" / "gn_model" / "band_data.csv")
    )
    # Read band names (string column)
    band_names_raw = np.genfromtxt(f, delimiter=",", skip_header=1, usecols=(0,), dtype=str)
    band_names = list(band_names_raw)
    # Read numeric columns
    band_data_num = np.genfromtxt(f, delimiter=",", skip_header=1, usecols=(1, 2, 3, 4))
    band_freq_lo = band_data_num[:, 2]  # frequency_min_ghz
    band_freq_hi = band_data_num[:, 3]  # frequency_max_ghz

    # Compute per-slot frequencies
    slot_centres = (np.arange(link_resources) - (link_resources - 1) / 2) * slot_size
    ref_freq_ghz = c / ref_lambda / 1e9
    slot_freq_ghz = ref_freq_ghz + slot_centres

    # Assign each slot to a band
    band_slots = {name: [] for name in band_names}
    uncovered = []
    for i, freq in enumerate(slot_freq_ghz):
        assigned = False
        for j, name in enumerate(band_names):
            if band_freq_lo[j] <= freq <= band_freq_hi[j]:
                band_slots[name].append(i)
                assigned = True
                break
        if not assigned:
            uncovered.append(i)

    preference_list = [b.strip().upper() for b in band_preference.split(",")]

    order_ff = []
    order_lf = []
    # Preferred bands first
    for band_name in preference_list:
        slots = sorted(band_slots.get(band_name, []))
        order_ff.extend(slots)
        order_lf.extend(reversed(slots))
    # Then any bands not in the preference list (in CSV order)
    for name in band_names:
        if name not in preference_list:
            slots = sorted(band_slots.get(name, []))
            order_ff.extend(slots)
            order_lf.extend(reversed(slots))
    # Finally uncovered/gap slots
    order_ff.extend(uncovered)
    order_lf.extend(uncovered)

    assert len(order_ff) == link_resources, (
        f"band_slot_order length {len(order_ff)} != link_resources {link_resources}"
    )

    return np.array(order_ff, dtype=np.int32), np.array(order_lf, dtype=np.int32)

Compute total optical power per link by summing one power value per channel.

Each channel spans multiple contiguous slots with the same power value and the same path_index. We identify channel starts (where path_index differs from previous slot and is >= 0) and sum their power values.

Parameters:

Name Type Description Default
channel_power_array

(num_links, link_resources) per-channel power in linear Watts

required
path_index_array

(num_links, link_resources) lightpath index (-1 for empty)

required

Returns: (num_links,) total optical power per link in linear Watts

Source code in xlron/environments/env_funcs.py
3438
3439
3440
3441
3442
3443
3444
3445
3446
3447
3448
3449
3450
3451
3452
3453
3454
3455
3456
3457
3458
3459
3460
3461
def compute_total_power_per_link(channel_power_array, path_index_array):
    """Compute total optical power per link by summing one power value per channel.

    Each channel spans multiple contiguous slots with the same power value and
    the same path_index. We identify channel starts (where path_index differs
    from previous slot and is >= 0) and sum their power values.

    Args:
        channel_power_array: (num_links, link_resources) per-channel power in linear Watts
        path_index_array: (num_links, link_resources) lightpath index (-1 for empty)
    Returns:
        (num_links,) total optical power per link in linear Watts
    """
    occupied = path_index_array >= 0
    prev_path_idx = jnp.concatenate(
        [
            jnp.full((path_index_array.shape[0], 1), -1, dtype=path_index_array.dtype),
            path_index_array[:, :-1],
        ],
        axis=1,
    )
    is_channel_start = occupied & (path_index_array != prev_path_idx)
    channel_start_powers = jnp.where(is_channel_start, channel_power_array, 0.0)
    return jnp.sum(channel_start_powers, axis=1)

convert_node_probs_to_traffic_matrix(node_probs)

Convert list of node probabilities to symmetric traffic matrix.

Parameters:

Name Type Description Default
node_probs list

node probabilities

required

Returns:

Name Type Description
traffic_matrix Array

traffic matrix

Source code in xlron/environments/env_funcs.py
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
def convert_node_probs_to_traffic_matrix(node_probs: list) -> chex.Array:
    """Convert list of node probabilities to symmetric traffic matrix.

    Args:
        node_probs: node probabilities

    Returns:
        traffic_matrix: traffic matrix
    """
    matrix = jnp.outer(node_probs, node_probs).astype(dtype_config.SMALL_FLOAT_DTYPE)
    # Set lead diagonal to zero
    matrix = jnp.where(jnp.eye(matrix.shape[0]) == 1, 0, matrix)
    matrix = normalise_traffic_matrix(matrix)
    return matrix

count_until_next_one(array, position, temperature, differentiable=True)

Counts positions until the next 1 in the array. Made differentiable using straight-through gradient trick.

Parameters:

Name Type Description Default
array Array

Input array

required
position int

Starting position for counting

required
temperature float

Temperature for differentiable approximation

required
differentiable bool

If False, use non-differentiable operations

True

Returns:

Type Description
Array

Number of positions until the next 1

Source code in xlron/environments/env_funcs.py
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
def count_until_next_one(
    array: chex.Array, position: int, temperature: float, differentiable: bool = True
) -> chex.Array:
    """
    Counts positions until the next 1 in the array.
    Made differentiable using straight-through gradient trick.

    Args:
        array: Input array
        position: Starting position for counting
        temperature: Temperature for differentiable approximation
        differentiable: If False, use non-differentiable operations

    Returns:
        Number of positions until the next 1
    """
    # Add 1s to end so that end block is counted and slice shape can be fixed
    shape = array.shape[0]
    array = jnp.concatenate([array, jnp.ones(array.shape[0], dtype=dtype_config.LARGE_INT_DTYPE)])
    # Find the indices of 1 in the array
    one_indices = jax.lax.dynamic_slice(array, (position,), (shape,))
    # Use our differentiable_argmax helper
    next_one_idx = differentiable_argmax(
        one_indices, temperature=temperature, differentiable=differentiable
    )
    return next_one_idx + 1

count_until_previous_one(array, position, temperature, differentiable=True)

Counts positions until the previous 1 in the array. Made differentiable using straight-through gradient trick.

Parameters:

Name Type Description Default
array Array

Input array

required
position int

Starting position for counting backwards

required
temperature float

Temperature for differentiable approximation

required
differentiable bool

If False, use non-differentiable operations

True

Returns:

Type Description
int

Number of positions until the previous 1

Source code in xlron/environments/env_funcs.py
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
def count_until_previous_one(
    array: chex.Array, position: int, temperature: float, differentiable: bool = True
) -> int:
    """
    Counts positions until the previous 1 in the array.
    Made differentiable using straight-through gradient trick.

    Args:
        array: Input array
        position: Starting position for counting backwards
        temperature: Temperature for differentiable approximation
        differentiable: If False, use non-differentiable operations

    Returns:
        Number of positions until the previous 1
    """
    # Add 1s to start so that end block is counted and slice shape can be fixed
    shape = array.shape[0]
    array = jnp.concatenate([jnp.ones(array.shape[0], dtype=dtype_config.LARGE_INT_DTYPE), array])
    # Find the indices of 1 in the array
    one_indices = jax.lax.dynamic_slice(array, (-shape - position,), (shape,))
    one_indices = jnp.flip(one_indices)
    # Use our differentiable_argmax helper
    prev_one_idx = differentiable_argmax(
        one_indices, temperature=temperature, differentiable=differentiable
    )
    return prev_one_idx + 1

create_run_name(config)

Create name for run based on config flags

Source code in xlron/environments/env_funcs.py
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
def create_run_name(config: Union[box.Box, dict]) -> str:
    """Create name for run based on config flags"""
    env_type = config["env_type"]
    topology = config["topology_name"]
    slots = config["link_resources"]
    gnn = "_GNN" if config["USE_GNN"] else ""
    incremental = "_INC" if config["incremental_loading"] else ""
    run_name = f"{env_type}_{topology}_{slots}{gnn}{incremental}".upper()
    if config["EVAL_HEURISTIC"]:
        run_name += f"_{config['path_heuristic']}"
        if env_type.lower() == "vone":
            run_name += f"_{config['node_heuristic']}"
    elif config["EVAL_MODEL"]:
        run_name += "_EVAL"
    return run_name

differentiable_check_no_spectrum_reuse(state, action_info, params)

Differentiable version of check_no_spectrum_reuse with improved gradient properties.

Parameters:

Name Type Description Default
link_slot_array

Link slot array (L x S) where L is number of links and S is number of slots

required
temperature

Controls the sharpness of the gradient response

required
differentiable

If False, return hard result directly without gradient approximation

required

Returns:

Type Description

A value that behaves like the original boolean check in forward pass

but has zero gradient when there are no violations and otherwise

has gradient pointing toward reducing violations

Source code in xlron/environments/env_funcs.py
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
def differentiable_check_no_spectrum_reuse(
    state: EnvState, action_info: ActionInfo, params: EnvParams
):
    """
    Differentiable version of check_no_spectrum_reuse with improved gradient properties.

    Args:
        link_slot_array: Link slot array (L x S) where L is number of links and S is number of slots
        temperature: Controls the sharpness of the gradient response
        differentiable: If False, return hard result directly without gradient approximation

    Returns:
        A value that behaves like the original boolean check in forward pass
        but has zero gradient when there are no violations and otherwise
        has gradient pointing toward reducing violations
    """
    # Hard result for forward pass (original behavior)
    hard_result = check_no_spectrum_reuse(state, action_info, params)

    # If not differentiable mode, return hard result directly
    if not params.differentiable:
        return hard_result

    # Measure violations (how much each element exceeds the threshold of -1)
    violations = jnp.maximum(0, -1 - state.link_slot_array)

    # Any violation is considered a violation (alternatively can sum to discourage more egregious violations)
    # TODO - see if sum vs. max makes a difference in solution quality
    total_violation = jnp.max(violations)

    # Scale violations by temperature
    scaled_violation = params.temperature * total_violation

    # Use a function with zero gradient at zero: x²/(1+x²)
    # This function:
    # - Equals 0 when there are no violations
    # - Has gradient 0 when there are no violations
    # - Grows monotonically toward 1 as violations increase
    soft_result = (scaled_violation**2) / (1 + scaled_violation**2)

    # Apply straight-through trick
    return straight_through(hard_result, soft_result)

find_block_ends(path_slots)

Finds the end positions of blocks in the path slots.

Parameters:

Name Type Description Default
path_slots Array

Array of path slots

required

Returns:

Type Description
Array

Array with 1s at the end positions of blocks

Source code in xlron/environments/env_funcs.py
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
def find_block_ends(path_slots: Array) -> Array:
    """
    Finds the end positions of blocks in the path slots.

    Args:
        path_slots: Array of path slots

    Returns:
        Array with 1s at the end positions of blocks
    """
    transitions = jnp.diff(path_slots, append=1)  # Find transition 0 to 1
    return jnp.clip(transitions, 0, 1)

find_block_starts(path_slots)

Finds the starting positions of blocks in the path slots. Args: path_slots: Array of path slots

Returns:

Type Description
Array

Array with 1s at the starting positions of blocks

Source code in xlron/environments/env_funcs.py
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
def find_block_starts(path_slots: Array) -> Array:
    """
    Finds the starting positions of blocks in the path slots.
    Args:
        path_slots: Array of path slots

    Returns:
        Array with 1s at the starting positions of blocks
    """
    # Add a [1] at the beginning to find transitions from 1 to 0
    transitions = jnp.clip(jnp.diff(path_slots, prepend=1), -1, 0)  # Find transitions (1 to 0)
    return jnp.abs(transitions)

generate_arrival_holding_times(key, params)

Generate arrival and holding times based on Poisson distributed events. To understand how sampling from e^-x can be transformed to sample from lambdae^-(x/lambda) see: https://en.wikipedia.org/wiki/Inverse_transform_sampling#Examples Basically, inverse transform sampling is used to sample from a distribution with CDF F(x). The CDF of the exponential distribution (lambdae^-{lambdax}) is F(x) = 1 - e^-{lambdax}. Therefore, the inverse CDF is x = -ln(1-u)/lambda, where u is sample from uniform distribution. Therefore, we need to divide jax.random.exponential() by lambda in order to scale the standard exponential CDF. Experimental histograms of this method compared to random.expovariate() in Python's random library show that the two methods are equivalent. Also see: https://numpy.org/doc/stable/reference/random/generated/numpy.random.exponential.html https://jax.readthedocs.io/en/latest/_autosummary/jax.random.exponential.html

Parameters:

Name Type Description Default
key

PRNG key

required
params

Environment parameters

required

Returns:

Name Type Description
arrival_time

Arrival time

holding_time

Holding time

Source code in xlron/environments/env_funcs.py
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
@partial(jax.jit, static_argnums=(1,))
def generate_arrival_holding_times(key, params):
    """
    Generate arrival and holding times based on Poisson distributed events.
    To understand how sampling from e^-x can be transformed to sample from lambda*e^-(x/lambda) see:
    https://en.wikipedia.org/wiki/Inverse_transform_sampling#Examples
    Basically, inverse transform sampling is used to sample from a distribution with CDF F(x).
    The CDF of the exponential distribution (lambda*e^-{lambda*x}) is F(x) = 1 - e^-{lambda*x}.
    Therefore, the inverse CDF is x = -ln(1-u)/lambda, where u is sample from uniform distribution.
    Therefore, we need to divide jax.random.exponential() by lambda in order to scale the standard exponential CDF.
    Experimental histograms of this method compared to random.expovariate() in Python's random library show that
    the two methods are equivalent.
    Also see: https://numpy.org/doc/stable/reference/random/generated/numpy.random.exponential.html
    https://jax.readthedocs.io/en/latest/_autosummary/jax.random.exponential.html

    Args:
        key: PRNG key
        params: Environment parameters

    Returns:
        arrival_time: Arrival time
        holding_time: Holding time
    """
    key_arrival, key_holding = jax.random.split(key, 2)
    arrival_time = (
        jax.random.exponential(key_arrival, shape=(1,), dtype=dtype_config.SMALL_FLOAT_DTYPE)
        / params.arrival_rate
    )  # Divide because it is rate (lambda)
    if params.truncate_holding_time:
        # For DeepRMSA, need to generate holding times that are less than 2*mean_service_holding_time
        key_holding = jax.random.split(key, 5)
        holding_times = jax.vmap(
            lambda x: jax.random.exponential(x, shape=(1,)) * params.mean_service_holding_time
        )(key_holding).reshape(-1)
        holding_times = jnp.where(
            holding_times < 2 * params.mean_service_holding_time, holding_times, zero
        )
        # Get first non-zero value in holding_times
        holding_time_indices = differentiable_where(
            holding_times > 0,
            jnp.arange(holding_times.shape[0]),
            0,
            threshold=0.5,
            temperature=params.temperature,
            differentiable=params.differentiable,
        )
        non_zero_index = differentiable_argmax(
            holding_time_indices,
            temperature=params.temperature,
            differentiable=params.differentiable,
        )
        holding_time = differentiable_indexing(
            jnp.squeeze(holding_times),
            (non_zero_index,),
            params.temperature,
            params.differentiable,
        )
    else:
        holding_time = (
            jax.random.exponential(key_holding, shape=(1,), dtype=dtype_config.SMALL_FLOAT_DTYPE)
            * params.mean_service_holding_time
        )  # Multiply because it is mean (1/lambda)
    return arrival_time, holding_time

get_best_modulation_format_simple(state, path, initial_slot_index, params)

Get modulation format for lightpath. Assume worst case (least Gaussian) modulation format when calculating SNR. Args: state (EnvState): Environment state path (chex.Array): Path array initial_slot_index (int): Initial slot index params (EnvParams): Environment parameters Returns: jnp.array: Acceptable modulation format indices

Source code in xlron/environments/env_funcs.py
3665
3666
3667
3668
3669
3670
3671
3672
3673
3674
3675
3676
3677
3678
3679
3680
3681
3682
3683
3684
3685
3686
3687
3688
3689
3690
3691
3692
3693
3694
3695
@partial(jax.jit, static_argnums=(3,))
def get_best_modulation_format_simple(
    state: RSAGNModelEnvState,
    path: chex.Array,
    initial_slot_index: int,
    params: RSAGNModelEnvParams,
) -> chex.Array:
    """Get modulation format for lightpath.
    Assume worst case (least Gaussian) modulation format when calculating SNR.
    Args:
        state (EnvState): Environment state
        path (chex.Array): Path array
        initial_slot_index (int): Initial slot index
        params (EnvParams): Environment parameters
    Returns:
        jnp.array: Acceptable modulation format indices
    """
    link_snr_array = get_snr_link_array(state, params)
    snr_value = (
        get_snr_for_path(path, link_snr_array, params, state)[initial_slot_index]
        - params.snr_margin
    )  # Margin
    mod_format_count = params.modulations_array.val.shape[0]
    acceptable_mod_format_indices = jnp.arange(mod_format_count)
    req_snr = params.modulations_array.val[:, 2] + params.snr_margin
    acceptable_mod_format_indices = jnp.where(
        snr_value >= req_snr,
        acceptable_mod_format_indices,
        jnp.full((mod_format_count,), -2),
    )
    return acceptable_mod_format_indices

get_centre_frequency(initial_slot_index, num_slots, params)

Get centre frequency for new lightpath.

Looks up pre-computed per-slot centre frequencies from params.slot_centre_freq_array and returns the midpoint of the first and last slot in the channel. This correctly handles non-uniform slot spacing (e.g. inter-band gap slots).

Parameters:

Name Type Description Default
initial_slot_index Array

Index of the first slot of the channel.

required
num_slots float

Number of slots occupied by the channel.

required
params RSAGNModelEnvParams

Environment parameters.

required

Returns:

Type Description
Array

chex.Array: Centre frequency for new lightpath (relative GHz offset

Array

from ref_lambda).

Source code in xlron/environments/env_funcs.py
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
@partial(jax.jit, static_argnums=(2,))
def get_centre_frequency(
    initial_slot_index: int, num_slots: int, params: RSAGNModelEnvParams
) -> chex.Array:
    """Get centre frequency for new lightpath.

    Looks up pre-computed per-slot centre frequencies from
    ``params.slot_centre_freq_array`` and returns the midpoint of the first
    and last slot in the channel.  This correctly handles non-uniform slot
    spacing (e.g. inter-band gap slots).

    Args:
        initial_slot_index (chex.Array): Index of the first slot of the channel.
        num_slots (float): Number of slots occupied by the channel.
        params (RSAGNModelEnvParams): Environment parameters.

    Returns:
        chex.Array: Centre frequency for new lightpath (relative GHz offset
        from ref_lambda).
    """
    slot_centres = params.slot_centre_freq_array.val  # (link_resources,) relative GHz
    initial_slot_index = jnp.asarray(initial_slot_index, dtype=dtype_config.INDEX_DTYPE)
    num_slots = jnp.asarray(num_slots, dtype=dtype_config.INDEX_DTYPE)
    first_slot_centre = slot_centres[initial_slot_index]
    last_slot_idx = jnp.minimum(initial_slot_index + num_slots - 1, params.link_resources - 1)
    last_slot_centre = slot_centres[last_slot_idx]
    return (first_slot_centre + last_slot_centre) / 2

get_edge_disjoint_paths(graph)

Get edge disjoint paths between all nodes in graph.

Parameters:

Name Type Description Default
graph Graph

graph

required

Returns:

Name Type Description
dict Dict[int, Dict[int, List[Tuple[int, int]]]]

edge disjoint paths (path is list of edges)

Source code in xlron/environments/env_funcs.py
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
def get_edge_disjoint_paths(graph: nx.Graph) -> Dict[int, Dict[int, List[Tuple[int, int]]]]:
    """Get edge disjoint paths between all nodes in graph.

    Args:
        graph: graph

    Returns:
        dict: edge disjoint paths (path is list of edges)
    """
    result = {n: {} for n in graph}
    for n1, n2 in itertools.combinations(graph, 2):
        # Sort by number of links in path
        # TODO - sort by path length
        result[n1][n2] = sorted(list(nx.edge_disjoint_paths(graph, n1, n2)), key=len)
        result[n2][n1] = sorted(list(nx.edge_disjoint_paths(graph, n2, n1)), key=len)
    return result

get_launch_power(state, path_action, power_action, params)

Get launch power for new lightpath. N.B. launch power is specified in dBm but is converted to linear units when stored in channel_power_array. This func returns linear units (mW). Path action is used to determine the launch power in the case of tabular launch power type. Power action is used to determine the launch power in the case of RL launch power type. During masking, power action is set as state.launch_power_array[0], which is set by the RL agent. Args: state (EnvState): Environment state path_action (chex.Array): Action specifying path index (0 to k_paths-1) power_action (chex.Array): Action specifying launch power in dBm params (EnvParams): Environment parameters Returns: chex.Array: Launch power for new lightpath

Source code in xlron/environments/env_funcs.py
4235
4236
4237
4238
4239
4240
4241
4242
4243
4244
4245
4246
4247
4248
4249
4250
4251
4252
4253
4254
4255
4256
4257
4258
4259
4260
4261
4262
4263
4264
4265
4266
4267
4268
4269
4270
4271
4272
4273
4274
4275
4276
4277
4278
4279
4280
4281
4282
4283
4284
4285
4286
4287
4288
4289
4290
4291
4292
4293
4294
4295
4296
4297
@partial(jax.jit, static_argnums=(3,))
def get_launch_power(
    state: EnvState,
    path_action: chex.Array,
    power_action: chex.Array,
    params: EnvParams,
) -> chex.Array:
    """Get launch power for new lightpath. N.B. launch power is specified in dBm but is converted to linear units
    when stored in channel_power_array. This func returns linear units (mW).
    Path action is used to determine the launch power in the case of tabular launch power type.
    Power action is used to determine the launch power in the case of RL launch power type. During masking,
    power action is set as state.launch_power_array[0], which is set by the RL agent.
    Args:
        state (EnvState): Environment state
        path_action (chex.Array): Action specifying path index (0 to k_paths-1)
        power_action (chex.Array): Action specifying launch power in dBm
        params (EnvParams): Environment parameters
    Returns:
        chex.Array: Launch power for new lightpath
    """
    k_path_index, _ = process_path_action(state, params, path_action)
    if params.launch_power_type == "fixed":
        return state.launch_power_array[0]
    elif params.launch_power_type == "tabular":
        nodes_sd, requested_datarate = read_rsa_request(state.request_array)
        source, dest = nodes_sd
        i = get_path_indices(
            params,
            source,
            dest,
            params.k_paths,
            params.num_nodes,
            directed=params.directed_graph,
        ).astype(jnp.int32)
        return state.launch_power_array[i + k_path_index]
    elif params.launch_power_type == "rl":
        return power_action
    elif params.launch_power_type == "scaled":
        nodes_sd, requested_datarate = read_rsa_request(state.request_array)
        source, dest = nodes_sd
        i = get_path_indices(
            params,
            source,
            dest,
            params.k_paths,
            params.num_nodes,
            directed=params.directed_graph,
        )
        # Get path length
        link_length_array = jnp.sum(params.link_length_array.val, axis=1, promote_integers=False)
        path_length = jnp.sum(link_length_array[i + k_path_index], promote_integers=False)
        path_link_array = (
            jnp.unpackbits(params.path_link_array.val)[:, params.num_links]
            if params.pack_path_bits
            else params.path_link_array.val
        )
        maximum_path_length = jnp.max(jnp.dot(path_link_array, params.link_length_array.val))
        return state.launch_power_array[0] * (path_length / maximum_path_length)
    else:
        raise ValueError(
            f"Invalid launch_power_type '{params.launch_power_type}'. "
            "Must be 'fixed', 'tabular', 'rl', or 'scaled'."
        )

get_lightpath_snr(state, params)

Get SNR for each link on path. N.B. that in most cases it is more efficient to calculate the SNR for every possible path, rather than a slot-by-slot basis. But in some cases slot-by-slot is better i.e. when kN(N-1)/2 > LS Args: state (RSAGNModelEnvState): Environment state params (RSAGNModelEnvParams): Environment parameters

Returns:

Type Description
Array

chex.array: SNR for each link on path

Source code in xlron/environments/env_funcs.py
3415
3416
3417
3418
3419
3420
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430
3431
3432
3433
3434
3435
def get_lightpath_snr(state: RSAGNModelEnvParams, params: RSAGNModelEnvParams) -> chex.Array:
    """Get SNR for each link on path.
    N.B. that in most cases it is more efficient to calculate the SNR for every possible path, rather than a slot-by-slot basis.
    But in some cases slot-by-slot is better i.e. when k*N(N-1)/2 > L*S
    Args:
        state (RSAGNModelEnvState): Environment state
        params (RSAGNModelEnvParams): Environment parameters

    Returns:
        chex.array: SNR for each link on path
    """
    # Get the SNR for the channel that the path occupies
    path_snr_array = jax.vmap(get_snr_for_path, in_axes=(0, None, None, None))(
        params.path_link_array.val, state.link_snr_array, params, state
    )
    # Where value in path_index_array matches index of path_snr_array, substitute in SNR value
    slot_indices = jnp.arange(params.link_resources)
    lightpath_snr_array = jax.vmap(
        jax.vmap(lambda x, si: path_snr_array[x][si], in_axes=(0, 0)), in_axes=(0, None)
    )(state.path_index_array, slot_indices)
    return lightpath_snr_array

get_line_graph_laplacian(graph)

Compute the Laplacian matrix of the line graph.

Parameters:

Name Type Description Default
graph Graph

NetworkX graph (original topology)

required

Returns:

Type Description
Array

Laplacian matrix of the line graph as a JAX array

Source code in xlron/environments/env_funcs.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def get_line_graph_laplacian(graph: nx.Graph) -> chex.Array:
    """Compute the Laplacian matrix of the line graph.

    Args:
        graph: NetworkX graph (original topology)

    Returns:
        Laplacian matrix of the line graph as a JAX array
    """
    line_graph = make_line_graph(graph)
    if line_graph.is_directed():
        laplacian = nx.directed_laplacian_matrix(line_graph)
    else:
        laplacian = nx.laplacian_matrix(line_graph).todense()
    return jnp.array(laplacian, dtype=dtype_config.LARGE_FLOAT_DTYPE)

get_line_graph_spectral_features(graph, num_features)

Compute spectral features for edges using the line graph Laplacian.

These features are used as positional encodings for transformer architectures with WiRE (Wavelet-Induced Rotary Encodings).

Parameters:

Name Type Description Default
graph Graph

NetworkX graph (original topology)

required
num_features int

Number of spectral features to compute

required

Returns:

Type Description
Array

Array of shape (num_edges, num_features) containing eigenvectors of the

Array

line graph Laplacian, ordered by ascending eigenvalue magnitude.

Array

These serve as positional encodings for edge/link tokens.

Source code in xlron/environments/env_funcs.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def get_line_graph_spectral_features(graph: nx.Graph, num_features: int) -> chex.Array:
    """Compute spectral features for edges using the line graph Laplacian.

    These features are used as positional encodings for transformer architectures
    with WiRE (Wavelet-Induced Rotary Encodings).

    Args:
        graph: NetworkX graph (original topology)
        num_features: Number of spectral features to compute

    Returns:
        Array of shape (num_edges, num_features) containing eigenvectors of the
        line graph Laplacian, ordered by ascending eigenvalue magnitude.
        These serve as positional encodings for edge/link tokens.
    """
    line_laplacian = get_line_graph_laplacian(graph)
    num_edges = line_laplacian.shape[0]
    # Clamp num_features to available dimensions
    actual_features = min(num_features, num_edges)
    eigenvalues, eigenvectors = jnp.linalg.eigh(line_laplacian)
    return eigenvectors[:, :actual_features].astype(dtype_config.LARGE_FLOAT_DTYPE)

Compute 4 link relevance features for the current request.

Parameters:

Name Type Description Default
paths Array

(k, E) binary path-link indicators

required
paths_se Array

(k, 1) spectral efficiency per path

required
requested_datarate Array

(1,)

required
params RSAEnvParams

environment parameters

required

Returns:

Type Description

(E, 4) array with columns: 0: weighted_relevance - combined rank/SE weighted sum across paths 1: path_count - fraction of k paths using each link 2: best_rank - 1 - min_rank/k for links on any path, 0 otherwise 3: best_se - max SE among paths through link, normalized

Source code in xlron/environments/env_funcs.py
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
def get_link_relevance_array(
    paths: Array, paths_se: Array, requested_datarate: Array, params: RSAEnvParams
):
    """Compute 4 link relevance features for the current request.

    Args:
        paths: (k, E) binary path-link indicators
        paths_se: (k, 1) spectral efficiency per path
        requested_datarate: (1,)
        params: environment parameters

    Returns:
        (E, 4) array with columns:
            0: weighted_relevance - combined rank/SE weighted sum across paths
            1: path_count - fraction of k paths using each link
            2: best_rank - 1 - min_rank/k for links on any path, 0 otherwise
            3: best_se - max SE among paths through link, normalized
    """
    k = params.k_paths

    # --- Feature 1: Weighted relevance (existing logic) ---
    ranks = jnp.arange(k)
    rank_weights = 1.0 / (ranks + 1.0)
    num_slots = jax.vmap(
        lambda x: required_slots(
            requested_datarate,
            x,
            params.slot_size,
            guardband=params.guardband,
            temperature=params.temperature,
        )
    )(paths_se.flatten())
    slot_weights = 1.0 / num_slots
    weights = rank_weights * slot_weights.flatten()
    weights = weights / (jnp.sum(weights) + 1e-8)
    weighted_paths = paths * weights[:, None]
    weighted_relevance = jnp.sum(weighted_paths, axis=0)  # (E,)

    # --- Feature 2: Path count - fraction of k paths using each link ---
    path_count = jnp.sum(paths, axis=0) / k  # (E,)

    # --- Feature 3: Best rank - 1 - min_rank/k for links on any path, 0 otherwise ---
    rank_per_path = jnp.arange(k).reshape(k, 1)  # (k, 1)
    # Where path uses link, use rank; otherwise use k (sentinel)
    rank_masked = jnp.where(paths > 0, rank_per_path, k)  # (k, E)
    min_rank = jnp.min(rank_masked, axis=0)  # (E,)
    on_any_path = (path_count > 0).astype(jnp.float32)  # (E,)
    best_rank = on_any_path * (1.0 - min_rank / k)  # (E,)

    # --- Feature 4: Best SE - max SE among paths through link, normalized ---
    se_vals = paths_se.flatten()  # (k,)
    max_se = jnp.max(se_vals) + 1e-8
    se_masked = jnp.where(paths > 0, se_vals[:, None], 0.0)  # (k, E)
    best_se = jnp.max(se_masked, axis=0) / max_se  # (E,)

    return jnp.stack([weighted_relevance, path_count, best_rank, best_se], axis=-1)  # (E, 4)

get_minimum_snr_of_channels_on_path(state, path, slot_index, req_slots, params)

Get the minimum value of the SNR on newly assigned channels. N.B. this requires the link_snr_array to have already been calculated and present in state.

Source code in xlron/environments/env_funcs.py
3930
3931
3932
3933
3934
3935
3936
3937
3938
3939
3940
3941
3942
3943
3944
3945
3946
3947
3948
3949
3950
@partial(jax.jit, static_argnums=(2,))
def get_minimum_snr_of_channels_on_path(
    state: RSAGNModelEnvState,
    path: chex.Array,
    slot_index: chex.Array,
    req_slots: int,
    params: RSAGNModelEnvParams,
) -> chex.Array:
    """Get the minimum value of the SNR on newly assigned channels.
    N.B. this requires the link_snr_array to have already been calculated and present in state."""
    snr_value_all_channels = get_snr_for_path(path, state.link_snr_array, params, state)
    min_snr_value_sub_channels = jnp.min(
        jnp.concatenate(
            [
                snr_value_all_channels[slot_index].reshape((1,)),
                snr_value_all_channels[slot_index + req_slots - 1].reshape((1,)),
            ],
            axis=0,
        )
    )
    return min_snr_value_sub_channels

get_obs_transformer(state, params)

Retrieves observation for transformer model.

Creates tokens for each link/edge. Column order: [wire_features | edge_features | traffic_marginals | request-specific...] where request-specific features are at the end so the critic can strip them.

Request-specific columns (stripped for critic): - holding_time (1 col, departure mode only) - request_size (1 col) - link_relevance (4 cols)

Parameters:

Name Type Description Default
state RSAEnvState

Environment state

required
params RSAEnvParams

Environment parameters

required

Returns:

Name Type Description
tokens Array

Array of shape (num_links, input_size)

Source code in xlron/environments/env_funcs.py
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
@partial(jax.jit, static_argnums=(1,))
def get_obs_transformer(state: RSAEnvState, params: RSAEnvParams) -> chex.Array:
    """Retrieves observation for transformer model.

    Creates tokens for each link/edge. Column order:
        [wire_features | edge_features | traffic_marginals | request-specific...]
    where request-specific features are at the end so the critic can strip them.

    Request-specific columns (stripped for critic):
        - holding_time (1 col, departure mode only)
        - request_size (1 col)
        - link_relevance (4 cols)

    Args:
        state: Environment state
        params: Environment parameters

    Returns:
        tokens: Array of shape (num_links, input_size)
    """
    # Get line graph spectral features (WiRE positional encodings)
    wire_features = params.line_graph_spectral_features.val

    # Get edge features based on traffic type (WITHOUT holding time - that's request-specific)
    if params.transformer_obs_type == "occupancy":
        edge_features = state.link_slot_array
    elif params.transformer_obs_type == "capacity":
        edge_features = state.link_capacity_array / 1e6
    else:
        # Dynamic traffic: normalized departure times only
        edge_features = state.link_slot_departure_array / params.mean_service_holding_time

    # --- Traffic matrix node marginal features (NOT request-specific) ---
    send_load = jnp.sum(state.traffic_matrix, axis=1)  # (N,) row marginals
    recv_load = jnp.sum(state.traffic_matrix, axis=0)  # (N,) col marginals
    send_load = send_load / (jnp.sum(send_load) + 1e-8)
    recv_load = recv_load / (jnp.sum(recv_load) + 1e-8)
    link_src = params.edges.val[:, 0].astype(jnp.int32)  # (E,)
    link_dst = params.edges.val[:, 1].astype(jnp.int32)  # (E,)
    endpoint_send = (send_load[link_src] + send_load[link_dst]).reshape(-1, 1)  # (E, 1)
    endpoint_recv = (recv_load[link_src] + recv_load[link_dst]).reshape(-1, 1)  # (E, 1)
    traffic_marginal_features = jnp.concatenate([endpoint_send, endpoint_recv], axis=-1)  # (E, 2)

    # --- Request-specific features (critic should NOT see these) ---
    nodes_sd, requested_datarate = read_rsa_request(state.request_array)

    # Normalized request size
    max_bw = jnp.max(params.values_bw.val)
    request_size_feature = jnp.full(
        (params.num_links, 1),
        requested_datarate / (max_bw + 1e-8),
    )  # (E, 1)

    # Link relevance (4 features)
    paths_se = get_paths_se(params, nodes_sd)
    paths = get_paths(params, nodes_sd)
    link_relevance_features = get_link_relevance_array(
        paths, paths_se, requested_datarate, params
    )  # (E, 4)

    # Concatenation: shared features first, request-specific features last
    # Shared: wire_features, edge_features, traffic_marginals
    # Request-specific: [holding_time (departure only),] request_size, link_relevance
    shared = [wire_features, edge_features, traffic_marginal_features]
    request_specific = [request_size_feature, link_relevance_features]
    if params.transformer_obs_type == "departure":
        holding_time_col = jnp.full(
            (params.num_links, 1),
            state.holding_time / params.mean_service_holding_time,
        )  # (E, 1)
        request_specific = [holding_time_col] + request_specific

    tokens = jnp.concatenate(shared + request_specific, axis=-1)

    return tokens

get_path(params, nodes, k_path_index)

Get k paths between source and destination

Source code in xlron/environments/env_funcs.py
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
@partial(jax.jit, static_argnums=(0,))
def get_path(params: EnvParams, nodes: Array, k_path_index: int) -> Array:
    """Get k paths between source and destination"""
    path_index = get_path_index(params, nodes, k_path_index)
    path = differentiable_indexing(
        params.path_link_array.val,
        path_index,
        params.temperature,
        params.differentiable,
    )
    if params.pack_path_bits:
        path = jnp.unpackbits(path)[: params.num_links]
    return path

get_path_and_se(params, nodes, k_path_index)

Get k paths and their spectral efficiencies between source and destination

Source code in xlron/environments/env_funcs.py
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
def get_path_and_se(params: EnvParams, nodes: Array, k_path_index: int) -> Tuple[Array, Array]:
    """Get k paths and their spectral efficiencies between source and destination"""
    path_index = get_path_index(params, nodes, k_path_index)
    path = differentiable_indexing(
        params.path_link_array.val,
        path_index,
        params.temperature,
        params.differentiable,
    )
    if params.pack_path_bits:
        path = jnp.unpackbits(path)[: params.num_links]
    se = differentiable_indexing(
        params.path_se_array.val, path_index, params.temperature, params.differentiable
    )
    return path, se

get_path_from_path_index_array(path_index_array, path_link_array)

Get path from path index array. Args: path_index_array (chex.Array): Path index array path_link_array (chex.Array): Path link array

Returns:

Type Description
Array

jnp.array: path index values replaced with binary path-link arrays

Source code in xlron/environments/env_funcs.py
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
@partial(jax.jit, static_argnums=(1,))
def get_path_from_path_index_array(
    path_index_array: chex.Array, path_link_array: chex.Array
) -> chex.Array:
    """Get path from path index array.
    Args:
        path_index_array (chex.Array): Path index array
        path_link_array (chex.Array): Path link array

    Returns:
        jnp.array: path index values replaced with binary path-link arrays
    """

    # TODO - support unpacking bits (if this function ends up being used)
    def get_index_from_link(link):
        return jax.vmap(lambda x: path_link_array[x], in_axes=(0,))(link)

    return jax.vmap(get_index_from_link, in_axes=(0,))(path_index_array)

get_path_index(params, nodes, k_path_index)

Get k paths between source and destination

Source code in xlron/environments/env_funcs.py
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
@partial(jax.jit, static_argnums=(0,))
def get_path_index(params: EnvParams, nodes: Array, k_path_index: int) -> Array:
    """Get k paths between source and destination"""
    source, dest = nodes
    starting_index = get_path_indices(
        params,
        source,
        dest,
        params.k_paths,
        params.num_nodes,
        directed=params.directed_graph,
    ).astype(jnp.int32)
    path_index = starting_index + k_path_index
    return path_index

get_path_index_array(params, nodes)

Indices of paths between source and destination from path array

Source code in xlron/environments/env_funcs.py
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
@partial(jax.jit, static_argnums=(0,))
def get_path_index_array(params: EnvParams, nodes: Array) -> Array:
    """Indices of paths between source and destination from path array"""
    # get source and destination nodes in order (for accurate indexing of path-link array)
    source, dest = nodes.astype(dtype_config.LARGE_INT_DTYPE)
    i = get_path_indices(
        params, source, dest, params.k_paths, params.num_nodes, directed=params.directed_graph
    )
    index_array = differentiable_indexing(
        jnp.arange(0, params.path_link_array.shape[0], dtype=dtype_config.LARGE_INT_DTYPE),
        i + jnp.arange(params.k_paths, dtype=dtype_config.LARGE_FLOAT_DTYPE),
        params.temperature,
        params.differentiable,
    )
    return index_array

get_path_se(params, nodes, k_path_index)

Get k paths between source and destination

Source code in xlron/environments/env_funcs.py
1113
1114
1115
1116
1117
1118
1119
@partial(jax.jit, static_argnums=(0,))
def get_path_se(params, nodes, k_path_index):
    """Get k paths between source and destination"""
    path_index = get_path_index(params, nodes, k_path_index)
    return differentiable_indexing(
        params.path_se_array.val, path_index, params.temperature, params.differentiable
    )

get_path_slots(link_slot_array, params, nodes_sd, i, agg_func='max')

Get slots on each constitutent link of path from link_slot_array (L x S), then aggregate to get (S x 1) representation of slots on path.

Parameters:

Name Type Description Default
link_slot_array Array

link-slot array

required
params EnvParams

environment parameters

required
nodes_sd Array

source-destination nodes

required
i int

path index

required
agg_func str

aggregation function (max or sum). If max, result will be available slots on path. If sum, result will contain information on edge features. if mean, will be mean.

'max'

Returns:

Name Type Description
slots Array

slots on path

Source code in xlron/environments/env_funcs.py
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
@partial(jax.jit, static_argnums=(1, 4))
def get_path_slots(
    link_slot_array: chex.Array,
    params: EnvParams,
    nodes_sd: chex.Array,
    i: int,
    agg_func: str = "max",
) -> chex.Array:
    """Get slots on each constitutent link of path from link_slot_array (L x S),
    then aggregate to get (S x 1) representation of slots on path.

    Args:
        link_slot_array: link-slot array
        params: environment parameters
        nodes_sd: source-destination nodes
        i: path index
        agg_func: aggregation function (max or sum).
            If max, result will be available slots on path.
            If sum, result will contain information on edge features.
            if mean, will be mean.

    Returns:
        slots: slots on path
    """
    path = get_path(params, nodes_sd, i)
    slots = path[:, None] * link_slot_array
    # Make any -1s positive then get max for each slot across links
    if agg_func == "max":
        # Use this for getting slots from link_slot_array
        slots = jnp.max(jnp.absolute(slots), axis=0)
    elif agg_func == "sum":
        # TODO - consider using an RNN (or S5) to aggregate edge features
        # Use this (or alternative) for aggregating edge features from GNN
        slots = jnp.sum(slots, axis=0, promote_integers=False)
    elif agg_func == "mean":
        # Use this for getting mean value in slot index along path
        slots = jnp.mean(slots, axis=0)
    elif agg_func == "min":
        # Use this for getting slots from link_slot_array
        slots = jnp.min(slots, axis=0)
    else:
        raise ValueError("agg_func must be 'max' or 'sum' or 'mean' or min")
    return slots

get_paths(params, nodes)

Get k paths between source and destination

Source code in xlron/environments/env_funcs.py
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
@partial(jax.jit, static_argnums=(0,))
def get_paths(params: EnvParams, nodes: Array) -> Array:
    """Get k paths between source and destination"""
    index_array = get_path_index_array(params, nodes)
    paths = differentiable_indexing(
        params.path_link_array.val,
        index_array,
        params.temperature,
        params.differentiable,
    )
    if params.pack_path_bits:
        paths = jnp.unpackbits(paths, axis=1)[:, : params.num_links]
    return paths

get_paths_obs_gn_model(state, params)

Get observation space for launch power optimization (with numerical stability).

Source code in xlron/environments/env_funcs.py
4300
4301
4302
4303
4304
4305
4306
4307
4308
4309
4310
4311
4312
4313
4314
4315
4316
4317
4318
4319
4320
4321
4322
4323
4324
4325
4326
4327
4328
4329
4330
4331
4332
4333
4334
4335
4336
4337
4338
4339
4340
4341
4342
4343
4344
4345
4346
4347
4348
4349
4350
4351
4352
4353
4354
4355
4356
4357
4358
4359
4360
4361
4362
4363
4364
4365
4366
4367
4368
4369
4370
4371
4372
4373
4374
4375
4376
4377
4378
4379
4380
4381
@partial(jax.jit, static_argnums=(1,))
def get_paths_obs_gn_model(state: RSAGNModelEnvState, params: RSAGNModelEnvParams) -> chex.Array:
    # TODO - make this just show the stats from just one path at a time
    """Get observation space for launch power optimization (with numerical stability)."""
    request_array = state.request_array.reshape((-1,))
    path_stats = calculate_path_stats(state, params, request_array)
    # Remove first 3 items of path stats for each path
    path_stats = path_stats[:, 3:]
    link_length_array = jnp.sum(params.link_length_array.val, axis=1, promote_integers=False)
    lightpath_snr_array = get_lightpath_snr(state, params)
    nodes_sd, requested_datarate = read_rsa_request(request_array)
    source, dest = nodes_sd

    def calculate_gn_path_stats(k_path_index, init_val):
        # Get path index
        path_index = (
            get_path_indices(params, source, dest, params.k_paths, params.num_nodes) + k_path_index
        )
        path_link_array = (
            jnp.unpackbits(params.path_link_array.val, axis=1)[:, : params.num_links]
            if params.pack_path_bits
            else params.path_link_array.val
        )
        path = params.path_link_array[path_index]
        path_length = jnp.dot(path, link_length_array)
        max_path_length = jnp.max(jnp.dot(path_link_array, link_length_array))
        path_length / max_path_length
        max_path_length_hops = jnp.max(jnp.sum(path_link_array, axis=1, promote_integers=False))
        path_length_hops_norm = (
            jnp.sum(path, promote_integers=False).astype(dtype_config.LARGE_FLOAT_DTYPE)
            / max_path_length_hops
        )
        # Connections on path
        num_connections = jnp.where(
            path == 1,
            jnp.where(state.channel_power_array > 0, one, zero).sum(axis=1),
            zero,
        ).sum()
        num_connections_norm = num_connections / jnp.array(
            params.link_resources, dtype=dtype_config.LARGE_FLOAT_DTYPE
        )
        # Mean power of connections on path
        # make path with row length equal to link_resource (+1 to avoid zero division)
        mean_power_norm = jnp.where(
            path == one, state.channel_power_array.sum(axis=1), zero
        ).sum() / (jnp.where(num_connections > zero, num_connections, one) * params.max_power)
        # Mean SNR of connections on the path links
        max_snr = jnp.array(
            50, dtype=dtype_config.LARGE_FLOAT_DTYPE
        )  # Nominal value for max GSNR in dB
        mean_snr_norm = jnp.where(path == one, lightpath_snr_array.sum(axis=1), zero).sum(
            promote_integers=False
        ) / (jnp.where(num_connections > zero, num_connections, one) * max_snr)
        return jax.lax.dynamic_update_slice(
            init_val,
            jnp.array(
                [
                    [
                        path_length,
                        path_length_hops_norm,
                        num_connections_norm,
                        mean_power_norm,
                        mean_snr_norm,
                    ]
                ]
            ),
            (k_path_index, 0),
        )

    gn_path_stats = jnp.zeros((params.k_paths, 5), dtype=dtype_config.LARGE_FLOAT_DTYPE)
    gn_path_stats = jax.lax.fori_loop(0, params.k_paths, calculate_gn_path_stats, gn_path_stats)
    all_stats = jnp.concatenate([path_stats, gn_path_stats], axis=1)
    return jnp.concatenate(
        (
            jnp.array([source]),
            jnp.reshape(requested_datarate / 100.0, (-1,)),
            jnp.array([dest]),
            jnp.reshape(state.holding_time, (-1,)),
            jnp.reshape(all_stats, (-1,)),
        ),
        axis=0,
    )

get_paths_se(params, nodes)

Get max. spectral efficiency of modulation format on k paths between source and destination

Source code in xlron/environments/env_funcs.py
1122
1123
1124
1125
1126
1127
1128
1129
@partial(jax.jit, static_argnums=(0,))
def get_paths_se(params, nodes):
    """Get max. spectral efficiency of modulation format on k paths between source and destination"""
    # get source and destination nodes in order (for accurate indexing of path-link array)
    index_array = get_path_index_array(params, nodes)
    return differentiable_indexing(
        params.path_se_array.val, index_array, params.temperature, params.differentiable
    )

get_required_snr_se_kurtosis_array(modulation_format_index_array, col_index, params)

Convert modulation format index to required SNR or spectral efficiency. Modulation format index array contains the index of the modulation format used by the channel. The modulation index references a row in the modulations array, which contains SNR and SE values.

Parameters:

Name Type Description Default
modulation_format_index_array Array

Modulation format index array

required
col_index int

Column index for required SNR or spectral efficiency

required
params RSAGNModelEnvParams

Environment parameters

required

Returns:

Type Description
Array

jnp.array: Required SNR for each channel (min. SNR for empty channel (mod. index 0))

Source code in xlron/environments/env_funcs.py
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
@partial(
    jax.jit,
    static_argnums=(
        1,
        2,
    ),
)
def get_required_snr_se_kurtosis_array(
    modulation_format_index_array: chex.Array,
    col_index: int,
    params: RSAGNModelEnvParams,
) -> chex.Array:
    """Convert modulation format index to required SNR or spectral efficiency.
    Modulation format index array contains the index of the modulation format used by the channel.
    The modulation index references a row in the modulations array, which contains SNR and SE values.

    Args:
        modulation_format_index_array (chex.Array): Modulation format index array
        col_index (int): Column index for required SNR or spectral efficiency
        params (RSAGNModelEnvParams): Environment parameters

    Returns:
        jnp.array: Required SNR for each channel (min. SNR for empty channel (mod. index 0))
    """
    return jax.vmap(get_required_snr_se_kurtosis_on_link, in_axes=(0, None, None))(
        modulation_format_index_array, col_index, params
    )

Get SNR per link Args: state (EnvState): Environment state params (EnvParams): Environment parameters Returns: jnp.array: SNR per link

Source code in xlron/environments/env_funcs.py
3483
3484
3485
3486
3487
3488
3489
3490
3491
3492
3493
3494
3495
3496
3497
3498
3499
3500
3501
3502
3503
3504
3505
3506
3507
3508
3509
3510
3511
3512
3513
3514
3515
3516
3517
3518
3519
3520
3521
3522
3523
3524
3525
3526
3527
3528
3529
3530
3531
3532
3533
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544
@partial(jax.jit, static_argnums=(1,))
def get_snr_link_array(state: EnvState, params: EnvParams) -> chex.Array:
    """Get SNR per link
    Args:
        state (EnvState): Environment state
        params (EnvParams): Environment parameters
    Returns:
        jnp.array: SNR per link
    """

    def get_link_snr(link_index, state, params):
        # Get channel power, channel centre, bandwidth, and noise figure
        link_lengths = params.link_length_array[link_index, :]  # in metres
        num_spans = jnp.ceil(jnp.sum(link_lengths) / params.max_span_length).astype(
            dtype_config.LARGE_INT_DTYPE
        )
        if params.mod_format_correction:
            mod_format_link = state.modulation_format_index_array[link_index, :]
            kurtosis_link = get_required_snr_se_kurtosis_on_link(mod_format_link, 4, params)
            se_link = get_required_snr_se_kurtosis_on_link(mod_format_link, 1, params)
        else:
            kurtosis_link = jnp.zeros(params.link_resources).astype(jnp.float32)
            se_link = jnp.ones(params.link_resources, dtype=jnp.float32)
        bw_link = state.channel_centre_bw_array[link_index, :]
        ch_power_link = state.channel_power_array[link_index, :]
        ch_centres_link = state.channel_centre_freq_array[link_index, :]

        # Calculate SNR
        P = dict(
            num_channels=params.link_resources,
            num_spans=num_spans,
            max_spans=params.max_spans,
            ref_lambda=params.ref_lambda,
            length=link_lengths,
            attenuation_i=jnp.array(params.attenuation),
            attenuation_bar_i=jnp.array(params.attenuation_bar),
            nonlinear_coeff=jnp.array(params.nonlinear_coeff),
            raman_gain_slope_i=jnp.array(params.raman_gain_slope),
            dispersion_coeff=jnp.array(params.dispersion_coeff),
            dispersion_slope=jnp.array(params.dispersion_slope),
            coherent=params.coherent,
            num_roadms=params.num_roadms,
            roadm_loss=params.roadm_loss,
            amplifier_noise_figure=params.amplifier_noise_figure.val,
            transceiver_snr=params.transceiver_snr.val,
            mod_format_correction=params.mod_format_correction,
            ch_power_w_i=ch_power_link,
            ch_centre_i=ch_centres_link * 1e9,
            ch_bandwidth_i=bw_link * 1e9,
            excess_kurtosis_i=kurtosis_link,
            uniform_spans=params.uniform_spans,
            num_subchannels=params.num_subchannels,
        )
        snr = isrs_gn_model.get_snr(**P)[0]

        return snr

    link_snr_array = jax.vmap(get_link_snr, in_axes=(0, None, None))(
        jnp.arange(params.num_links), state, params
    )
    link_snr_array = jnp.nan_to_num(link_snr_array, nan=1e-5)
    return link_snr_array

Get SNR per link using fused computation (uniform spans, no mod_format_correction).

Drop-in replacement for get_snr_link_array that uses get_snr_fused to reduce XLA op count and kernel launch overhead on GPU.

Source code in xlron/environments/env_funcs.py
3547
3548
3549
3550
3551
3552
3553
3554
3555
3556
3557
3558
3559
3560
3561
3562
3563
3564
3565
3566
3567
3568
3569
3570
3571
3572
3573
3574
3575
3576
3577
3578
3579
3580
3581
3582
3583
3584
3585
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595
3596
3597
3598
3599
3600
3601
3602
3603
3604
3605
3606
3607
3608
3609
3610
3611
@partial(jax.jit, static_argnums=(1,))
def get_snr_link_array_fused(state: EnvState, params: EnvParams) -> chex.Array:
    """Get SNR per link using fused computation (uniform spans, no mod_format_correction).

    Drop-in replacement for get_snr_link_array that uses get_snr_fused to
    reduce XLA op count and kernel launch overhead on GPU.
    """
    # Precompute per-link num_spans and span_length from static link_length_array
    # link_length_array is (L, max_spans) in metres
    link_lengths_m = params.link_length_array.val  # (L, max_spans)
    total_link_length_m = jnp.sum(link_lengths_m, axis=1)  # (L,)
    num_spans_per_link = jnp.ceil(total_link_length_m / params.max_span_length).astype(
        dtype_config.LARGE_INT_DTYPE
    )  # (L,)
    # span_length in metres to match GN model expectations
    span_length_per_link = total_link_length_m / jnp.maximum(num_spans_per_link, 1).astype(
        jnp.float32
    )  # (L,) in metres

    # Per-link channel data from state: (L, N)
    ch_power_all = state.channel_power_array  # (L, N)
    bw_all = state.channel_centre_bw_array  # (L, N) in GHz

    # Use cached centre frequencies from state
    ch_centres_hz = state.channel_centre_freq_array * 1e9  # (L, N) GHz -> Hz
    bw_hz = bw_all * 1e9  # (L, N)

    # Tile amplifier noise figure and transceiver SNR to (L, N) for vmap
    amp_nf = jnp.broadcast_to(params.amplifier_noise_figure.val, ch_power_all.shape)
    trx_snr = jnp.broadcast_to(params.transceiver_snr.val, ch_power_all.shape)

    def _link_snr_fused(ch_pow, ch_centre, ch_bw, n_spans, s_length, amp_nf_link, trx_snr_link):
        return isrs_gn_model.get_snr_fused(
            ch_power_w_i=ch_pow,
            ch_centre_i=ch_centre,
            ch_bandwidth_i=ch_bw,
            num_spans=n_spans,
            span_length=s_length,
            num_channels=params.link_resources,
            ref_lambda=params.ref_lambda,
            attenuation=params.attenuation,
            attenuation_bar=params.attenuation_bar,
            nonlinear_coeff=params.nonlinear_coeff,
            raman_gain_slope=params.raman_gain_slope,
            dispersion_coeff=params.dispersion_coeff,
            dispersion_slope=params.dispersion_slope,
            amplifier_noise_figure=amp_nf_link,
            transceiver_snr=trx_snr_link,
            roadm_loss=params.roadm_loss,
            num_roadms=params.num_roadms,
            coherent=params.coherent,
            num_subchannels=params.num_subchannels,
        )

    link_snr_array = jax.vmap(_link_snr_fused)(
        ch_power_all,
        ch_centres_hz,
        bw_hz,
        num_spans_per_link,
        span_length_per_link,
        amp_nf,
        trx_snr,
    )
    link_snr_array = jnp.nan_to_num(link_snr_array, nan=1e-5)
    return link_snr_array

get_spectral_features(laplacian, num_features)

Compute spectral node features from symmetric normalized graph Laplacian.

Parameters:

Name Type Description Default
adj

Adjacency matrix of the graph

required
num_features int

Number of eigenvector features to extract

required

Returns:

Type Description
Array

Array of shape (n_nodes, num_features) containing eigenvectors corresponding

Array

to the smallest non-zero eigenvalues of the graph Laplacian. If the graph has

Array

fewer nodes than num_features, the result is zero-padded to have num_features columns.

Notes
  • Skips trivial eigenvectors (those with near-zero eigenvalues)
  • Eigenvectors are ordered by ascending eigenvalue magnitude
  • Runtime is O(n^3) - use only for small/medium graphs
  • Eigenvector signs are arbitrary (may vary between runs)
Source code in xlron/environments/env_funcs.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
@partial(jax.jit, static_argnums=(1,))
def get_spectral_features(laplacian: Array, num_features: int) -> Array:
    """Compute spectral node features from symmetric normalized graph Laplacian.

    Args:
        adj: Adjacency matrix of the graph
        num_features: Number of eigenvector features to extract

    Returns:
        Array of shape (n_nodes, num_features) containing eigenvectors corresponding
        to the smallest non-zero eigenvalues of the graph Laplacian. If the graph has
        fewer nodes than num_features, the result is zero-padded to have num_features columns.

    Notes:
        - Skips trivial eigenvectors (those with near-zero eigenvalues)
        - Eigenvectors are ordered by ascending eigenvalue magnitude
        - Runtime is O(n^3) - use only for small/medium graphs
        - Eigenvector signs are arbitrary (may vary between runs)
    """
    eigenvalues, eigenvectors = jnp.linalg.eigh(laplacian)
    n_nodes = laplacian.shape[0]
    # If graph has fewer nodes than requested features, pad with zeros
    if n_nodes < num_features:
        padding = jnp.zeros((n_nodes, num_features - n_nodes), dtype=dtype_config.LARGE_FLOAT_DTYPE)
        return jnp.concatenate([eigenvectors, padding], axis=-1).astype(
            dtype_config.LARGE_FLOAT_DTYPE
        )
    return eigenvectors[:, :num_features].astype(dtype_config.LARGE_FLOAT_DTYPE)

implement_action_rmsa_gn_model(state, action_info, params)

Implement action for RSA GN model. Update following arrays: - link_slot_array - link_slot_departure_array - link_snr_array - modulation_format_index_array - channel_power_array - active_path_array Args: state (EnvState): Environment state action (chex.Array): Action tuple (first is path action, second is launch_power) params (EnvParams): Environment parameters Returns: EnvState: Updated environment state

Source code in xlron/environments/env_funcs.py
3816
3817
3818
3819
3820
3821
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831
3832
3833
3834
3835
3836
3837
3838
3839
3840
3841
3842
3843
3844
3845
3846
3847
3848
3849
3850
3851
3852
3853
3854
3855
3856
3857
3858
3859
3860
3861
3862
3863
3864
3865
3866
3867
3868
3869
3870
3871
3872
3873
3874
3875
3876
3877
3878
@partial(jax.jit, static_argnums=(2,))
def implement_action_rmsa_gn_model(
    state: RSAGNModelEnvState, action_info: ActionInfo, params: RSAGNModelEnvParams
) -> EnvState:
    """Implement action for RSA GN model. Update following arrays:
    - link_slot_array
    - link_slot_departure_array
    - link_snr_array
    - modulation_format_index_array
    - channel_power_array
    - active_path_array
    Args:
        state (EnvState): Environment state
        action (chex.Array): Action tuple (first is path action, second is launch_power)
        params (EnvParams): Environment parameters
    Returns:
        EnvState: Updated environment state
    """
    path_action = action_info.action.astype(dtype_config.LARGE_INT_DTYPE)
    lightpath_index = get_lightpath_index(params, action_info.nodes_sd, action_info.path_index)
    launch_power = get_launch_power(state, path_action, action_info.power_action, params)
    # TODO(GN MODEL) - get mod. format based on maximum reach
    mod_format_index = jax.lax.dynamic_slice(state.mod_format_mask, (path_action,), (1,)).astype(
        dtype_config.LARGE_INT_DTYPE
    )[0]
    # Update link_slot_array and link_slot_departure_array, then other arrays
    state = implement_path_action(state, action_info, params)
    state = state.replace(
        path_index_array=set_path_links(
            state.path_index_array,
            action_info.affected_slots_mask,
            lightpath_index,
        ),
        channel_power_array=set_path_links(
            state.channel_power_array,
            action_info.affected_slots_mask,
            launch_power,
        ),
        modulation_format_index_array=set_path_links(
            state.modulation_format_index_array,
            action_info.affected_slots_mask,
            mod_format_index,
        ),
        channel_centre_bw_array=set_path_links(
            state.channel_centre_bw_array,
            action_info.affected_slots_mask,
            params.slot_size,
        ),
        channel_centre_freq_array=set_path_links(
            state.channel_centre_freq_array,
            action_info.affected_slots_mask,
            get_centre_frequency(action_info.initial_slot_index, action_info.num_slots, params),
        ),
    )
    # Update link_snr_array
    state = state.replace(link_snr_array=get_snr_link_array(state, params))
    # jax.debug.print("launch_power {}", launch_power, ordered=True)
    # jax.debug.print("mod_format_index {}", mod_format_index, ordered=True)
    # jax.debug.print("initial_slot_index {}", initial_slot_index, ordered=True)
    # jax.debug.print("state.mod_format_mask {}", state.mod_format_mask, ordered=True)
    # jax.debug.print("path_snr {}", get_snr_for_path(path, state.link_snr_array, params, state), ordered=True)
    # jax.debug.print("required_snr {}", params.modulations_array.val[mod_format_index][2] + params.snr_margin, ordered=True)
    return state

implement_action_rsa(state, action_info, params)

Implement action to assign slots on links.

Parameters:

Name Type Description Default
state RSAEnvState

current state

required
action

action to implement

required
params RSAEnvParams

environment parameters

required

Returns:

Name Type Description
state EnvState

updated state

Source code in xlron/environments/env_funcs.py
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
@partial(jax.jit, static_argnums=(2,), donate_argnums=(0,))
def implement_action_rsa(
    state: RSAEnvState,
    action_info: ActionInfo,
    params: RSAEnvParams,
) -> EnvState:
    """Implement action to assign slots on links.

    Args:
        state: current state
        action: action to implement
        params: environment parameters

    Returns:
        state: updated state
    """
    if params.__class__.__name__ == "RWALightpathReuseEnvParams":
        state = state.replace(
            link_capacity_array=update_path_links(
                state.link_capacity_array,
                action_info,
                action_info.requested_datarate,
            )
        )
        # TODO (Dynamic-RWALR) - to support diverse requested_datarates for RWA-LR, need to update masking
        # TODO (Dynamic-RWALR) - In order to enable dynamic RWA with lightpath reuse (as opposed to just incremental loading),
        #  need to keep track of active requests OR just randomly remove connections
        #  (could do this by using the link_slot_departure array in a novel way... i.e. don't fill it with departure time but current bw)
        capacity_mask = jnp.where(state.link_capacity_array <= 0.0, 1.0, 0.0)
        over_capacity_mask = jnp.where(state.link_capacity_array < 0.0, 1.0, 0.0)
        total_mask = capacity_mask + over_capacity_mask
        state = state.replace(
            link_slot_array=total_mask,
            link_slot_departure_array=update_path_links(
                state.link_slot_departure_array,
                action_info,
                state.current_time + state.holding_time,
            ),
        )
    else:
        state = implement_path_action(state, action_info, params)
    return state

implement_action_rsa_gn_model(state, action_info, params)

Implement action for RSA GN model. Update following arrays: - link_slot_array - link_slot_departure_array - link_snr_array - modulation_format_index_array - channel_power_array - active_path_array Args: state (EnvState): Environment state action (chex.Array): Action tuple (first is path action, second is launch_power) params (EnvParams): Environment parameters Returns: EnvState: Updated environment state

Source code in xlron/environments/env_funcs.py
3752
3753
3754
3755
3756
3757
3758
3759
3760
3761
3762
3763
3764
3765
3766
3767
3768
3769
3770
3771
3772
3773
3774
3775
3776
3777
3778
3779
3780
3781
3782
3783
3784
3785
3786
3787
3788
3789
3790
3791
3792
3793
3794
3795
3796
3797
3798
3799
3800
3801
3802
3803
3804
3805
3806
3807
3808
3809
3810
3811
3812
3813
@partial(jax.jit, static_argnums=(2,))
def implement_action_rsa_gn_model(
    state: RSAGNModelEnvState, action_info: ActionInfo, params: RSAGNModelEnvParams
) -> EnvState:
    """Implement action for RSA GN model. Update following arrays:
    - link_slot_array
    - link_slot_departure_array
    - link_snr_array
    - modulation_format_index_array
    - channel_power_array
    - active_path_array
    Args:
        state (EnvState): Environment state
        action (chex.Array): Action tuple (first is path action, second is launch_power)
        params (EnvParams): Environment parameters
    Returns:
        EnvState: Updated environment state
    """
    path_action = action_info.action.astype(dtype_config.LARGE_INT_DTYPE)
    lightpath_index = get_lightpath_index(params, action_info.nodes_sd, action_info.path_index)
    launch_power = get_launch_power(state, path_action, action_info.power_action, params)
    # Update link_slot_array and link_slot_departure_array, then other arrays
    state = implement_path_action(state, action_info, params)
    state = state.replace(
        path_index_array=set_path_links(
            state.path_index_array,
            action_info.affected_slots_mask,
            lightpath_index,
        ),
        channel_power_array=set_path_links(
            state.channel_power_array,
            action_info.affected_slots_mask,
            launch_power,
        ),
        channel_centre_bw_array=set_path_links(
            state.channel_centre_bw_array,
            action_info.affected_slots_mask,
            params.slot_size,
        ),
        channel_centre_freq_array=set_path_links(
            state.channel_centre_freq_array,
            action_info.affected_slots_mask,
            get_centre_frequency(action_info.initial_slot_index, action_info.num_slots, params),
        ),
    )
    if params.monitor_active_lightpaths:
        state = state.replace(
            active_lightpaths_array=update_active_lightpaths_array(
                state,
                lightpath_index,
                action_info.initial_slot_index,
                action_info.num_slots - params.guardband,
            ),
            active_lightpaths_array_departure=update_active_lightpaths_array_departure(
                state, -state.current_time - state.holding_time
            ),
        )
        # No need to check SNR until end of episode
        return state
    # Update link_snr_array
    state = state.replace(link_snr_array=get_snr_link_array(state, params))
    return state

init_active_lightpaths_array(params)

Initialise active lightpath array. Stores path indices of all active paths on the network in a 1 x M array. M is MIN(max_requests, num_links * link_resources / min_slots). min_slots is the minimum number of slots required for a lightpath i.e. max(values_bw)/ slot_size.

Parameters:

Name Type Description Default
params RSAGNModelEnvParams

Environment parameters

required

Returns: jnp.array: Active path array (default value -1, empty path)

Source code in xlron/environments/env_funcs.py
3301
3302
3303
3304
3305
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
def init_active_lightpaths_array(params: RSAGNModelEnvParams):
    """Initialise active lightpath array. Stores path indices of all active paths on the network in a 1 x M array.
    M is MIN(max_requests, num_links * link_resources / min_slots).
    min_slots is the minimum number of slots required for a lightpath i.e. max(values_bw)/ slot_size.

    Args:
        params (RSAGNModelEnvParams): Environment parameters
    Returns:
        jnp.array: Active path array (default value -1, empty path)
    """
    total_slots = params.num_links * params.link_resources  # total slots on networks
    min_slots = (
        jnp.max(params.values_bw.val) / params.slot_size
    )  # minimum number of slots required for lightpath
    return jnp.full((int(total_slots / min_slots), 3), -1, dtype=dtype_config.LARGE_INT_DTYPE)

init_active_lightpaths_array_departure(params)

Initialise active lightpath array. Stores path indices of all active paths on the network in a 1 x M array. M is MIN(max_requests, num_links * link_resources / min_slots). min_slots is the minimum number of slots required for a lightpath i.e. max(values_bw)/ slot_size.

Parameters:

Name Type Description Default
params RSAGNModelEnvParams

Environment parameters

required

Returns: jnp.array: Active path array (default value -1, empty path)

Source code in xlron/environments/env_funcs.py
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
def init_active_lightpaths_array_departure(params: RSAGNModelEnvParams):
    """Initialise active lightpath array. Stores path indices of all active paths on the network in a 1 x M array.
    M is MIN(max_requests, num_links * link_resources / min_slots).
    min_slots is the minimum number of slots required for a lightpath i.e. max(values_bw)/ slot_size.

    Args:
        params (RSAGNModelEnvParams): Environment parameters
    Returns:
        jnp.array: Active path array (default value -1, empty path)
    """
    total_slots = params.num_links * params.link_resources  # total slots on networks
    min_slots = (
        jnp.max(params.values_bw.val) / params.slot_size
    )  # minimum number of slots required for lightpath
    return jnp.full((int(total_slots / min_slots), 3), 0.0, dtype=dtype_config.SMALL_FLOAT_DTYPE)

init_active_path_array(params)

Initialise active path array. Stores details of full path utilised by lightpath on each frequency slot. Args: params (EnvParams): Environment parameters Returns: jnp.array: Active path array

Source code in xlron/environments/env_funcs.py
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
def init_active_path_array(params: EnvParams):
    """Initialise active path array. Stores details of full path utilised by lightpath on each frequency slot.
    Args:
        params (EnvParams): Environment parameters
    Returns:
        jnp.array: Active path array
    """
    return jnp.full(
        (params.num_links, params.link_resources, params.num_links),
        -1,
        dtype=dtype_config.LARGE_INT_DTYPE,
    )

init_channel_centre_bw_array(params)

Initialise channel centre array. Args: params (EnvParams): Environment parameters Returns: jnp.array: Channel centre array

Source code in xlron/environments/env_funcs.py
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
def init_channel_centre_bw_array(params: EnvParams):
    """Initialise channel centre array.
    Args:
        params (EnvParams): Environment parameters
    Returns:
        jnp.array: Channel centre array
    """
    return jnp.full(
        (params.num_links, params.link_resources), 0.0, dtype=dtype_config.LARGE_FLOAT_DTYPE
    )

init_channel_centre_freq_array(params)

Initialise channel centre frequency array. Args: params (EnvParams): Environment parameters Returns: jnp.array: Channel centre frequency array (GHz)

Source code in xlron/environments/env_funcs.py
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
def init_channel_centre_freq_array(params: EnvParams):
    """Initialise channel centre frequency array.
    Args:
        params (EnvParams): Environment parameters
    Returns:
        jnp.array: Channel centre frequency array (GHz)
    """
    return jnp.full(
        (params.num_links, params.link_resources), 0.0, dtype=dtype_config.LARGE_FLOAT_DTYPE
    )

init_channel_power_array(params)

Initialise channel power array.

Parameters:

Name Type Description Default
params EnvParams

Environment parameters

required

Returns: jnp.array: Channel power array

Source code in xlron/environments/env_funcs.py
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
def init_channel_power_array(params: EnvParams):
    """Initialise channel power array.

    Args:
        params (EnvParams): Environment parameters
    Returns:
        jnp.array: Channel power array
    """
    return jnp.full(
        (params.num_links, params.link_resources), 0.0, dtype=dtype_config.LARGE_FLOAT_DTYPE
    )

init_graph_tuple(state, params, adj, exclude_source_dest=False)

Initialise graph tuple for use with Jraph GNNs. Args: state (EnvState): Environment state params (EnvParams): Environment parameters adj (jnp.array): Adjacency matrix of the graph Returns: jraph.GraphsTuple: Graph tuple

Source code in xlron/environments/env_funcs.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
@partial(jax.jit, static_argnums=(1, 3))
def init_graph_tuple(
    state: EnvState,
    params: EnvParams,
    adj: Array,
    exclude_source_dest: bool = False,
) -> jraph.GraphsTuple:
    """Initialise graph tuple for use with Jraph GNNs.
    Args:
        state (EnvState): Environment state
        params (EnvParams): Environment parameters
        adj (jnp.array): Adjacency matrix of the graph
    Returns:
        jraph.GraphsTuple: Graph tuple
    """
    senders = params.edges.val.T[0].astype(dtype_config.LARGE_INT_DTYPE)
    receivers = params.edges.val.T[1].astype(dtype_config.LARGE_INT_DTYPE)

    # Get source and dest from request array
    # VONE has 2D request_array (2, max_edges*2+1), use first row for node info
    request_array = state.request_array
    if request_array.ndim == 2:
        request_array = request_array[0]
    source_dest, datarate = read_rsa_request(request_array)
    # Global feature is normalised data rate of current request
    globals = jnp.array(
        [datarate / jnp.max(params.values_bw.val)], dtype=dtype_config.LARGE_FLOAT_DTYPE
    )

    if exclude_source_dest:
        source_dest_features = jnp.zeros(
            (params.num_nodes, 2), dtype=dtype_config.LARGE_FLOAT_DTYPE
        )
    else:
        source, dest = source_dest[0], source_dest[2]
        # One-hot encode source and destination (2 additional features)
        source_dest_features = jnp.zeros(
            (params.num_nodes, 2), dtype=dtype_config.LARGE_FLOAT_DTYPE
        )
        source_dest_features = source_dest_features.at[
            source.astype(dtype_config.INDEX_DTYPE), 0
        ].set(1)
        source_dest_features = source_dest_features.at[
            dest.astype(dtype_config.INDEX_DTYPE), 1
        ].set(-1)

    spectral_features = get_spectral_features(adj, num_features=params.num_spectral_features)

    # For dynamic traffic, edge_features are normalised remaining holding time instead of link_slot_array
    holding_time_edge_features = state.link_slot_departure_array / params.mean_service_holding_time

    if params.__class__.__name__ in ["RSAGNModelEnvParams", "RMSAGNModelEnvParams"]:
        # Normalize by max parameters (converted to linear units)
        max_power = isrs_gn_model.from_dbm(params.max_power)
        normalized_power = jnp.round(state.channel_power_array / max_power, 3)
        max_snr = isrs_gn_model.from_db(params.max_snr)
        normalized_snr = jnp.round(state.link_snr_array / max_snr, 3)
        edge_features = jnp.stack([normalized_snr, normalized_power], axis=-1)
        node_features = jnp.concatenate([spectral_features, source_dest_features], axis=-1)
    elif params.__class__.__name__ == "VONEEnvParams":
        edge_features = (
            state.link_slot_array
            if params.mean_service_holding_time > 1e5
            else holding_time_edge_features
        )
        node_features = getattr(
            state,
            "node_capacity_array",
            jnp.zeros(params.num_nodes, dtype=dtype_config.LARGE_FLOAT_DTYPE),
        )
        node_features = node_features.reshape(-1, 1)
        node_features = jnp.concatenate(
            [node_features, spectral_features, source_dest_features], axis=-1
        )
    else:
        edge_features = (
            state.link_slot_array
            if params.mean_service_holding_time > 1e5
            else holding_time_edge_features
        )
        # [n_edges] or [n_edges, ...]
        node_features = jnp.concatenate([spectral_features, source_dest_features], axis=-1)

    if params.disable_node_features:
        node_features = jnp.zeros((1,), dtype=dtype_config.LARGE_FLOAT_DTYPE)

    # Handle undirected graphs (duplicate edges after normalization)
    if not params.directed_graph:
        senders_ = jnp.concatenate([senders, receivers])
        receivers = jnp.concatenate([receivers, senders])
        senders = senders_
        edge_features = jnp.repeat(edge_features, 2, axis=0)

    return jraph.GraphsTuple(
        nodes=node_features,
        edges=edge_features,
        senders=senders,
        receivers=receivers,
        n_node=jnp.reshape(params.num_nodes, (1,)).astype(dtype_config.LARGE_INT_DTYPE),
        n_edge=jnp.reshape(len(senders), (1,)).astype(dtype_config.LARGE_INT_DTYPE),
        globals=globals,
    )

Initialise link capacity array. Represents available data rate for lightpath on each link. Default is high value (1e6) for unoccupied slots. Once lightpath established, capacity is determined by corresponding entry in path capacity array.

Source code in xlron/environments/env_funcs.py
2386
2387
2388
2389
2390
def init_link_capacity_array(params):
    """Initialise link capacity array. Represents available data rate for lightpath on each link.
    Default is high value (1e6) for unoccupied slots. Once lightpath established, capacity is determined by
    corresponding entry in path capacity array."""
    return jnp.full((params.num_links, params.link_resources), 1e6)

Initialise link length array. Args: graph (nx.Graph): NetworkX graph Returns:

Source code in xlron/environments/env_funcs.py
332
333
334
335
336
337
338
339
340
341
342
def init_link_length_array(graph: nx.Graph) -> chex.Array:
    """Initialise link length array.
    Args:
        graph (nx.Graph): NetworkX graph
    Returns:

    """
    link_lengths = []
    for edge in sorted(graph.edges):
        link_lengths.append(graph.edges[edge]["distance"])
    return jnp.array(link_lengths, dtype=dtype_config.LARGE_INT_DTYPE)

Initialise link length array for environements that use GN model of physical layer. We assume each link has spans of equal length.

Parameters:

Name Type Description Default
graph Graph

NetworkX graph

required
max_span_length int

Maximum span length in metres

required
max_spans int

Maximum number of spans per link

required

Returns: jnp.array: Link length array (L x max_spans) in metres

Source code in xlron/environments/env_funcs.py
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
def init_link_length_array_gn_model(
    graph: nx.Graph, max_span_length: int, max_spans: int
) -> chex.Array:
    """Initialise link length array for environements that use GN model of physical layer.
    We assume each link has spans of equal length.

    Args:
        graph (nx.Graph): NetworkX graph
        max_span_length (int): Maximum span length in metres
        max_spans (int): Maximum number of spans per link
    Returns:
        jnp.array: Link length array (L x max_spans) in metres
    """
    link_lengths = []
    directed = graph.is_directed()
    graph = graph.to_undirected()
    edges = sorted(graph.edges)
    for edge in edges:
        # Topology distances are in km; convert to metres for GN model
        link_lengths.append(graph.edges[edge]["distance"] * 1e3)
    if directed:
        for edge in edges:
            link_lengths.append(graph.edges[edge]["distance"] * 1e3)
    span_length_array = []
    for length in link_lengths:
        num_spans = math.ceil(length / max_span_length)
        avg_span_length = length / num_spans
        span_lengths = [avg_span_length] * num_spans
        span_lengths.extend([0] * (max_spans - num_spans))
        span_length_array.append(span_lengths)
    return jnp.array(span_length_array, dtype=dtype_config.LARGE_INT_DTYPE)

Initialize empty (all zeroes) link-slot array. 0 means slot is free, -1 means occupied. Args: params (EnvParams): Environment parameters Returns: jnp.array: Link slot array (E x S) where E is number of edges and S is number of slots

Source code in xlron/environments/env_funcs.py
815
816
817
818
819
820
821
822
823
824
@partial(jax.jit, static_argnums=(0,))
def init_link_slot_array(params: EnvParams):
    """Initialize empty (all zeroes) link-slot array. 0 means slot is free, -1 means occupied.
    Args:
        params (EnvParams): Environment parameters
    Returns:
        jnp.array: Link slot array (E x S) where E is number of edges and S is number of slots"""
    return jnp.zeros(
        (params.num_links, params.link_resources), dtype=dtype_config.LARGE_FLOAT_DTYPE
    )

Initialize link mask

Source code in xlron/environments/env_funcs.py
832
833
834
835
836
837
838
@partial(jax.jit, static_argnums=(0, 1, 2))
def init_link_slot_mask(params: EnvParams, include_no_op: bool = False, agg: float = 1.0):
    """Initialize link mask"""
    return jnp.ones(
        params.k_paths * math.ceil(params.link_resources / agg) + (1 * include_no_op),
        dtype=dtype_config.LARGE_FLOAT_DTYPE,
    )

Initialise signal-to-noise ratio (SNR) array. Args: params (EnvParams): Environment parameters Returns: jnp.array: SNR array

Source code in xlron/environments/env_funcs.py
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
def init_link_snr_array(params: EnvParams):
    """Initialise signal-to-noise ratio (SNR) array.
    Args:
        params (EnvParams): Environment parameters
    Returns:
        jnp.array: SNR array
    """
    # The SNR is kept in linear units to allow summation of 1/SNR across links
    return jnp.full(
        (params.num_links, params.link_resources), -1e5, dtype=dtype_config.LARGE_FLOAT_DTYPE
    )

init_mod_format_mask(params)

Initialize link mask

Source code in xlron/environments/env_funcs.py
841
842
843
844
845
846
@partial(jax.jit, static_argnums=(0,))
def init_mod_format_mask(params: EnvParams):
    """Initialize link mask"""
    return jnp.full(
        (params.k_paths * params.link_resources,), -1.0, dtype=dtype_config.LARGE_FLOAT_DTYPE
    )

init_modulation_format_index_array(params)

Initialise modulation format index array. Args: params (EnvParams): Environment parameters Returns: jnp.array: Modulation format index array

Source code in xlron/environments/env_funcs.py
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
def init_modulation_format_index_array(params: EnvParams):
    """Initialise modulation format index array.
    Args:
        params (EnvParams): Environment parameters
    Returns:
        jnp.array: Modulation format index array
    """
    return jnp.full(
        (params.num_links, params.link_resources), -1, dtype=dtype_config.LARGE_INT_DTYPE
    )  # -1 so that highest order is assumed (closest to Gaussian)

init_modulations_array(modulations_filepath=None)

Initialise array of maximum spectral efficiency for modulation format on path.

Parameters:

Name Type Description Default
modulations_filepath str

Path to CSV file containing modulation formats. Defaults to None.

None

Returns: jnp.array: Array of maximum spectral efficiency for modulation format on path. First two columns are maximum path length and spectral efficiency.

Source code in xlron/environments/env_funcs.py
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
def init_modulations_array(modulations_filepath: str | None = None) -> Array:
    """Initialise array of maximum spectral efficiency for modulation format on path.

    Args:
        modulations_filepath (str, optional): Path to CSV file containing modulation formats. Defaults to None.
    Returns:
        jnp.array: Array of maximum spectral efficiency for modulation format on path.
        First two columns are maximum path length and spectral efficiency.
    """
    f = (
        pathlib.Path(modulations_filepath)
        if modulations_filepath
        else (
            pathlib.Path(__file__).parents[1].absolute()
            / "data"
            / "modulations"
            / "modulations.csv"
        )
    )
    modulations = np.genfromtxt(f, delimiter=",")
    # Drop empty first row (headers) and column (name)
    modulations = modulations[1:, 1:]
    return jnp.array(modulations, dtype=dtype_config.LARGE_FLOAT_DTYPE)

init_path_capacity_array(link_length_array, path_link_array, min_request=1, scale_factor=1.0, alpha=0.0002, NF=4.5, B=10000000000000.0, R_s=100000000000.0, beta_2=-2.17e-26, gamma=0.0012, L_s=100000.0, lambda0=1.55e-06)

Calculated from Nevin paper: https://api.repository.cam.ac.uk/server/api/core/bitstreams/b80e7a9c-a86b-4b30-a6d6-05017c60b0c8/content

Parameters:

Name Type Description Default
link_length_array Array

Array of link lengths

required
path_link_array Array

Array of links on paths

required
min_request int

Minimum data rate request size. Defaults to 100 GBps.

1
scale_factor float

Scale factor for link capacity. Defaults to 1.0.

1.0
alpha float

Fibre attenuation coefficient. Defaults to 0.2e-3 /m

0.0002
NF float

Amplifier noise figure. Defaults to 4.5 dB.

4.5
B float

Total modulated bandwidth. Defaults to 10e12 Hz.

10000000000000.0
R_s float

Symbol rate. Defaults to 100e9 Baud.

100000000000.0
beta_2 float

Dispersion parameter. Defaults to -21.7e-27 s^2/m.

-2.17e-26
gamma float

Nonlinear coefficient. Defaults to 1.2e-3 /W/m.

0.0012
L_s float

Span length. Defaults to 100e3 m.

100000.0
lambda0 float

Wavelength. Defaults to 1550e-9 m.

1.55e-06

Returns:

Type Description
Array

chex.Array: Array of link capacities in Gbps

Source code in xlron/environments/env_funcs.py
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
def init_path_capacity_array(
    link_length_array: chex.Array,
    path_link_array: chex.Array,
    min_request=1,  # Minimum data rate request size
    scale_factor=1.0,  # Scale factor for link capacity
    alpha=0.2e-3,  # Fibre attenuation coefficient
    NF=4.5,  # Amplifier noise figure
    B=10e12,  # Total modulated bandwidth
    R_s=100e9,  # Symbol rate
    beta_2=-21.7e-27,  # Dispersion parameter
    gamma=1.2e-3,  # Nonlinear coefficient
    L_s=100e3,  # span length
    lambda0=1550e-9,  # Wavelength
) -> chex.Array:
    """Calculated from Nevin paper:
    https://api.repository.cam.ac.uk/server/api/core/bitstreams/b80e7a9c-a86b-4b30-a6d6-05017c60b0c8/content

    Args:
        link_length_array (chex.Array): Array of link lengths
        path_link_array (chex.Array): Array of links on paths
        min_request (int, optional): Minimum data rate request size. Defaults to 100 GBps.
        scale_factor (float, optional): Scale factor for link capacity. Defaults to 1.0.
        alpha (float, optional): Fibre attenuation coefficient. Defaults to 0.2e-3 /m
        NF (float, optional): Amplifier noise figure. Defaults to 4.5 dB.
        B (float, optional): Total modulated bandwidth. Defaults to 10e12 Hz.
        R_s (float, optional): Symbol rate. Defaults to 100e9 Baud.
        beta_2 (float, optional): Dispersion parameter. Defaults to -21.7e-27 s^2/m.
        gamma (float, optional): Nonlinear coefficient. Defaults to 1.2e-3 /W/m.
        L_s (float, optional): Span length. Defaults to 100e3 m.
        lambda0 (float, optional): Wavelength. Defaults to 1550e-9 m.

    Returns:
        chex.Array: Array of link capacities in Gbps
    """
    path_length_array = jnp.dot(path_link_array, link_length_array)
    path_capacity_array = calculate_path_capacity(
        path_length_array,
        min_request=min_request,
        scale_factor=scale_factor,
        alpha=alpha,
        NF=NF,
        B=B,
        R_s=R_s,
        beta_2=beta_2,
        gamma=gamma,
        L_s=L_s,
        lambda0=lambda0,
    )
    return path_capacity_array.astype(dtype_config.LARGE_FLOAT_DTYPE)

init_path_index_array(params)

Initialise path index array. Represents index of lightpath occupying each slot.

Source code in xlron/environments/env_funcs.py
2393
2394
2395
def init_path_index_array(params):
    """Initialise path index array. Represents index of lightpath occupying each slot."""
    return jnp.full((params.num_links, params.link_resources), -1)

init_path_length_array(path_link_array, graph)

Initialise path length array.

Parameters:

Name Type Description Default
path_link_array Array

Path-link array

required
graph Graph

NetworkX graph

required

Returns: chex.Array: Path length array

Source code in xlron/environments/env_funcs.py
676
677
678
679
680
681
682
683
684
685
686
687
def init_path_length_array(path_link_array: chex.Array, graph: nx.Graph) -> chex.Array:
    """Initialise path length array.

    Args:
        path_link_array (chex.Array): Path-link array
        graph (nx.Graph): NetworkX graph
    Returns:
        chex.Array: Path length array
    """
    link_length_array = init_link_length_array(graph)
    path_lengths = jnp.dot(path_link_array, link_length_array)
    return path_lengths

Initialise path-link array. Each path is defined by a link utilisation array (one row in the path-link array). 1 indicates link corresponding to index is used, 0 indicates not used.

Parameters:

Name Type Description Default
graph Graph

NetworkX graph

required
k int

Number of paths

required
disjoint bool

Whether to use edge-disjoint paths. Defaults to False.

False
weight str

Sort paths by edge attribute. Defaults to "".

required
directed bool

Whether graph is directed. Defaults to False.

False
modulations_array Array

Array of maximum spectral efficiency for modulation format on path. Defaults to None.

None
rwa_lr bool

Whether the environment is RWA with lightpath reuse (affects path ordering).

False
path_snr bool

If GN model is used, include extra row of zeroes for unutilised paths

False

Returns:

Type Description
Array

chex.Array: Path-link array (N(N-1)*k x E) where N is number of nodes, E is number of edges, k is number of shortest paths

Source code in xlron/environments/env_funcs.py
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
def init_path_link_array(
    graph: nx.Graph,
    k: int,
    disjoint: bool = False,
    path_sort_criteria: str = "",
    directed: bool = False,
    modulations_array: None | chex.Array = None,
    rwa_lr: bool = False,
    scale_factor: float = 1.0,
    path_snr: bool = False,
) -> chex.Array:
    """Initialise path-link array.
    Each path is defined by a link utilisation array (one row in the path-link array).
    1 indicates link corresponding to index is used, 0 indicates not used.

    Args:
        graph (nx.Graph): NetworkX graph
        k (int): Number of paths
        disjoint (bool, optional): Whether to use edge-disjoint paths. Defaults to False.
        weight (str, optional): Sort paths by edge attribute. Defaults to "".
        directed (bool, optional): Whether graph is directed. Defaults to False.
        modulations_array (chex.Array, optional): Array of maximum spectral efficiency for modulation format on path. Defaults to None.
        rwa_lr (bool, optional): Whether the environment is RWA with lightpath reuse (affects path ordering).
        path_snr (bool, optional): If GN model is used, include extra row of zeroes for unutilised paths
        to ensure correct SNR calculation for empty paths (path index -1).

    Returns:
        chex.Array: Path-link array (N(N-1)*k x E) where N is number of nodes, E is number of edges, k is number of shortest paths
    """
    # Assert that sort_criteria is one of the allowed values
    assert path_sort_criteria in [
        "spectral_resources",
        "hops",
        "distance",
        "hops_distance",
        "capacity",
    ], (
        f"path_sort_criteria must be one of 'spectral_resources', 'hops', 'distance', 'hops_distance', or 'capacity' got '{path_sort_criteria}'"
    )

    # Set weight based on sort_criteria
    weight = (
        ""
        if path_sort_criteria in ["spectral_resources", "hops", "hops_distance", "capacity"]
        else "distance"
    )

    def path_weight(g, path, weight):
        return sum(g[u][v].get(weight, 1) for u, v in zip(path, path[1:]))

    def path_hash(p):
        return int(hashlib.sha256(str(p).encode()).hexdigest(), 16)

    def get_k_shortest_paths(
        g: nx.Graph, source: int, target: int, k: int, weight: str | None
    ) -> List[List[Tuple[int, int]]]:
        paths = list(islice(nx.shortest_simple_paths(g, source, target, weight=weight), k))
        # Ensure deterministic sorting. Sort first by weight (if any), then hops, then hash of path (random)
        paths.sort(
            key=lambda p: (
                path_weight(g, p, weight),
                len(p),
                path_hash(
                    p
                ),  # N.B. that code used for JOCN "Hype or Hope?" did not include this criterion.
            )
        )
        return paths

    def get_k_disjoint_shortest_paths(
        g: nx.Graph, source: int, target: int, k: int, weight: str | None
    ) -> List[List[Tuple[int, int]]]:
        k_paths_disjoint_unsorted = list(nx.edge_disjoint_paths(g, source, target))
        k_paths_shortest = get_k_shortest_paths(g, source, target, k, weight=weight)

        # Keep disjoint paths and add unique shortest paths until k paths reached
        disjoint_ids = [tuple(path) for path in k_paths_disjoint_unsorted]
        k_paths = k_paths_disjoint_unsorted
        for path in k_paths_shortest:
            if tuple(path) not in disjoint_ids:
                k_paths.append(path)
        k_paths = k_paths[:k]
        return k_paths

    paths = []
    edges = sorted(graph.edges)

    # Get the k-shortest paths for each node pair
    k_path_collections = []
    get_paths = get_k_disjoint_shortest_paths if disjoint else get_k_shortest_paths
    for node_pair in combinations(graph.nodes, 2):
        k_paths = get_paths(graph, node_pair[0], node_pair[1], k, weight=weight)
        k_path_collections.append(k_paths)

    if directed:  # Get paths in reverse direction
        for node_pair in combinations(graph.nodes, 2):
            k_paths_rev = get_paths(graph, node_pair[1], node_pair[0], k, weight=weight)
            k_path_collections.append(k_paths_rev)

    # Sort the paths for each node pair
    for k_paths in k_path_collections:
        source, dest = k_paths[0][0], k_paths[0][-1]

        # Get path lengths
        path_distance = [nx.path_weight(graph, path, weight="distance") for path in k_paths]

        # Get path num hops
        path_hops = [len(path) - 1 for path in k_paths]

        # Get spectral efficiency of each path
        if modulations_array is not None:
            path_se = []
            modulations_array = modulations_array[::-1]
            for length in path_distance:
                for modulation in modulations_array:
                    if length <= modulation[0]:
                        path_se.append(modulation[1])
                        break
        else:
            path_se = [1] * len(path_distance)

        if rwa_lr:
            path_capacity = [
                float(calculate_path_capacity(path_length, scale_factor=scale_factor)) + 1e-6
                for path_length in path_distance
            ]
        else:
            path_capacity = [1] * len(path_distance)

        # If less then k unique paths, add dummy paths (just so each node pair still has K rows in the array)
        empty_path = [0] * len(graph.edges)
        num_missing_paths = k - len(k_paths)
        k_paths = k_paths + [empty_path] * num_missing_paths
        path_distance = path_distance + [1e6] * num_missing_paths
        path_hops = path_hops + [1e6] * num_missing_paths
        path_se = path_se + [0] * num_missing_paths
        path_capacity = path_capacity + [0] * num_missing_paths

        # Zip the paths with potential sort criteria
        unsorted_paths = zip(k_paths, path_distance, path_hops, path_se, path_capacity)

        def determine_sort_criteria(x, path_sort_criteria):
            if path_sort_criteria == "spectral_resources":
                # Sort by ratio of hops/se or hops/capacity
                # Use max(..., 1) to avoid division by zero for dummy/padded paths
                return (x[2] / max(x[3], 1)) if not rwa_lr else (x[2] / max(x[4], 1))
            elif path_sort_criteria == "distance":
                return x[1]
            elif path_sort_criteria == "hops":
                return x[2]
            elif path_sort_criteria == "hops_distance":
                return (x[2], x[1])
            elif path_sort_criteria == "capacity":
                return x[4]
            else:
                raise ValueError(f"Path sort criteria: {path_sort_criteria}")

        k_paths_sorted = [
            (source, dest, distance, hops, se, capacity, path)
            for path, distance, hops, se, capacity in sorted(
                unsorted_paths, key=lambda x: determine_sort_criteria(x, path_sort_criteria)
            )
        ]

        # Keep only first k paths
        k_paths_sorted = k_paths_sorted[:k]

        for k_path in k_paths_sorted:
            k_path = k_path[-1]
            link_usage = [0] * len(graph.edges)  # Initialise empty path
            if sum(k_path) == 0:
                link_usage = empty_path
            else:
                for i in range(len(k_path) - 1):
                    s, d = k_path[i], k_path[i + 1]
                    for edge_index, edge in enumerate(edges):
                        condition = (
                            (edge[0] == s and edge[1] == d)
                            if directed
                            else (
                                (edge[0] == s and edge[1] == d) or (edge[0] == d and edge[1] == s)
                            )
                        )
                        if condition:
                            link_usage[edge_index] = 1
            path = link_usage
            paths.append(path)

    # If using GN model, add extra row of zeroes for empty paths for SNR calculation
    if path_snr:
        empty_path = [0] * len(graph.edges)
        paths.append(empty_path)

    return jnp.array(paths, dtype=dtype_config.BINARY_DTYPE)

init_path_se_array(path_length_array, modulations_array)

Initialise array of maximum spectral efficiency for highest-order modulation format on path.

Parameters:

Name Type Description Default
path_length_array array

Array of path lengths

required
modulations_array array

Array of maximum spectral efficiency for modulation format on path

required

Returns:

Type Description
Array

jnp.array: Array of maximum spectral efficiency for on path

Source code in xlron/environments/env_funcs.py
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
def init_path_se_array(path_length_array: Array, modulations_array: Array) -> Array:
    """Initialise array of maximum spectral efficiency for highest-order modulation format on path.

    Args:
        path_length_array (jnp.array): Array of path lengths
        modulations_array (jnp.array): Array of maximum spectral efficiency for modulation format on path

    Returns:
        jnp.array: Array of maximum spectral efficiency for on path
    """
    se_list = []
    # Flip the modulation array so that the shortest path length is first
    modulations_array = modulations_array[::-1]
    for length in path_length_array:
        for modulation in modulations_array:
            if length <= modulation[0]:
                se_list.append(modulation[1])
                break
    return jnp.array(se_list, dtype=dtype_config.SMALL_INT_DTYPE)

init_rsa_request_array()

Initialize request array

Source code in xlron/environments/env_funcs.py
827
828
829
def init_rsa_request_array():
    """Initialize request array"""
    return jnp.zeros(3, dtype=dtype_config.LARGE_INT_DTYPE)

init_traffic_matrix(key, params)

Initialize traffic matrix. Allows for random traffic matrix or uniform traffic matrix. Source-dest traffic requests are sampled probabilistically from the resulting traffic matrix.

Parameters:

Name Type Description Default
key PRNGKey

PRNG key

required
params EnvParams

Environment parameters

required

Returns:

Type Description
Array

jnp.array: Traffic matrix

Source code in xlron/environments/env_funcs.py
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
@partial(jax.jit, static_argnums=(1,))
def init_traffic_matrix(key: chex.PRNGKey, params: EnvParams) -> Array:
    """Initialize traffic matrix. Allows for random traffic matrix or uniform traffic matrix.
    Source-dest traffic requests are sampled probabilistically from the resulting traffic matrix.

    Args:
        key (chex.PRNGKey): PRNG key
        params (EnvParams): Environment parameters

    Returns:
        jnp.array: Traffic matrix
    """
    if params.random_traffic:
        traffic_matrix = jax.random.uniform(
            key, shape=(params.num_nodes, params.num_nodes), dtype=dtype_config.SMALL_FLOAT_DTYPE
        )
    else:
        traffic_matrix = jnp.ones(
            (params.num_nodes, params.num_nodes), dtype=dtype_config.SMALL_FLOAT_DTYPE
        )
    diag_elements = jnp.diag_indices_from(traffic_matrix)
    # Set main diagonal to zero so no requests from node to itself
    traffic_matrix = traffic_matrix.at[diag_elements].set(0)
    traffic_matrix = normalise_traffic_matrix(traffic_matrix)
    return traffic_matrix.astype(jnp.float32)

init_transceiver_amplifier_noise_arrays(link_resources, ref_lambda, slot_size, noise_data_filepath=None, slot_frequencies_ghz=None)

Initialise transceiver, amplifier, and ROADM noise arrays from per-band CSV data.

Parameters:

Name Type Description Default
link_resources int

Number of link resources.

required
ref_lambda float

Reference wavelength.

required
slot_size float

Slot size in GHz.

required
noise_data_filepath str

Path to CSV file. Defaults to None.

None
slot_frequencies_ghz ndarray

Pre-computed absolute slot centre frequencies in GHz. When provided, these are used directly instead of computing from the uniform formula.

None

Returns:

Type Description
Tuple[Array, Array, Array, Array, Array]

Tuple of per-slot arrays: (transceiver_snr, amplifier_noise_figure, roadm_express_loss, roadm_add_drop_loss, roadm_noise_figure)

Source code in xlron/environments/env_funcs.py
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
def init_transceiver_amplifier_noise_arrays(
    link_resources: int,
    ref_lambda: float,
    slot_size: float,
    noise_data_filepath: str | None = None,
    slot_frequencies_ghz: np.ndarray | None = None,
) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array, chex.Array]:
    """Initialise transceiver, amplifier, and ROADM noise arrays from per-band CSV data.

    Args:
        link_resources (int): Number of link resources.
        ref_lambda (float): Reference wavelength.
        slot_size (float): Slot size in GHz.
        noise_data_filepath (str, optional): Path to CSV file. Defaults to None.
        slot_frequencies_ghz (np.ndarray, optional): Pre-computed absolute slot
            centre frequencies in GHz.  When provided, these are used directly
            instead of computing from the uniform formula.

    Returns:
        Tuple of per-slot arrays: (transceiver_snr, amplifier_noise_figure,
            roadm_express_loss, roadm_add_drop_loss, roadm_noise_figure)
    """
    f = (
        pathlib.Path(noise_data_filepath)
        if noise_data_filepath
        else (
            pathlib.Path(__file__).parents[1].absolute()
            / "data"
            / "gn_model"
            / "transceiver_amplifier_data.csv"
        )
    )
    noise_data = np.genfromtxt(f, delimiter=",")
    # Drop empty first row (headers) and column (name)
    noise_data = noise_data[1:, 1:]
    # Columns are: wavelength_min_nm,wavelength_max_nm,frequency_min_ghz,frequency_max_ghz,
    #   NF_ASE_dB,SNR_TRX_dB,roadm_express_loss_dB,roadm_add_drop_loss_dB,roadm_NF_dB
    frequency_min_ghz = noise_data[:, 2]
    frequency_max_ghz = noise_data[:, 3]
    amplifier_noise_db = noise_data[:, 4]  # NF_ASE_dB
    transceiver_snr_db = noise_data[:, 5]  # SNR_TRX_dB
    roadm_express_loss_db = noise_data[:, 6]  # roadm_express_loss_dB
    roadm_add_drop_loss_db = noise_data[:, 7]  # roadm_add_drop_loss_dB
    roadm_nf_db = noise_data[:, 8]  # roadm_NF_dB

    if slot_frequencies_ghz is None:
        # Legacy uniform formula
        slot_centres = (jnp.arange(link_resources) - (link_resources - 1) / 2) * slot_size
        ref_frequency_ghz = c / ref_lambda / 1e9
        slot_frequencies_ghz = ref_frequency_ghz + slot_centres

    # Initialize output arrays
    transceiver_snr_array = jnp.zeros(link_resources)
    amplifier_noise_figure_array = jnp.zeros(link_resources)
    roadm_express_loss_array = jnp.zeros(link_resources)
    roadm_add_drop_loss_array = jnp.zeros(link_resources)
    roadm_noise_figure_array = jnp.zeros(link_resources)

    # For each slot, find which band it belongs to
    for i, freq in enumerate(slot_frequencies_ghz):
        # Find the band this frequency falls into
        found = False
        for j in range(len(frequency_min_ghz)):
            if frequency_min_ghz[j] <= freq <= frequency_max_ghz[j]:
                transceiver_snr_array = transceiver_snr_array.at[i].set(transceiver_snr_db[j])
                amplifier_noise_figure_array = amplifier_noise_figure_array.at[i].set(
                    amplifier_noise_db[j]
                )
                roadm_express_loss_array = roadm_express_loss_array.at[i].set(
                    roadm_express_loss_db[j]
                )
                roadm_add_drop_loss_array = roadm_add_drop_loss_array.at[i].set(
                    roadm_add_drop_loss_db[j]
                )
                roadm_noise_figure_array = roadm_noise_figure_array.at[i].set(roadm_nf_db[j])
                found = True
                break
        if not found:
            # Gap slots fall between bands — leave at zero (they are never occupied)
            pass

    return (
        transceiver_snr_array,
        amplifier_noise_figure_array,
        roadm_express_loss_array,
        roadm_add_drop_loss_array,
        roadm_noise_figure_array,
    )

make_graph(topology_name='conus', topology_directory=None)

Create graph from topology definition. Topologies must be defined in JSON format in the topologies directory and named as the topology name with .json extension.

Parameters:

Name Type Description Default
topology_name str

topology name

'conus'
topology_directory str | None

topology directory

None

Returns:

Name Type Description
graph

graph

Source code in xlron/environments/env_funcs.py
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
def make_graph(topology_name: str = "conus", topology_directory: str | None = None):
    """Create graph from topology definition.
    Topologies must be defined in JSON format in the topologies directory and
    named as the topology name with .json extension.

    Args:
        topology_name: topology name
        topology_directory: topology directory

    Returns:
        graph: graph
    """
    topology_path = (
        pathlib.Path(topology_directory)
        if topology_directory
        else (pathlib.Path(__file__).parents[1].absolute() / "data" / "topologies")
    )
    # Create topology
    if topology_name == "4node":
        # 4 node ring
        graph = nx.from_numpy_array(
            np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]])
        )
        # Add edge weights to graph
        nx.set_edge_attributes(graph, {(0, 1): 4, (1, 2): 3, (2, 3): 2, (3, 0): 1}, "distance")
    elif topology_name == "7node":
        # 7 node ring
        graph = nx.from_numpy_array(
            jnp.array(
                [
                    [0, 1, 0, 0, 0, 0, 1],
                    [1, 0, 1, 0, 0, 0, 0],
                    [0, 1, 0, 1, 0, 0, 0],
                    [0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 0, 1, 0],
                    [0, 0, 0, 0, 1, 0, 1],
                    [1, 0, 0, 0, 0, 1, 0],
                ]
            )
        )
        # Add edge weights to graph
        nx.set_edge_attributes(
            graph,
            {
                (0, 1): 4,
                (1, 2): 3,
                (2, 3): 2,
                (3, 4): 1,
                (4, 5): 2,
                (5, 6): 3,
                (6, 0): 4,
            },
            "distance",
        )
    else:
        with open(topology_path / f"{topology_name}.json") as f:
            graph = nx.node_link_graph(json.load(f), edges="links")
    return graph

make_line_graph(graph)

Create the line graph of a NetworkX graph.

The line graph L(G) has: - One node for each edge in the original graph G - An edge between two nodes in L(G) if the corresponding edges in G share a node

This is used for transformer architectures where we treat edges (links) as tokens and need positional encodings based on edge relationships.

Parameters:

Name Type Description Default
graph Graph

NetworkX graph (original topology)

required

Returns:

Name Type Description
line_graph Graph

NetworkX line graph where nodes correspond to edges in the original

Source code in xlron/environments/env_funcs.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def make_line_graph(graph: nx.Graph) -> nx.Graph:
    """Create the line graph of a NetworkX graph.

    The line graph L(G) has:
    - One node for each edge in the original graph G
    - An edge between two nodes in L(G) if the corresponding edges in G share a node

    This is used for transformer architectures where we treat edges (links) as tokens
    and need positional encodings based on edge relationships.

    Args:
        graph: NetworkX graph (original topology)

    Returns:
        line_graph: NetworkX line graph where nodes correspond to edges in the original
    """
    return nx.line_graph(graph)

mask_slots_rmsa_gn_model(state, params, request)

Compute action mask for RMSA with GN model physical layer.

For each (path, modulation_format) pair, finds first-fit and last-fit candidate slot positions, evaluates them via the ISRS GN model, and builds a mask indicating which slots are valid (with the modulation format index stored in mod_format_mask).

Parameters:

Name Type Description Default
state RSAGNModelEnvState

Environment state

required
params RSAGNModelEnvParams

Environment parameters

required
request Array

Request array in format [source_node, data-rate, destination_node]

required

Returns:

Name Type Description
state EnvState

Updated environment state with link_slot_mask and mod_format_mask

Source code in xlron/environments/env_funcs.py
3953
3954
3955
3956
3957
3958
3959
3960
3961
3962
3963
3964
3965
3966
3967
3968
3969
3970
3971
3972
3973
3974
3975
3976
3977
3978
3979
3980
3981
3982
3983
3984
3985
3986
3987
3988
3989
3990
3991
3992
3993
3994
3995
3996
3997
3998
3999
4000
4001
4002
4003
4004
4005
4006
4007
4008
4009
4010
4011
4012
4013
4014
4015
4016
4017
4018
4019
4020
4021
4022
4023
4024
4025
4026
4027
4028
4029
4030
4031
4032
4033
4034
4035
4036
4037
4038
4039
4040
4041
4042
4043
4044
4045
4046
4047
4048
4049
4050
4051
4052
4053
4054
4055
4056
4057
4058
4059
4060
4061
4062
4063
4064
4065
4066
4067
4068
4069
4070
4071
4072
4073
4074
4075
4076
4077
4078
4079
4080
4081
4082
4083
4084
4085
4086
4087
4088
4089
4090
4091
4092
4093
4094
4095
4096
4097
4098
4099
4100
4101
4102
4103
4104
4105
4106
4107
4108
4109
4110
4111
4112
4113
4114
4115
4116
4117
4118
4119
4120
4121
4122
4123
4124
4125
4126
4127
4128
4129
4130
4131
4132
4133
4134
4135
4136
4137
4138
4139
4140
4141
4142
4143
4144
4145
4146
4147
4148
4149
4150
4151
4152
4153
4154
4155
4156
4157
4158
4159
4160
4161
4162
4163
4164
4165
4166
4167
4168
4169
4170
4171
4172
4173
4174
4175
4176
4177
4178
4179
4180
4181
4182
4183
4184
4185
4186
4187
4188
4189
4190
4191
4192
4193
4194
4195
4196
4197
4198
4199
4200
4201
4202
4203
4204
4205
4206
4207
4208
4209
4210
4211
4212
4213
4214
4215
4216
4217
4218
4219
4220
4221
4222
4223
4224
4225
4226
4227
4228
4229
4230
4231
4232
@partial(jax.jit, static_argnums=(1,))
def mask_slots_rmsa_gn_model(
    state: RSAGNModelEnvState, params: RSAGNModelEnvParams, request: chex.Array
) -> EnvState:
    """Compute action mask for RMSA with GN model physical layer.

    For each (path, modulation_format) pair, finds first-fit and last-fit candidate
    slot positions, evaluates them via the ISRS GN model, and builds a mask indicating
    which slots are valid (with the modulation format index stored in mod_format_mask).

    Args:
        state: Environment state
        params: Environment parameters
        request: Request array in format [source_node, data-rate, destination_node]

    Returns:
        state: Updated environment state with link_slot_mask and mod_format_mask
    """
    nodes_sd, requested_datarate = read_rsa_request(request)
    num_mods = params.modulations_array.val.shape[0]

    # --- Phase 1: Vectorized slot availability via cumsum ---
    paths = get_paths(params, nodes_sd)  # (k, num_links)
    slots_occupied = (state.link_slot_array != 0).astype(dtype_config.LARGE_FLOAT_DTYPE)
    occupied = (paths @ slots_occupied) > 0  # (k, link_resources) True if occupied on any path link

    # Modulation format data
    mod_formats = params.modulations_array.val  # (num_mods, >=3)
    se_values = mod_formats[:, 1]  # (num_mods,)
    req_snr_values = mod_formats[:, 2] + params.snr_margin  # (num_mods,) in dB

    # Required slots per modulation format
    all_req_slots = jax.vmap(
        lambda se: required_slots(
            requested_datarate,
            se,
            params.slot_size,
            guardband=params.guardband,
            temperature=params.temperature,
        )
    )(se_values)  # (num_mods,)

    # Cumsum sliding window for contiguous free slots
    padded = jnp.concatenate(
        [
            jnp.zeros((params.k_paths, 1), dtype=dtype_config.LARGE_FLOAT_DTYPE),
            occupied.astype(dtype_config.LARGE_FLOAT_DTYPE),
            jnp.ones((params.k_paths, params.max_slots - 1), dtype=dtype_config.LARGE_FLOAT_DTYPE),
        ],
        axis=1,
    )  # (k, link_resources + max_slots)
    cumsum = jnp.cumsum(padded, axis=1)

    slot_indices = jnp.arange(params.link_resources)
    end_indices = slot_indices[None, :] + all_req_slots[:, None]  # (num_mods, link_resources)

    cumsum_at_end = cumsum[:, end_indices]  # (k, num_mods, link_resources)
    cumsum_at_start = cumsum[:, slot_indices]  # (k, link_resources)
    window_sums = cumsum_at_end - cumsum_at_start[:, None, :]  # (k, num_mods, link_resources)
    slot_available = window_sums == 0  # (k, num_mods, link_resources)

    # Zero out dummy (all-zero) paths
    path_valid = jnp.max(paths, axis=1) > 0  # (k,)
    slot_available = slot_available & path_valid[:, None, None]

    # Extract FF and LF slot indices
    has_candidate = jnp.any(slot_available, axis=2)  # (k, num_mods)

    # FF: first True along axis=2 (append False sentinel so argmax returns link_resources when empty)
    ff_sentinel = jnp.zeros((params.k_paths, num_mods, 1), dtype=bool)
    ff_indices = jnp.argmax(
        jnp.concatenate([slot_available, ff_sentinel], axis=2), axis=2
    )  # (k, num_mods)

    # LF: last True along axis=2
    lf_sentinel = jnp.zeros((params.k_paths, num_mods, 1), dtype=bool)
    lf_from_end = jnp.argmax(
        jnp.concatenate([jnp.flip(slot_available, axis=2), lf_sentinel], axis=2), axis=2
    )  # (k, num_mods)
    lf_indices = params.link_resources - 1 - lf_from_end  # (k, num_mods)

    # Skip LF evaluation when FF == LF (same slot, would be duplicate)
    is_same = ff_indices == lf_indices

    # --- Phase 2: Batch candidate construction ---
    # Flatten to (2 * k * M) candidates: [FF candidates..., LF candidates...]
    # Order within each half: path0_mod0, path0_mod1, ..., pathK_modM
    flat_ff_indices = ff_indices.reshape(-1)  # (k*M,)
    flat_lf_indices = lf_indices.reshape(-1)  # (k*M,)
    all_slot_indices = jnp.concatenate([flat_ff_indices, flat_lf_indices])  # (2*k*M,)

    # Path index for each candidate
    path_idx_per_km = jnp.repeat(jnp.arange(params.k_paths), num_mods)  # (k*M,)
    path_idx_all = jnp.concatenate([path_idx_per_km, path_idx_per_km])  # (2*k*M,)
    all_paths = paths[path_idx_all]  # (2*k*M, num_links)

    # Mod format index for each candidate
    mod_idx_per_km = jnp.tile(jnp.arange(num_mods), params.k_paths)  # (k*M,)
    mod_idx_all = jnp.concatenate([mod_idx_per_km, mod_idx_per_km])  # (2*k*M,)

    # Required slots and SNR per candidate
    all_req_slots_flat = all_req_slots[mod_idx_all]  # (2*k*M,)
    all_req_snr_flat = req_snr_values[mod_idx_all]  # (2*k*M,)

    # Lightpath indices per candidate
    all_lightpath_indices = jax.vmap(lambda i: get_lightpath_index(params, nodes_sd, i))(
        path_idx_all
    )  # (2*k*M,)

    # Launch power per path
    if params.launch_power_type == "fixed":
        all_launch_powers = jnp.broadcast_to(
            state.launch_power_array[0], (2 * params.k_paths * num_mods,)
        )
    else:
        per_path_launch_powers = jax.vmap(
            lambda i: get_launch_power(
                state,
                i * (params.link_resources // params.aggregate_slots),
                state.launch_power_array[i],
                params,
            )
        )(jnp.arange(params.k_paths))  # (k,)
        all_launch_powers = per_path_launch_powers[path_idx_all]  # (2*k*M,)

    # Validity flags: FF candidates use has_candidate, LF also requires FF != LF
    flat_has_ff = has_candidate.reshape(-1)  # (k*M,)
    flat_has_lf = (has_candidate & ~is_same).reshape(-1)  # (k*M,)
    all_has_candidate = jnp.concatenate([flat_has_ff, flat_has_lf])  # (2*k*M,)

    # Build affected_slots_masks for all candidates
    all_masks = jax.vmap(lambda si, rs, p: get_affected_slots_mask(si, rs, p, params))(
        all_slot_indices, all_req_slots_flat, all_paths
    )  # (2*k*M, num_links, link_resources)

    # Construct modified state arrays for all candidates
    all_ch_bw = jax.vmap(
        lambda mask: set_path_links(state.channel_centre_bw_array, mask, params.slot_size)
    )(all_masks)  # (2*k*M, num_links, link_resources)

    all_ch_power = jax.vmap(lambda mask, lp: set_path_links(state.channel_power_array, mask, lp))(
        all_masks, all_launch_powers
    )  # (2*k*M, num_links, link_resources)

    all_path_idx_arrays = jax.vmap(
        lambda mask, li: set_path_links(state.path_index_array, mask, li)
    )(all_masks, all_lightpath_indices)  # (2*k*M, num_links, link_resources)

    all_mod_fmt_arrays = jax.vmap(
        lambda mask, mi: set_path_links(
            state.modulation_format_index_array,
            mask,
            mi.astype(state.modulation_format_index_array.dtype),
        )
    )(all_masks, mod_idx_all)  # (2*k*M, num_links, link_resources)

    # Compute centre frequencies for each candidate placement
    all_centre_freqs = jax.vmap(
        lambda mask, si, rs: set_path_links(
            state.channel_centre_freq_array,
            mask,
            get_centre_frequency(si, rs, params),
        )
    )(all_masks, all_slot_indices, all_req_slots_flat)  # (2*k*M, num_links, link_resources)

    # --- Phase 3: Vmapped GN model evaluation ---
    def evaluate_one_candidate(
        ch_bw,
        ch_power,
        path_idx_arr,
        mod_fmt_arr,
        centre_freq_arr,
        slot_idx,
        req_slots_val,
        req_snr_val,
        path_vec,
        has_cand,
    ):
        """Evaluate a single candidate placement. Returns 1.0 if valid, 0.0 if not."""
        temp_state = state.replace(
            channel_centre_bw_array=ch_bw,
            channel_power_array=ch_power,
            path_index_array=path_idx_arr,
            modulation_format_index_array=mod_fmt_arr,
            channel_centre_freq_array=centre_freq_arr,
        )

        # Compute SNR for all links
        if params.uniform_spans and not params.mod_format_correction:
            link_snr = get_snr_link_array_fused(temp_state, params)
        else:
            link_snr = get_snr_link_array(temp_state, params)
        temp_state = temp_state.replace(link_snr_array=link_snr)

        # Check 1: New lightpath SNR meets modulation format threshold
        # (inlined get_minimum_snr_of_channels_on_path to avoid static_argnums issue in vmap)
        snr_all_channels = get_snr_for_path(path_vec, temp_state.link_snr_array, params, temp_state)
        new_snr = jnp.min(
            jnp.concatenate(
                [
                    snr_all_channels[slot_idx].reshape((1,)),
                    snr_all_channels[slot_idx + req_slots_val - 1].reshape((1,)),
                ],
                axis=0,
            )
        )
        new_ok = (new_snr >= req_snr_val).astype(jnp.float32)

        # Check 2: Existing lightpaths still meet their SNR thresholds
        existing_fail = check_snr_sufficient(temp_state, params)
        existing_ok = (1.0 - existing_fail).astype(jnp.float32)

        # Check 3: Total power on each link doesn't exceed max_power_per_fibre
        total_power = compute_total_power_per_link(ch_power, path_idx_arr)
        power_ok = (~jnp.any(total_power > params.max_power_per_fibre)).astype(jnp.float32)

        return new_ok * existing_ok * power_ok * has_cand.astype(jnp.float32)

    all_results = jax.vmap(evaluate_one_candidate)(
        all_ch_bw,
        all_ch_power,
        all_path_idx_arrays,
        all_mod_fmt_arrays,
        all_centre_freqs,
        all_slot_indices,
        all_req_slots_flat,
        all_req_snr_flat,
        all_paths,
        all_has_candidate,
    )  # (2*k*M,)

    # --- Phase 4: Assemble final mask ---
    ff_results = all_results[: params.k_paths * num_mods].reshape(params.k_paths, num_mods)
    lf_results = all_results[params.k_paths * num_mods :].reshape(params.k_paths, num_mods)

    # Build mod_format_mask: for each (path, slot), store winning mod format index or -1
    slot_idx_range = jnp.arange(params.link_resources, dtype=dtype_config.LARGE_INT_DTYPE)

    def build_path_mask(path_idx):
        path_mask = jnp.full((params.link_resources,), -1.0, dtype=dtype_config.LARGE_FLOAT_DTYPE)

        def apply_mod(mod_idx, mask):
            ff_idx = ff_indices[path_idx, mod_idx]
            lf_idx = lf_indices[path_idx, mod_idx]
            ff_ok = ff_results[path_idx, mod_idx]
            lf_ok = lf_results[path_idx, mod_idx]

            # Set FF slot position if valid
            mask = jnp.where(
                (slot_idx_range == ff_idx) & (ff_ok > 0),
                mod_idx.astype(dtype_config.LARGE_FLOAT_DTYPE),
                mask,
            )
            # Set LF slot position if valid
            mask = jnp.where(
                (slot_idx_range == lf_idx) & (lf_ok > 0),
                mod_idx.astype(dtype_config.LARGE_FLOAT_DTYPE),
                mask,
            )
            return mask

        path_mask = jax.lax.fori_loop(0, num_mods, apply_mod, path_mask)
        return path_mask

    mod_format_mask = jax.vmap(build_path_mask)(jnp.arange(params.k_paths)).reshape(
        -1
    )  # (k * link_resources,)

    link_slot_mask = jnp.where(mod_format_mask >= 0, 1.0, 0.0)
    if params.aggregate_slots > 1:
        state = state.replace(full_link_slot_mask=link_slot_mask)
        link_slot_mask, _ = aggregate_slots(link_slot_mask.reshape(params.k_paths, -1), params)
        link_slot_mask = link_slot_mask.reshape(-1)
    if params.include_no_op:
        link_slot_mask = jnp.hstack([link_slot_mask, jnp.ones((1,))])
    state = state.replace(
        link_slot_mask=link_slot_mask,
        mod_format_mask=mod_format_mask,
    )
    return state

mask_slots_rwalr(state, params, request)

For use in RWALightpathReuseEnv. Each lightpath has a maximum capacity defined in path_capacity_array. This is updated when a lightpath is assigned. If remaining path capacity is less than current request, corresponding link-slots are masked out. If link-slot is in use by another lightpath for a different source and destination node (even if not full) it is masked out. Step 1: - Mask out slots that are not valid based on path capacity (check link_capacity_array) Step 2: - Mask out slots that are not valid based on lightpath reuse (check path_index_array)

Parameters:

Name Type Description Default
state EnvState

Environment state

required
params EnvParams

Environment parameters

required
request Array

Request array in format [source_node, data-rate, destination_node]

required

Returns:

Name Type Description
state EnvState

Updated environment state

Source code in xlron/environments/env_funcs.py
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
@partial(jax.jit, static_argnums=(1,))
def mask_slots_rwalr(state: EnvState, params: EnvParams, request: chex.Array) -> EnvState:
    """For use in RWALightpathReuseEnv.
    Each lightpath has a maximum capacity defined in path_capacity_array. This is updated when a lightpath is assigned.
    If remaining path capacity is less than current request, corresponding link-slots are masked out.
    If link-slot is in use by another lightpath for a different source and destination node (even if not full) it is masked out.
    Step 1:
    - Mask out slots that are not valid based on path capacity (check link_capacity_array)
    Step 2:
    - Mask out slots that are not valid based on lightpath reuse (check path_index_array)

    Args:
        state: Environment state
        params: Environment parameters
        request: Request array in format [source_node, data-rate, destination_node]

    Returns:
        state: Updated environment state
    """
    nodes_sd, requested_datarate = read_rsa_request(request)
    source, dest = nodes_sd
    path_start_index = get_path_indices(
        params,
        source,
        dest,
        params.k_paths,
        params.num_nodes,
        directed=params.directed_graph,
    ).astype(dtype_config.INDEX_DTYPE)
    # Step 1 - capacity mask (computed once, shared across paths)
    capacity_mask = (state.link_capacity_array < requested_datarate).astype(
        dtype_config.LARGE_FLOAT_DTYPE
    )

    # Step 2 - lightpath reuse masks (computed once, indexed per path)
    empty_mask = (state.path_index_array != -1).astype(dtype_config.LARGE_FLOAT_DTYPE)

    def single_path(i):
        capacity_slots = get_path_slots(capacity_mask, params, nodes_sd, i)

        lightpath_index = path_start_index + i
        lightpath_mask = (
            1.0 - (state.path_index_array == lightpath_index).astype(dtype_config.LARGE_FLOAT_DTYPE)
        ) * empty_mask
        lightpath_slots = get_path_slots(lightpath_mask, params, nodes_sd, i)

        # Combine: masked if either mask is active, then invert (1 = valid)
        combined = jnp.maximum(capacity_slots, lightpath_slots)
        return 1.0 - jnp.minimum(combined, 1.0)

    link_slot_mask = jax.vmap(single_path)(jnp.arange(params.k_paths)).reshape(-1)

    if params.aggregate_slots > 1:
        state = state.replace(full_link_slot_mask=link_slot_mask)
        link_slot_mask, _ = aggregate_slots(link_slot_mask.reshape(params.k_paths, -1), params)
        link_slot_mask = link_slot_mask.reshape(-1)

    if params.include_no_op:
        link_slot_mask = jnp.concatenate([link_slot_mask, jnp.ones((1,))])

    return link_slot_mask

normalise_traffic_matrix(traffic_matrix)

Normalise traffic matrix to sum to 1

Source code in xlron/environments/env_funcs.py
863
864
865
866
def normalise_traffic_matrix(traffic_matrix):
    """Normalise traffic matrix to sum to 1"""
    traffic_matrix /= jnp.sum(traffic_matrix, promote_integers=False)
    return traffic_matrix

pad_array(array, fill_value)

Pad a ragged multidimensional array to rectangular shape. Used for training on multiple topologies. Source: https://codereview.stackexchange.com/questions/222623/pad-a-ragged-multidimensional-array-to-rectangular-shape

Parameters:

Name Type Description Default
array

array to pad

required
fill_value

value to fill with

required

Returns:

Name Type Description
result

padded array

Source code in xlron/environments/env_funcs.py
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
def pad_array(array, fill_value):
    """
    Pad a ragged multidimensional array to rectangular shape.
    Used for training on multiple topologies.
    Source: https://codereview.stackexchange.com/questions/222623/pad-a-ragged-multidimensional-array-to-rectangular-shape

    Args:
        array: array to pad
        fill_value: value to fill with

    Returns:
        result: padded array
    """

    def get_dimensions(array, level=0):
        yield level, len(array)
        try:
            for row in array:
                yield from get_dimensions(row, level + 1)
        except TypeError:  # not an iterable
            pass

    def get_max_shape(array):
        dimensions = defaultdict(int)
        for level, length in get_dimensions(array):
            dimensions[level] = max(dimensions[level], length)
        return [value for _, value in sorted(dimensions.items())]

    def iterate_nested_array(array, index=()):
        try:
            for idx, row in enumerate(array):
                yield from iterate_nested_array(row, (*index, idx))
        except TypeError:  # final level
            yield (*index, slice(len(array))), array

    dimensions = get_max_shape(array)
    result = np.full(dimensions, fill_value)
    for index, value in iterate_nested_array(array):
        result[index] = value
    return result

poisson(key, lam, shape=(), dtype=dtypes.float_)

Sample Exponential random values with given shape and float dtype.

The values are distributed according to the probability density function:

.. math:: f(x) = \lambda e^{-\lambda x}

on the domain :math:0 \le x < \infty.

Args: key: a PRNG key used as the random key. lam: a positive float32 or float64 Tensor indicating the rate parameter shape: optional, a tuple of nonnegative integers representing the result shape. Default (). dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns: A random array with the specified shape and dtype.

Source code in xlron/environments/env_funcs.py
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
@partial(jax.jit, static_argnums=(1, 2, 3))
def poisson(
    key: Union[Array, prng.PRNGKeyArray],
    lam: ArrayLike,
    shape: Shape = (),
    dtype: DTypeLike = dtypes.float_,
) -> Array:
    r"""Sample Exponential random values with given shape and float dtype.

    The values are distributed according to the probability density function:

    .. math::
     f(x) = \lambda e^{-\lambda x}

    on the domain :math:`0 \le x < \infty`.

    Args:
    key: a PRNG key used as the random key.
    lam: a positive float32 or float64 `Tensor` indicating the rate parameter
    shape: optional, a tuple of nonnegative integers representing the result
      shape. Default ().
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).

    Returns:
    A random array with the specified shape and dtype.
    """
    key, _ = jax._src.random._check_prng_key(key)
    if not dtypes.issubdtype(dtype, np.floating):
        raise ValueError(f"dtype argument to `exponential` must be a float dtype, got {dtype}")
    dtype = dtypes.canonicalize_dtype(dtype)
    shape = core.canonicalize_shape(shape)
    return _poisson(key, lam, shape, dtype)

process_path_action(state, params, path_action)

Process path action to get path index and initial slot index. Args: state (State): current state params (Params): environment parameters path_action (int): path action Returns: int: path index int: initial slot index

Source code in xlron/environments/env_funcs.py
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
@partial(jax.jit, static_argnums=(1,))
def process_path_action(
    state: EnvState, params: EnvParams, path_action: chex.Array
) -> tuple[chex.Array, chex.Array]:
    """Process path action to get path index and initial slot index.
    Args:
        state (State): current state
        params (Params): environment parameters
        path_action (int): path action
    Returns:
        int: path index
        int: initial slot index
    """
    num_slot_actions = params.link_resources // params.aggregate_slots
    path_action = differentiable_round_simple(
        path_action, params.temperature, params.differentiable
    )
    path_index = differentiable_floor(
        path_action // num_slot_actions, params.temperature, params.differentiable
    ).astype(dtype_config.LARGE_INT_DTYPE)
    initial_aggregated_slot_index = jnp.mod(path_action, num_slot_actions)
    initial_slot_index = initial_aggregated_slot_index * params.aggregate_slots

    if params.aggregate_slots > 1:
        # Compute flat index into 1D array of shape (k_paths * link_resources,)
        full_mask = state.full_link_slot_mask.reshape(
            (params.k_paths, num_slot_actions, params.aggregate_slots)
        )
        window = jax.lax.dynamic_slice(
            full_mask,
            (path_index, initial_aggregated_slot_index, 0),
            (1, 1, params.aggregate_slots),
        )
        # Use argmax to get index of first 1 in slice of mask
        initial_slot_index = initial_slot_index + differentiable_argmax(
            window, temperature=params.temperature, differentiable=params.differentiable
        ).astype(dtype_config.LARGE_INT_DTYPE)
    return path_index, initial_slot_index

read_rsa_request(request_array)

Read RSA request from request array. Return source-destination nodes and bandwidth request. Args: request_array: request array Returns: Tuple[chex.Array, chex.Array]: source-destination nodes and bandwidth request

Source code in xlron/environments/env_funcs.py
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
def read_rsa_request(request_array: chex.Array) -> Tuple[chex.Array, chex.Array]:
    """Read RSA request from request array. Return source-destination nodes and bandwidth request.
    Args:
        request_array: request array
    Returns:
        Tuple[chex.Array, chex.Array]: source-destination nodes and bandwidth request
    """
    nodes_sd = request_array[jnp.array([0, 2])]
    requested_datarate = request_array[1]
    return nodes_sd, requested_datarate

remove_expired_services_rwalr(state, params)

Parameters:

Name Type Description Default
state EnvState

Environment state

required
params Optional[EnvParams]

Environment parameters

required

Returns:

Type Description
EnvState

Updated environment state

Source code in xlron/environments/env_funcs.py
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
@partial(jax.jit, static_argnums=(1,))
def remove_expired_services_rwalr(state: EnvState, params: Optional[EnvParams]) -> EnvState:
    """

    Args:
        state: Environment state
        params: Environment parameters

    Returns:
        Updated environment state
    """
    # Set one where link_slot_departure_array is >= zero and <= current time
    current_time = state.current_time if not params.relative_arrival_times else state.arrival_time
    mask_remove = differentiable_compare(
        state.link_slot_departure_array,
        current_time,
        "<=",
        params.differentiable,
        params.temperature,
    ) * differentiable_compare(
        state.link_slot_departure_array,
        zero,
        ">=",
        params.differentiable,
        params.temperature,
    )
    updated_link_slot_departure_array = state.link_slot_departure_array * (
        1 - mask_remove
    )  # Set to zero where mask is one
    if params.relative_arrival_times:
        mask_subtract = differentiable_compare(
            updated_link_slot_departure_array, zero, ">", params.differentiable, params.temperature
        )
        updated_link_slot_departure_array = (
            updated_link_slot_departure_array - jnp.squeeze(current_time) * mask_subtract
        )
    state = state.replace(
        link_slot_array=state.link_slot_array * (1 - mask_remove),
        path_index_array=state.path_index_array * (1 - mask_remove) + (-one) * mask_remove,
        link_slot_departure_array=updated_link_slot_departure_array,
    )
    return state

required_slots(bitrate, se, channel_width, guardband=1, temperature=1.0, differentiable=True)

Calculate required slots for a given bitrate and spectral efficiency.

Parameters:

Name Type Description Default
bit_rate float

Bit rate in Gbps

required
se float

Spectral efficiency in bps/Hz

required
channel_width float

Channel width in GHz

required
guardband int

Guard band. Defaults to 1.

1
temperature float

Temperature for differentiable approximation. Defaults to 1.0.

1.0
differentiable bool

If False, use non-differentiable operations. Defaults to True.

True

Returns:

Name Type Description
int int

Required slots

Source code in xlron/environments/env_funcs.py
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
@partial(jax.jit, static_argnums=(2, 3, 4, 5))
def required_slots(
    bitrate: float,
    se: int,
    channel_width: float,
    guardband: int = 1,
    temperature: float = 1.0,
    differentiable: bool = True,
) -> int:
    """Calculate required slots for a given bitrate and spectral efficiency.

    Args:
        bit_rate (float): Bit rate in Gbps
        se (float): Spectral efficiency in bps/Hz
        channel_width (float): Channel width in GHz
        guardband (int, optional): Guard band. Defaults to 1.
        temperature (float, optional): Temperature for differentiable approximation. Defaults to 1.0.
        differentiable (bool, optional): If False, use non-differentiable operations. Defaults to True.

    Returns:
        int: Required slots
    """
    # Apply differentiable ceiling
    base_calculation = bitrate / (se * channel_width) + guardband
    slots = differentiable_ceil(
        base_calculation, temperature=temperature, differentiable=differentiable
    )
    # Differentiable version of equality comparison bitrate == 0
    is_zero = differentiable_compare(
        bitrate, zero, "==", temperature=temperature, differentiable=differentiable
    )  # High temperature for sharper transition
    # Differentiable version of the conditional zeroing (if bitrate is zero, then required slots should be zero)
    result = slots * (one - is_zero)
    return jnp.squeeze(result).astype(dtype_config.SMALL_INT_DTYPE)

set_band_gaps(link_slot_array, params, val)

Set band gaps in link slot array Args: link_slot_array (chex.Array): Link slot array params (RSAGNModelEnvParams): Environment parameters val (int): Value to set Returns: chex.Array: Link slot array with band gaps

Source code in xlron/environments/env_funcs.py
3698
3699
3700
3701
3702
3703
3704
3705
3706
3707
3708
3709
3710
3711
3712
3713
3714
3715
3716
3717
3718
3719
3720
3721
@partial(jax.jit, static_argnums=(1, 2))
def set_band_gaps(link_slot_array: chex.Array, params: RSAGNModelEnvParams, val: int) -> chex.Array:
    """Set band gaps in link slot array
    Args:
        link_slot_array (chex.Array): Link slot array
        params (RSAGNModelEnvParams): Environment parameters
        val (int): Value to set
    Returns:
        chex.Array: Link slot array with band gaps
    """
    # Create array that is size of link_slot array with values of column index
    mask = jnp.arange(params.link_resources)
    mask = jnp.tile(mask, (params.num_links, 1))

    def set_band_gap(i, arr):
        gap_start = params.gap_starts.val[i]
        gap_end = gap_start + params.gap_widths.val[i]
        condition = jnp.logical_and(arr >= gap_start, arr < gap_end)
        arr = jnp.where(condition, -one, arr)
        return arr

    mask = jax.lax.fori_loop(0, params.gap_widths.val.shape[0], set_band_gap, mask)
    link_slot_array = jnp.where(mask == -one, val, link_slot_array)
    return link_slot_array

update_active_lightpaths_array(state, path_index, initial_slot_index, num_slots)

Update active lightpaths array with new path index. Find the first index of the array with value -1 and replace with path index. Args: state (RSAGNModelEnvState): Environment state path_index (int): Path index to add to active lightpaths array Returns: jnp.array: Updated active lightpaths array

Source code in xlron/environments/env_funcs.py
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
def update_active_lightpaths_array(
    state: RSAGNModelEnvState, path_index: int, initial_slot_index: int, num_slots: int
) -> chex.Array:
    """Update active lightpaths array with new path index.
    Find the first index of the array with value -1 and replace with path index.
    Args:
        state (RSAGNModelEnvState): Environment state
        path_index (int): Path index to add to active lightpaths array
    Returns:
        jnp.array: Updated active lightpaths array
    """
    first_empty_index = jnp.argmin(
        state.active_lightpaths_array[:, 0]
    )  # Just look at the first column
    return jax.lax.dynamic_update_slice(
        state.active_lightpaths_array,
        jnp.array(
            [[path_index, initial_slot_index, num_slots]], dtype=state.active_lightpaths_array.dtype
        ),
        (first_empty_index, 0),
    )

update_active_lightpaths_array_departure(state, time)

Update active lightpaths array with new path index. Find the first index of the array with value -1 and replace with path index. Args: state (RSAGNModelEnvState): Environment state time (float): Departure time Returns: jnp.array: Updated active lightpaths array

Source code in xlron/environments/env_funcs.py
3358
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
def update_active_lightpaths_array_departure(state: RSAGNModelEnvState, time: float) -> chex.Array:
    """Update active lightpaths array with new path index.
    Find the first index of the array with value -1 and replace with path index.
    Args:
        state (RSAGNModelEnvState): Environment state
        time (float): Departure time
    Returns:
        jnp.array: Updated active lightpaths array
    """
    first_empty_index = jnp.argmin(
        state.active_lightpaths_array[:, 0]
    )  # Just look at the first column
    return jax.lax.dynamic_update_slice(
        state.active_lightpaths_array_departure,
        jnp.stack((time, time, time)),
        (first_empty_index, 0),
    )

update_graph_tuple(state, params)

Update graph tuple for use with Jraph GNNs.

Edge and node features are updated from link_slot_array and node_capacity_array respectively. Global features are updated as request_array.

Parameters:

Name Type Description Default
state EnvState

Environment state

required
params EnvParams

Environment parameters

required

Returns:

Name Type Description
state EnvState

Environment state with updated graph tuple

Source code in xlron/environments/env_funcs.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
def update_graph_tuple(state: EnvState, params: EnvParams) -> EnvState:
    """Update graph tuple for use with Jraph GNNs.

    Edge and node features are updated from link_slot_array and node_capacity_array respectively.
    Global features are updated as request_array.

    Args:
        state (EnvState): Environment state
        params (EnvParams): Environment parameters

    Returns:
        state (EnvState): Environment state with updated graph tuple
    """
    # Get source and dest from request array
    source_dest, datarate = read_rsa_request(state.request_array)
    source, dest = source_dest[0], source_dest[2]
    # Current request as global feature
    globals = jnp.array(
        [datarate / jnp.max(params.values_bw.val)], dtype=dtype_config.LARGE_FLOAT_DTYPE
    )
    # One-hot encode source and destination
    source_dest_features = jnp.zeros((params.num_nodes, 2), dtype=dtype_config.LARGE_FLOAT_DTYPE)
    # Convert indices to int32 for indexing...
    source_idx = source.astype(dtype_config.LARGE_INT_DTYPE)
    dest_idx = dest.astype(dtype_config.LARGE_INT_DTYPE)
    # ...but maintain grads in value with differentiable index updates
    source_dest_features = differentiable_one_hot_index_update(
        source_dest_features, source_idx, 1.0, params.temperature, params.differentiable
    )
    source_dest_features = differentiable_one_hot_index_update(
        source_dest_features, dest_idx, -1.0, params.temperature, params.differentiable
    )
    spectral_features = state.graph.nodes[..., : params.num_spectral_features]
    holding_time_edge_features = state.link_slot_departure_array / params.mean_service_holding_time

    if params.__class__.__name__ in ["RSAGNModelEnvParams", "RMSAGNModelEnvParams"]:
        # Normalize by max parameters (converted to linear units)
        max_power = isrs_gn_model.from_dbm(params.max_power)
        # Use differentiable rounding
        normalized_power = differentiable_round(
            state.channel_power_array / max_power,
            decimals=3,
            temperature=params.temperature,
            differentiable=params.differentiable,
        )

        max_snr = isrs_gn_model.from_db(params.max_snr)
        # Use differentiable rounding
        normalized_snr = differentiable_round(
            state.link_snr_array / max_snr,
            decimals=3,
            temperature=params.temperature,
            differentiable=params.differentiable,
        )

        edge_features = jnp.stack([normalized_snr, normalized_power], axis=-1)
        node_features = jnp.concatenate([spectral_features, source_dest_features], axis=-1)
    elif params.__class__.__name__ == "VONEEnvParams":
        edge_features = (
            state.link_slot_array
            if params.mean_service_holding_time > 1e5
            else holding_time_edge_features
        )
        node_features = getattr(state, "node_capacity_array", jnp.zeros(params.num_nodes))
        node_features = node_features.reshape(-1, 1)
        node_features = jnp.concatenate(
            [node_features, spectral_features, source_dest_features], axis=-1
        )
    else:
        edge_features = (
            state.link_slot_array
            if params.mean_service_holding_time > 1e5
            else holding_time_edge_features
        )
        node_features = jnp.concatenate([spectral_features, source_dest_features], axis=-1)

    if params.disable_node_features:
        node_features = jnp.zeros((1,), dtype=dtype_config.LARGE_FLOAT_DTYPE)

    edge_features = edge_features if params.directed_graph else jnp.repeat(edge_features, 2, axis=0)
    graph = state.graph._replace(nodes=node_features, edges=edge_features, globals=globals)
    state = state.replace(graph=graph)
    return state

Training Utilities

TrainState

Bases: Module

Train state for Equinox models.

The model is stored but marked as non-pytree so JAX doesn't try to trace it.

Source code in xlron/train/train_utils.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
class TrainState(eqx.Module):
    """Train state for Equinox models.

    The model is stored but marked as non-pytree so JAX doesn't try to trace it.
    """

    step: Array
    model_params: eqx.Module
    model_static: eqx.Module = eqx.field(static=True)  # Mark as static/non-pytree
    tx: optax.GradientTransformation = eqx.field(static=True)
    opt_state: optax.OptState
    lr_schedule: Schedule = eqx.field(static=True)
    ent_schedule: Schedule = eqx.field(static=True)
    vml_schedule: Schedule = eqx.field(static=True)
    avg_reward: Array
    reward_stepsize: Array
    reward_stepsize_init: Array
    reward_stepsize_offset: Array
    prio_alpha: Array
    prio_beta0: Array
    prio_beta: Array

    def apply_gradients(self, grads: Any) -> "TrainState":
        """Updates model parameters and opt_state."""
        model = eqx.combine(self.model_params, self.model_static)
        updates, new_opt_state = self.tx.update(grads, self.opt_state, self.model_params)
        new_model = eqx.apply_updates(model, updates)
        new_model_params, new_model_static = eqx.partition(new_model, eqx.is_inexact_array)
        # Can't use eqx.tree_at for static fields, create new instance
        return TrainState(
            step=self.step,
            model_params=new_model_params,
            model_static=new_model_static,
            tx=self.tx,
            opt_state=new_opt_state,
            lr_schedule=self.lr_schedule,
            ent_schedule=self.ent_schedule,
            vml_schedule=self.vml_schedule,
            avg_reward=self.avg_reward,
            reward_stepsize=self.reward_stepsize,
            reward_stepsize_init=self.reward_stepsize_init,
            reward_stepsize_offset=self.reward_stepsize_offset,
            prio_alpha=self.prio_alpha,
            prio_beta0=self.prio_beta0,
            prio_beta=self.prio_beta,
        )

    def update_step_size(self) -> "TrainState":
        """Updates the step size used for reward centering."""
        reward_stepsize_offset = self.reward_stepsize_offset + self.reward_stepsize_init * (
            1 - self.reward_stepsize_offset
        )
        reward_stepsize = self.reward_stepsize_init / reward_stepsize_offset
        return eqx.tree_at(
            lambda state: (state.reward_stepsize, state.reward_stepsize_offset),
            self,
            (reward_stepsize, reward_stepsize_offset),
        )

    @staticmethod
    def create(
        model: eqx.Module | None,
        tx: optax.GradientTransformation,
        lr_schedule: Schedule = lambda x: jnp.array(0.0),
        ent_schedule: Schedule = lambda x: jnp.array(0.0),
        vml_schedule: Schedule = lambda x: jnp.array(0.0),
        prio_alpha: float = 0.0,
        prio_beta0: float = 1.0,
        prio_beta: float = 1.0,
        reward_stepsize_init: float = 0.001,
        initial_avg_reward: float = 0.0,
    ) -> "TrainState":
        """Creates a new instance with step=0 and initialized opt_state."""
        opt_state = tx.init(eqx.filter(model, eqx.is_inexact_array))
        model_params, model_static = eqx.partition(model, eqx.is_inexact_array)
        return TrainState(
            step=jnp.array(0),
            model_params=model_params,
            model_static=model_static,
            tx=tx,
            opt_state=opt_state,
            lr_schedule=lr_schedule,
            ent_schedule=ent_schedule,
            vml_schedule=vml_schedule,
            avg_reward=jnp.array(initial_avg_reward, dtype=dtype_config.REWARD_DTYPE),
            reward_stepsize=jnp.array(reward_stepsize_init, dtype=dtype_config.REWARD_DTYPE),
            reward_stepsize_init=jnp.array(reward_stepsize_init, dtype=dtype_config.REWARD_DTYPE),
            reward_stepsize_offset=jnp.array(1.0, dtype=dtype_config.REWARD_DTYPE),
            prio_alpha=jnp.array(prio_alpha, dtype=dtype_config.REWARD_DTYPE),
            prio_beta0=jnp.array(prio_beta0, dtype=dtype_config.REWARD_DTYPE),
            prio_beta=jnp.array(prio_beta, dtype=dtype_config.REWARD_DTYPE),
        )

apply_gradients(grads)

Updates model parameters and opt_state.

Source code in xlron/train/train_utils.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def apply_gradients(self, grads: Any) -> "TrainState":
    """Updates model parameters and opt_state."""
    model = eqx.combine(self.model_params, self.model_static)
    updates, new_opt_state = self.tx.update(grads, self.opt_state, self.model_params)
    new_model = eqx.apply_updates(model, updates)
    new_model_params, new_model_static = eqx.partition(new_model, eqx.is_inexact_array)
    # Can't use eqx.tree_at for static fields, create new instance
    return TrainState(
        step=self.step,
        model_params=new_model_params,
        model_static=new_model_static,
        tx=self.tx,
        opt_state=new_opt_state,
        lr_schedule=self.lr_schedule,
        ent_schedule=self.ent_schedule,
        vml_schedule=self.vml_schedule,
        avg_reward=self.avg_reward,
        reward_stepsize=self.reward_stepsize,
        reward_stepsize_init=self.reward_stepsize_init,
        reward_stepsize_offset=self.reward_stepsize_offset,
        prio_alpha=self.prio_alpha,
        prio_beta0=self.prio_beta0,
        prio_beta=self.prio_beta,
    )

create(model, tx, lr_schedule=lambda x: jnp.array(0.0), ent_schedule=lambda x: jnp.array(0.0), vml_schedule=lambda x: jnp.array(0.0), prio_alpha=0.0, prio_beta0=1.0, prio_beta=1.0, reward_stepsize_init=0.001, initial_avg_reward=0.0) staticmethod

Creates a new instance with step=0 and initialized opt_state.

Source code in xlron/train/train_utils.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
@staticmethod
def create(
    model: eqx.Module | None,
    tx: optax.GradientTransformation,
    lr_schedule: Schedule = lambda x: jnp.array(0.0),
    ent_schedule: Schedule = lambda x: jnp.array(0.0),
    vml_schedule: Schedule = lambda x: jnp.array(0.0),
    prio_alpha: float = 0.0,
    prio_beta0: float = 1.0,
    prio_beta: float = 1.0,
    reward_stepsize_init: float = 0.001,
    initial_avg_reward: float = 0.0,
) -> "TrainState":
    """Creates a new instance with step=0 and initialized opt_state."""
    opt_state = tx.init(eqx.filter(model, eqx.is_inexact_array))
    model_params, model_static = eqx.partition(model, eqx.is_inexact_array)
    return TrainState(
        step=jnp.array(0),
        model_params=model_params,
        model_static=model_static,
        tx=tx,
        opt_state=opt_state,
        lr_schedule=lr_schedule,
        ent_schedule=ent_schedule,
        vml_schedule=vml_schedule,
        avg_reward=jnp.array(initial_avg_reward, dtype=dtype_config.REWARD_DTYPE),
        reward_stepsize=jnp.array(reward_stepsize_init, dtype=dtype_config.REWARD_DTYPE),
        reward_stepsize_init=jnp.array(reward_stepsize_init, dtype=dtype_config.REWARD_DTYPE),
        reward_stepsize_offset=jnp.array(1.0, dtype=dtype_config.REWARD_DTYPE),
        prio_alpha=jnp.array(prio_alpha, dtype=dtype_config.REWARD_DTYPE),
        prio_beta0=jnp.array(prio_beta0, dtype=dtype_config.REWARD_DTYPE),
        prio_beta=jnp.array(prio_beta, dtype=dtype_config.REWARD_DTYPE),
    )

update_step_size()

Updates the step size used for reward centering.

Source code in xlron/train/train_utils.py
172
173
174
175
176
177
178
179
180
181
182
def update_step_size(self) -> "TrainState":
    """Updates the step size used for reward centering."""
    reward_stepsize_offset = self.reward_stepsize_offset + self.reward_stepsize_init * (
        1 - self.reward_stepsize_offset
    )
    reward_stepsize = self.reward_stepsize_init / reward_stepsize_offset
    return eqx.tree_at(
        lambda state: (state.reward_stepsize, state.reward_stepsize_offset),
        self,
        (reward_stepsize, reward_stepsize_offset),
    )

count_parameters(params)

Counts the number of parameters in a parameter tree.

Source code in xlron/train/train_utils.py
224
225
226
def count_parameters(params: chex.ArrayTree) -> int:
    """Counts the number of parameters in a parameter tree."""
    return sum(x.size for x in jax.tree_util.tree_leaves(params))

get_warmup_fn(warmup_state, env, params, train_state, config)

Warmup period for DeepRMSA.

Source code in xlron/train/train_utils.py
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
def get_warmup_fn(warmup_state, env, params, train_state, config) -> Callable[[Tuple], Tuple]:
    """Warmup period for DeepRMSA."""

    def warmup_fn(warmup_state) -> Tuple[EnvState, chex.Array]:
        rng, state, last_obs = warmup_state

        def warmup_step(i, val) -> Tuple:
            _rng, _state, _params, _train_state, _last_obs = val
            # SELECT ACTION
            _rng, action_key, step_key = jax.random.split(_rng, 3)
            select_action_state = (_rng, _state, _last_obs)
            action_fn = select_action if not config.EVAL_HEURISTIC else select_action_eval
            _state, action, log_prob, value = action_fn(
                select_action_state, env, _params, _train_state, config
            )
            if "gn_model" in config.env_type.lower() and config.launch_power_type == "rl":
                # If the action is launch power, the action is this shape:
                # jnp.concatenate([path_action.reshape((1,)), power_action.reshape((1,))], axis=0)
                # We want to overwrite the launch power with a default launch_power
                path_action = (
                    ksp_lf(_state.env_state, _params)
                    if _params.last_fit is True
                    else ksp_ff(_state.env_state, _params)
                )
                action = jnp.concatenate(
                    [
                        path_action.reshape((1,)),
                        jnp.array(
                            [
                                params.default_launch_power,
                            ]
                        ),
                    ],
                    axis=0,
                )
            elif (
                "gn_model" in config.env_type.lower()
                and config.launch_power_type != "rl"
                and not config.EVAL_HEURISTIC
            ):
                raise ValueError("Check that EVAL_HEURISTIC is set to True if using a heuristic")
            # STEP ENV
            obsv, _state, reward, terminal, truncated, info = env.step(
                step_key, _state, action, params
            )
            obsv = (
                (_state.env_state, params)
                if config.USE_GNN or config.USE_TRANSFORMER
                else tuple([obsv])
            )
            return _rng, _state, _params, _train_state, obsv

        vals = jax.lax.fori_loop(
            0, config.ENV_WARMUP_STEPS, warmup_step, (rng, state, params, train_state, last_obs)
        )

        return vals[1], vals[4]

    return warmup_fn

log_metrics(config, out, total_run_time, increment_run_time, merge_func, episode_count=0, update_count=0, step_count=0)

Log metrics to wandb and/or save episode end metrics to CSV.

Source code in xlron/train/train_utils.py
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
def log_metrics(
    config: Box,
    out: Dict[str, Dict[str, Array]],
    total_run_time: float,
    increment_run_time: float,
    merge_func: Callable,
    episode_count: int = 0,
    update_count: int = 0,
    step_count: int = 0,
) -> Tuple[Dict, Dict]:
    """Log metrics to wandb and/or save episode end metrics to CSV."""

    with TimeIt("Processing metrics"):
        merged_out, merged_out_loss, processed_data, episode_ends = process_metrics(
            config, out, merge_func
        )

    all_metrics = list(processed_data.keys())
    if not config.LOG_ALL_INFO:
        all_metrics = [
            "service_blocking_probability",
            "bitrate_blocking_probability",
            "accepted_services",
            "accepted_bitrate",
        ]

    with TimeIt("Logging metrics"):
        if config.DATA_OUTPUT_FILE:
            print("Saving metrics to file")
            # Save episode end metrics to file
            episode_end_df = pd.DataFrame(
                {
                    f"{metric}_{stat}": processed_data[metric][stat]
                    for metric in all_metrics
                    for stat in [
                        "episode_end_mean",
                        "episode_end_std",
                        "episode_end_iqr_upper",
                        "episode_end_iqr_lower",
                    ]
                }
            )
            # Check if data output file exists
            write_headers = not os.path.exists(config.DATA_OUTPUT_FILE)
            episode_end_df.to_csv(
                config.DATA_OUTPUT_FILE, mode="a", header=write_headers, index=False
            )
            # Pickle merged_out for further analysis
            with open(config.DATA_OUTPUT_FILE.replace(".csv", ".pkl"), "wb") as f:
                pickle.dump(merged_out, f)

        if config.WANDB:
            print("Logging metrics to wandb")

            if not config.continuous_operation:
                # Log metrics from every step
                # Define the downsample factor to speed up upload to wandb
                # Then reshape the array and compute the mean
                training_time = (
                    jnp.arange(len(processed_data[all_metrics[0]]["episode_end_mean"]))
                    / len(processed_data[all_metrics[0]]["episode_end_mean"])
                    * increment_run_time
                ) + total_run_time
                # Log episode end metrics
                print(f"Logging episode end metrics for {np.sum(episode_ends)} episodes")
                for i in range(len(processed_data[all_metrics[0]]["episode_end_mean"])):
                    log_dict = {
                        f"{metric}_{stat}": processed_data[metric][stat][i]
                        for metric in all_metrics
                        for stat in [
                            "episode_end_mean",
                            "episode_end_std",
                            "episode_end_iqr_upper",
                            "episode_end_iqr_lower",
                        ]
                    }
                    log_dict["training_time"] = training_time[i]
                    log_dict["episode_count"] = i + episode_count
                    wandb.log(log_dict)

            else:
                # Log metrics from every step
                # Define the downsample factor to speed up upload to wandb
                # Then reshape the array and compute the mean
                training_time = (
                    jnp.arange(len(processed_data[all_metrics[0]]["mean"]))
                    / len(processed_data[all_metrics[0]]["mean"])
                    * increment_run_time
                ) + total_run_time

                chop = len(processed_data[all_metrics[0]]["mean"]) % config.DOWNSAMPLE_FACTOR

                def downsample_mean(x: Array) -> Array:
                    x = jnp.asarray(x)
                    return x[chop:].reshape(-1, config.DOWNSAMPLE_FACTOR).mean(axis=1)

                for key in all_metrics:
                    processed_data[key]["mean"] = downsample_mean(processed_data[key]["mean"])
                    processed_data[key]["std"] = downsample_mean(processed_data[key]["std"])
                    processed_data[key]["iqr_upper"] = downsample_mean(
                        processed_data[key]["iqr_upper"]
                    )
                    processed_data[key]["iqr_lower"] = downsample_mean(
                        processed_data[key]["iqr_lower"]
                    )
                training_time = downsample_mean(training_time)

                # Log per step metrics
                print("Logging per step metrics")
                for i in range(len(processed_data[all_metrics[0]]["mean"])):
                    log_dict = {
                        f"{metric}_{agg}": processed_data[metric][agg][i]
                        for metric in all_metrics
                        for agg in ["mean", "std", "iqr_upper", "iqr_lower"]
                    }
                    log_dict["training_time"] = training_time[i]
                    log_dict["env_step"] = (i * config.DOWNSAMPLE_FACTOR) + step_count
                    wandb.log(log_dict)

            if config.LOG_LOSS_INFO and merged_out_loss is not None:
                print("Logging loss info")
                for i in range(len(merged_out_loss["loss/total_loss"])):
                    log_dict = {f"{metric}": merged_out_loss[metric][i] for metric in loss_metrics}
                    if config.REWARD_CENTERING:
                        log_dict_rc = {
                            f"{metric}": merged_out_loss[metric][i]
                            for metric in reward_centering_metrics
                        }
                        log_dict = {**log_dict, **log_dict_rc}
                    if config.ENHANCED_LOGGING:
                        log_dict_diag = {
                            f"{metric}": merged_out_loss[metric][i]
                            for metric in diagnostics_metrics
                        }
                        log_dict = {**log_dict, **log_dict_diag}
                    log_dict["update_epoch"] = i + update_count
                    wandb.log(log_dict)

    return merged_out, processed_data

make_ent_schedule(config)

Create an entropy coefficient schedule based on the configuration.

Source code in xlron/train/train_utils.py
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
def make_ent_schedule(config: Box) -> optax.Schedule:
    """Create an entropy coefficient schedule based on the configuration."""

    ENT_COEF = config.ENT_COEF
    ENT_END_FRACTION = config.ENT_END_FRACTION
    NUM_MINIBATCHES = config.NUM_MINIBATCHES
    NUM_UPDATES = config.NUM_UPDATES * config.NUM_INCREMENTS
    UPDATE_EPOCHS = config.UPDATE_EPOCHS
    SCHEDULE_MULTIPLIER = config.ENT_SCHEDULE_MULTIPLIER
    end_value = ENT_COEF * ENT_END_FRACTION

    def ent_schedule(count: chex.Numeric) -> chex.Numeric:
        total_steps = NUM_UPDATES * UPDATE_EPOCHS * NUM_MINIBATCHES * SCHEDULE_MULTIPLIER
        if config.ENT_SCHEDULE == "cosine":
            schedule = optax.cosine_decay_schedule(
                init_value=ENT_COEF,
                decay_steps=total_steps,
                alpha=end_value,
            )
        elif config.ENT_SCHEDULE == "linear":
            schedule = optax.linear_schedule(
                init_value=ENT_COEF,
                end_value=end_value,
                transition_steps=total_steps,
            )
        elif config.ENT_SCHEDULE == "constant":
            schedule = optax.constant_schedule(ENT_COEF)
        else:
            raise ValueError(f"Invalid entropy schedule {config.ENT_SCHEDULE}")
        return schedule(count)

    return ent_schedule

make_lr_schedule(config)

Create a learning rate schedule based on the configuration.

Source code in xlron/train/train_utils.py
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
def make_lr_schedule(config: Box) -> optax.Schedule:
    """Create a learning rate schedule based on the configuration."""

    LR = config.LR
    LR_END_FRACTION = config.LR_END_FRACTION
    NUM_MINIBATCHES = config.NUM_MINIBATCHES
    NUM_UPDATES = config.NUM_UPDATES * config.NUM_INCREMENTS
    UPDATE_EPOCHS = config.UPDATE_EPOCHS
    SCHEDULE_MULTIPLIER = config.LR_SCHEDULE_MULTIPLIER
    WARMUP_MULTIPLIER = config.WARMUP_MULTIPLIER
    WARMUP_STEPS_FRACTION = config.WARMUP_STEPS_FRACTION
    end_value = LR * LR_END_FRACTION

    def lr_schedule(count: chex.Numeric) -> chex.Numeric:
        total_steps = NUM_UPDATES * UPDATE_EPOCHS * NUM_MINIBATCHES * SCHEDULE_MULTIPLIER
        if config.LR_SCHEDULE == "warmup_cosine":
            schedule = optax.warmup_cosine_decay_schedule(
                init_value=LR,
                peak_value=LR * WARMUP_MULTIPLIER,
                warmup_steps=total_steps * WARMUP_STEPS_FRACTION,
                decay_steps=total_steps,
                end_value=end_value,
            )
        elif config.LR_SCHEDULE == "cosine":
            schedule = optax.cosine_decay_schedule(
                init_value=LR,
                decay_steps=total_steps,
                alpha=end_value,
            )
        elif config.LR_SCHEDULE == "linear":
            schedule = optax.linear_schedule(
                init_value=LR,
                end_value=end_value,
                transition_steps=total_steps,
            )
        elif config.LR_SCHEDULE == "constant":
            schedule = optax.constant_schedule(LR)
        else:
            raise ValueError(f"Invalid LR schedule {config.LR_SCHEDULE}")
        return schedule(count)

    return lr_schedule

make_vf_lr_schedule(config)

Create a learning rate schedule for the value function optimizer.

Source code in xlron/train/train_utils.py
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
def make_vf_lr_schedule(config: Box) -> optax.Schedule:
    """Create a learning rate schedule for the value function optimizer."""
    vf_lr = config.VF_LR if config.VF_LR is not None else config.LR / 3.0
    vf_schedule_type = (
        config.VF_LR_SCHEDULE if config.VF_LR_SCHEDULE is not None else config.LR_SCHEDULE
    )
    vf_end_fraction = (
        config.VF_LR_END_FRACTION
        if config.VF_LR_END_FRACTION is not None
        else config.LR_END_FRACTION
    )
    vf_warmup_mult = (
        config.VF_WARMUP_MULTIPLIER
        if config.VF_WARMUP_MULTIPLIER is not None
        else config.WARMUP_MULTIPLIER
    )
    vf_warmup_frac = (
        config.VF_WARMUP_STEPS_FRACTION
        if config.VF_WARMUP_STEPS_FRACTION is not None
        else config.WARMUP_STEPS_FRACTION
    )
    return _make_schedule(
        vf_lr,
        vf_end_fraction,
        vf_schedule_type,
        vf_warmup_mult,
        vf_warmup_frac,
        config,
        schedule_multiplier=config.VF_SCHEDULE_MULTIPLIER,
    )

make_vml_schedule(config)

Create a valid mass loss coefficient schedule based on the configuration.

Source code in xlron/train/train_utils.py
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
def make_vml_schedule(config: Box) -> optax.Schedule:
    """Create a valid mass loss coefficient schedule based on the configuration."""

    VML_COEF = config.VALID_MASS_LOSS_COEF
    VML_END_FRACTION = config.VML_END_FRACTION
    NUM_MINIBATCHES = config.NUM_MINIBATCHES
    NUM_UPDATES = config.NUM_UPDATES * config.NUM_INCREMENTS
    UPDATE_EPOCHS = config.UPDATE_EPOCHS
    SCHEDULE_MULTIPLIER = config.VML_SCHEDULE_MULTIPLIER
    end_value = VML_COEF * VML_END_FRACTION

    def vml_schedule(count: chex.Numeric) -> chex.Numeric:
        total_steps = NUM_UPDATES * UPDATE_EPOCHS * NUM_MINIBATCHES * SCHEDULE_MULTIPLIER
        if config.VML_SCHEDULE == "cosine":
            schedule = optax.cosine_decay_schedule(
                init_value=VML_COEF,
                decay_steps=total_steps,
                alpha=end_value,
            )
        elif config.VML_SCHEDULE == "linear":
            schedule = optax.linear_schedule(
                init_value=VML_COEF,
                end_value=end_value,
                transition_steps=total_steps,
            )
        elif config.VML_SCHEDULE == "constant":
            schedule = optax.constant_schedule(VML_COEF)
        else:
            raise ValueError(f"Invalid VML schedule {config.VML_SCHEDULE}")
        return schedule(count)

    return vml_schedule

merge_leading_dims(x, num_dims)

Merge leading dimensions.

Note

This implementation is a generic function for merging leading dimensions extracted from Haiku. For the original implementation, please refer to the following link: (https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/basic.py#L207)

Source code in xlron/train/train_utils.py
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
def merge_leading_dims(x: chex.Array, num_dims: chex.Numeric) -> chex.Array:
    """Merge leading dimensions.

    Note:
        This implementation is a generic function for merging leading dimensions
        extracted from Haiku.
        For the original implementation, please refer to the following link:
        (https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/basic.py#L207)
    """
    # Don't merge if there aren't dimensions to merge.
    if not ndim_at_least(x, num_dims):
        return x

    new_shape = (np.prod(x.shape[:num_dims]),) + x.shape[num_dims:]
    return x.reshape(new_shape)

ndim_at_least(x, num_dims)

Check if the number of dimensions of x is at least num_dims.

Source code in xlron/train/train_utils.py
404
405
406
407
408
def ndim_at_least(x: chex.Array, num_dims: chex.Numeric) -> jax.Array:
    """Check if the number of dimensions of `x` is at least `num_dims`."""
    if not (isinstance(x, jax.Array) or isinstance(x, np.ndarray)):
        x = jnp.asarray(x)
    return x.ndim >= num_dims

print_experiment_summary(config, env_params=None)

Print a formatted summary of the experiment configuration.

Source code in xlron/train/train_utils.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
def print_experiment_summary(config: Box, env_params=None) -> None:
    """Print a formatted summary of the experiment configuration."""
    W = 70
    sep = "=" * W

    # --- Determine execution mode ---
    if config.get("EVAL_HEURISTIC", False):
        mode = f"Heuristic Evaluation ({config.get('path_heuristic', '?')})"
    elif config.get("EVAL_MODEL", False):
        mode = "Model Evaluation"
    elif config.get("ACTION_OPTIMIZATION", False):
        mode = "Action Optimization"
    else:
        mode = "RL Training (PPO)"

    # --- Environment ---
    env_type = config.get("env_type", "?")
    topology = config.get("topology_name", "?")
    if env_params:
        num_nodes = env_params.num_nodes
        num_links = env_params.num_links
    elif config.get("NUM_NODES"):
        num_nodes = config.NUM_NODES
        num_links = config.NUM_LINKS
    else:
        graph = make_graph(topology, topology_directory=config.get("topology_directory", None))
        num_nodes = len(graph.nodes)
        num_links = len(graph.edges)
    k = env_params.k_paths if env_params else config.get("k", "?")
    link_resources = env_params.link_resources if env_params else config.get("link_resources", "?")
    slot_size = env_params.slot_size if env_params else config.get("slot_size", 12.5)
    guardband = env_params.guardband if env_params else config.get("guardband", 1)
    directed = env_params.directed_graph if env_params else config.get("directed_graph", "?")
    path_sort = config.get("path_sort_criteria", "hops")
    total_bw = float(link_resources) * float(slot_size) if link_resources != "?" else "?"

    # --- Traffic ---
    load = env_params.load if env_params else config.get("load", "?")
    holding_time = (
        env_params.mean_service_holding_time if env_params else config.get("mean_service_holding_time", "?")
    )
    if env_params:
        arrival_rate = env_params.arrival_rate
    elif load != "?" and holding_time != "?":
        arrival_rate = float(load) / float(holding_time)
    else:
        arrival_rate = "?"
    continuous = env_params.continuous_operation if env_params else config.get("continuous_operation", False)
    incremental = env_params.incremental_loading if env_params else config.get("incremental_loading", False)
    max_requests = env_params.max_requests if env_params else config.get("max_requests", "?")
    warmup = config.get("ENV_WARMUP_STEPS", 0)
    reward_type = env_params.reward_type if env_params else config.get("reward_type", "service")
    values_bw = env_params.values_bw if env_params else config.get("values_bw", None)
    if hasattr(values_bw, "val"):
        values_bw = values_bw.val
    truncate_ht = env_params.truncate_holding_time if env_params else config.get("truncate_holding_time", False)

    # --- Training / Execution ---
    total_timesteps = config.get("TOTAL_TIMESTEPS", "?")
    num_envs = config.get("NUM_ENVS", 1)
    num_learners = config.get("NUM_LEARNERS", 1)
    rollout_length = config.get("ROLLOUT_LENGTH", "?")
    num_updates = config.get("NUM_UPDATES", "?")
    num_minibatches = config.get("NUM_MINIBATCHES", 1)
    update_epochs = config.get("UPDATE_EPOCHS", 1)
    num_increments = config.get("NUM_INCREMENTS", 1)
    steps_per_inc = config.get("STEPS_PER_INCREMENT", "?")
    batch_size = config.get("MINIBATCH_SIZE", "?")

    # --- Model ---
    if config.get("USE_TRANSFORMER", False):
        arch = "Transformer"
        arch_detail = (
            f"{config.get('transformer_num_layers', '?')}L / "
            f"{config.get('transformer_num_heads', '?')}H / "
            f"d={config.get('transformer_embedding_size', '?')}"
        )
    elif config.get("USE_GNN", False):
        arch = "GNN"
        arch_detail = (
            f"{config.get('message_passing_steps', '?')} msg steps / "
            f"edge_emb={config.get('edge_embedding_size', '?')}"
        )
    else:
        arch = "MLP"
        arch_detail = (
            f"{config.get('NUM_LAYERS', '?')} layers x "
            f"{config.get('NUM_UNITS', '?')} units"
        )

    # --- Print ---
    print(f"\n{sep}")
    print("EXPERIMENT SUMMARY")
    print(sep)

    print(f"  Mode:                     {mode}")
    print(f"  Experiment name:          {config.get('EXPERIMENT_NAME', '-')}")

    print(f"\n  {'--- Environment ---':^{W - 4}}")
    print(f"  Type:                     {env_type}")
    print(f"  Topology:                 {topology} ({'directed' if directed else 'undirected'})")
    print(f"  Nodes / Links:            {num_nodes} / {num_links}")
    print(f"  K-shortest paths:         {k}  (sort: {path_sort})")
    print(f"  Slots per link:           {link_resources}  ({slot_size} GHz each, guardband={guardband})")
    if total_bw != "?":
        print(f"  Total spectrum per link:  {total_bw:.0f} GHz")
    consider_mod = env_params.consider_modulation_format if env_params else "?"
    if consider_mod and consider_mod != "?":
        print("  Modulation format:        enabled")
    if config.get("aggregate_slots", 1) > 1:
        print(f"  Slot aggregation:         {config.get('aggregate_slots')}x")

    print(f"\n  {'--- Traffic ---':^{W - 4}}")
    print(f"  Load:                     {load} Erlang")
    print(f"  Arrival rate:             {arrival_rate}")
    print(f"  Mean holding time:        {holding_time}{' (truncated)' if truncate_ht else ''}")
    if values_bw is not None:
        bw_str = ", ".join(str(int(v)) for v in np.asarray(values_bw).flatten()) if hasattr(values_bw, '__len__') else str(values_bw)
        print(f"  Bandwidth values (Gbps):  [{bw_str}]")
    print(f"  Reward type:              {reward_type}")
    op_mode = "continuous" if continuous else ("incremental" if incremental else "episodic")
    print(f"  Operation mode:           {op_mode}")
    if not continuous:
        print(f"  Max requests / episode:   {max_requests}")
    if warmup > 0:
        print(f"  Warmup steps:             {warmup}")

    print(f"\n  {'--- Execution ---':^{W - 4}}")
    print(f"  Total timesteps:          {total_timesteps:,}" if isinstance(total_timesteps, int) else f"  Total timesteps:          {total_timesteps}")
    print(f"  Parallel envs:            {num_envs}")
    if num_learners > 1:
        print(f"  Independent learners:     {num_learners}")
        print(f"  Grand total timesteps:    {total_timesteps * num_learners:,}")
    print(f"  Increments:               {num_increments}  ({steps_per_inc:,} steps each)" if isinstance(steps_per_inc, int) else f"  Increments:               {num_increments}")

    if not config.get("EVAL_HEURISTIC", False):
        print(f"  Rollout length:           {rollout_length}")
        batch_total = num_envs * rollout_length if isinstance(rollout_length, int) else "?"
        print(f"  Batch size:               {batch_total}  (= {num_envs} envs x {rollout_length} steps)")
        print(f"  Minibatches / epoch:      {num_minibatches}  (minibatch size: {batch_size})")
        print(f"  Update epochs:            {update_epochs}")
        total_updates = num_increments * num_updates * update_epochs * num_minibatches
        print(f"  Total gradient steps:     {total_updates:,}" if isinstance(total_updates, int) else f"  Total gradient steps:     {total_updates}")

    if not (config.get("EVAL_HEURISTIC", False) or config.get("EVAL_MODEL", False)):
        print(f"\n  {'--- Model & Optimiser ---':^{W - 4}}")
        print(f"  Architecture:             {arch}  ({arch_detail})")
        print(f"  Activation:               {config.get('ACTIVATION', '?')}")
        print(f"  Learning rate:            {config.get('LR', '?')}  (schedule: {config.get('LR_SCHEDULE', '?')})")
        print(f"  Discount (gamma):         {config.get('GAMMA', '?')}")
        gae = config.get("GAE_LAMBDA", None)
        if gae is not None:
            print(f"  GAE lambda:               {gae}")
        else:
            print(f"  GAE lambda:               annealed ({config.get('INITIAL_LAMBDA', '?')} -> {config.get('FINAL_LAMBDA', '?')})")
        print(f"  PPO clip:                 {config.get('CLIP_EPS', '?')}")
        print(f"  Entropy coef:             {config.get('ENT_COEF', '?')}  (schedule: {config.get('ENT_SCHEDULE', '?')})")
        print(f"  VF coef:                  {config.get('VF_COEF', '?')}")
        print(f"  Max grad norm:            {config.get('MAX_GRAD_NORM', '?')}")
        if config.get("REWARD_CENTERING", False):
            print(f"  Reward centering:         enabled (stepsize={config.get('REWARD_STEPSIZE', '?')})")
        if config.get("SEPARATE_VF_OPTIMIZER", False):
            print(f"  Separate VF optimizer:    enabled (VF_LR={config.get('VF_LR', 'auto')})")

    if config.get("EVAL_DURING_TRAINING", False):
        print(f"\n  Eval during training:     every {config.get('EVAL_FREQUENCY', '?')} increment(s)")

    if config.get("WANDB", False):
        print(f"  Logging:                  wandb ({config.get('PROJECT', '-')})")
    if config.get("SAVE_MODEL", False):
        print("  Model saving:             enabled")

    print(sep + "\n")

process_metrics(config, out, merge_func)

Calculate statistics from training or evaluation run.

Source code in xlron/train/train_utils.py
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
def process_metrics(config, out, merge_func):
    """Calculate statistics from training or evaluation run."""
    merged_out = {k: jax.tree.map(merge_func, v) for k, v in out["metrics"].items()}
    if config.EVAL_HEURISTIC or config.EVAL_MODEL:
        merged_out_loss = None
    else:
        # Average over minibatches and epochs to get one value per update
        num_learners_or_1 = config.NUM_LEARNERS if config.NUM_LEARNERS > 1 else 1
        merged_out_loss = {
            k: jax.tree.map(
                lambda x: x.reshape((num_learners_or_1, config.NUM_UPDATES, -1))
                .mean(axis=-1)
                .reshape((-1,)),
                v,
            )
            for k, v in out.get("loss_info", {}).items()
        }

    # Calculate blocking probabilities
    merged_out["service_blocking_probability"] = 1 - (
        merged_out["accepted_services"]
        / jnp.where(merged_out["lengths"] == 0, 1, merged_out["lengths"])
    )
    merged_out["bitrate_blocking_probability"] = 1 - (
        merged_out["accepted_bitrate"]
        / jnp.where(merged_out["total_bitrate"] == 0, 1, merged_out["total_bitrate"])
    )

    # Calculate episode ends
    merged_out["done"] = jnp.logical_or(merged_out["terminal"], merged_out["truncated"])
    episode_ends = merged_out["done"]
    # Instead of flattening, create a boolean mask of where episodes end
    # This preserves the structure across environments
    episode_ends = episode_ends.reshape(episode_ends.shape[0], -1)

    episode_ends = jnp.hstack((episode_ends[:, 1:], jnp.full((episode_ends.shape[0], 1), False)))

    # Reshape episode_ends to match the original shape
    episode_ends = episode_ends.reshape(merged_out["done"].shape)

    print(f"Created episode end mask with {np.sum(episode_ends)} episode endings")

    processed_data = {}
    print("Processing output metrics")
    for metric in metrics:
        if metric == "throughput":
            # Shift values down one index position
            ends = jnp.concatenate([jnp.array([False]), episode_ends.flatten()[:-1]]).reshape(
                episode_ends.shape
            )
        else:
            ends = episode_ends
        try:
            episode_end_mean, episode_end_std, episode_end_iqr_upper, episode_end_iqr_lower = (
                get_episode_end_mean_std_iqr(merged_out[metric], ends, config.NUM_ENVS)
            )
            mean, std, iqr_upper, iqr_lower = get_mean_std_iqr(merged_out, metric)
            processed_data[metric] = {
                "mean": mean,
                "std": std,
                "iqr_upper": iqr_upper,
                "iqr_lower": iqr_lower,
                "episode_end_mean": episode_end_mean,
                "episode_end_std": episode_end_std,
                "episode_end_iqr_upper": episode_end_iqr_upper,
                "episode_end_iqr_lower": episode_end_iqr_lower,
            }
        except KeyError:
            continue
    return merged_out, merged_out_loss, processed_data, episode_ends

run_eval_during_training(config, run_eval, eval_input, out, best_eval_metric, step_count, first_save=True)

Run evaluation during training and save model if it improves.

Parameters:

Name Type Description Default
config Box

Training configuration.

required
run_eval Callable

Compiled eval function.

required
eval_input Tuple

Initial eval runner state (train_state, env_state, obsv, rng_step, rng_epoch).

required
out Dict

Output from the current training increment.

required
best_eval_metric float

Best eval metric seen so far.

required
step_count int

Current training env step count (for wandb logging).

required
first_save bool

Whether this is the first save of the training run.

True

Returns:

Type Description
Tuple

Tuple of (updated best_eval_metric, updated first_save).

Source code in xlron/train/train_utils.py
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
def run_eval_during_training(
    config: Box,
    run_eval: Callable,
    eval_input: Tuple,
    out: Dict,
    best_eval_metric: float,
    step_count: int,
    first_save: bool = True,
) -> Tuple:
    """Run evaluation during training and save model if it improves.

    Args:
        config: Training configuration.
        run_eval: Compiled eval function.
        eval_input: Initial eval runner state (train_state, env_state, obsv, rng_step, rng_epoch).
        out: Output from the current training increment.
        best_eval_metric: Best eval metric seen so far.
        step_count: Current training env step count (for wandb logging).
        first_save: Whether this is the first save of the training run.

    Returns:
        Tuple of (updated best_eval_metric, updated first_save).
    """
    # Inject current model params into a copy of the eval runner state.
    # Copy the eval inputs to prevent buffer donation from invalidating the
    # original buffers, which need to be reused on subsequent eval runs.
    current_train_state = out["runner_state"][0]
    eval_runner_state = jax.tree.map(
        jnp.copy,
        (
            current_train_state,
            eval_input[1],  # fresh eval env_state
            eval_input[2],  # fresh eval obsv
            eval_input[3],  # eval rng_step
            eval_input[4],  # eval rng_epoch
        ),
    )
    eval_out = run_eval(eval_runner_state)
    eval_out["metrics"]["returns"].block_until_ready()

    # Compute eval blocking probability, discarding the warmup transient.
    # Metrics shape: (NUM_EPISODES, steps_per_episode, NUM_ENVS) if NUM_ENVS > 1
    #                (NUM_EPISODES, steps_per_episode) if NUM_ENVS == 1
    # With continuous_operation, env_state counters (accepted_services,
    # accepted_bitrate, total_bitrate) are cumulative and never reset.
    # LogWrapper's `lengths` resets per episode so can't be used as a
    # cumulative request count — use the step count directly instead.
    def _flatten(metric):
        """Flatten (NUM_EPISODES, steps_per_episode, ...) -> (total_steps, ...) keeping env dim."""
        x = eval_out["metrics"][metric]
        return x.reshape(-1, *x.shape[2:])  # (total_steps,) or (total_steps, NUM_ENVS)

    accepted_services = _flatten("accepted_services")
    accepted_bitrate = _flatten("accepted_bitrate")
    total_bitrate = _flatten("total_bitrate")
    total_steps = accepted_services.shape[0]

    # Warmup index is per-env steps. Clamp to valid range.
    warmup_idx = int(config.ENV_WARMUP_STEPS)

    # Check ENV_WARMUP_STEPS does not exceed total_steps
    assert warmup_idx < total_steps, (
        f"ENV_WARMUP_STEPS ({config.ENV_WARMUP_STEPS}) must be less than "
        f"the total evaluation steps ({total_steps})."
    )

    # Service BP: each step is one request, so denominator = steps after warmup
    # # -2 index avoids reset at end of last episode
    post_warmup_requests = max(total_steps - warmup_idx, 1)
    post_warmup_accepted = accepted_services[-2] - accepted_services[warmup_idx]
    service_bp_per_env = 1 - (post_warmup_accepted / post_warmup_requests)
    service_bp_mean = float(jnp.mean(service_bp_per_env))
    service_bp_std = float(jnp.std(service_bp_per_env))

    # Bitrate BP: denominator is cumulative total_bitrate delta
    post_warmup_total_br = total_bitrate[-2] - total_bitrate[warmup_idx]
    post_warmup_total_br = jnp.where(post_warmup_total_br == 0, 1, post_warmup_total_br)
    post_warmup_accepted_br = accepted_bitrate[-2] - accepted_bitrate[warmup_idx]
    bitrate_bp_per_env = 1 - (post_warmup_accepted_br / post_warmup_total_br)
    bitrate_bp_mean = float(jnp.mean(bitrate_bp_per_env))
    bitrate_bp_std = float(jnp.std(bitrate_bp_per_env))

    if config.reward_type == "bitrate":
        eval_metric_mean = bitrate_bp_mean
        eval_metric_std = bitrate_bp_std
        eval_metric_name = "bitrate_blocking_probability"
    else:
        eval_metric_mean = service_bp_mean
        eval_metric_std = service_bp_std
        eval_metric_name = "service_blocking_probability"

    print(
        f"Eval {eval_metric_name}: {eval_metric_mean:.6f} \u00b1 {eval_metric_std:.6f}"
        f" (best: {best_eval_metric:.6f})"
    )

    if config.WANDB:
        wandb.log(
            {
                "eval/service_blocking_probability_mean": service_bp_mean,
                "eval/service_blocking_probability_std": service_bp_std,
                "eval/bitrate_blocking_probability_mean": bitrate_bp_mean,
                "eval/bitrate_blocking_probability_std": bitrate_bp_std,
                "env_step": step_count,
            }
        )

    if eval_metric_mean <= best_eval_metric:
        best_eval_metric = eval_metric_mean
        print(f"New best eval {eval_metric_name}: {best_eval_metric:.6f}")
        if config.SAVE_MODEL:
            model = eqx.combine(current_train_state.model_params, current_train_state.model_static)
            saved_path = save_model(model, config, first_save=first_save)
            if first_save:
                config.MODEL_PATH = str(saved_path)
                first_save = False

    return best_eval_metric, first_save

scale_gradient(g, scale=1)

Scales the gradient of g by scale but keeps the original value unchanged.

Source code in xlron/train/train_utils.py
219
220
221
def scale_gradient(g: chex.Array, scale: float = 1) -> chex.Array:
    """Scales the gradient of `g` by `scale` but keeps the original value unchanged."""
    return g * scale + jax.lax.stop_gradient(g) * (1.0 - scale)

select_action(select_action_state, env, env_params, train_state, config)

Select an action from the policy. If using VONE, the action is a tuple of (source, path, destination). Otherwise, the action is a single lightpath. Args: select_action_state: Tuple of (rng_key, env_state, last_obs) env: Environment env_params: Environment parameters train_state: TrainState config: Configuration Returns: env_state: Environment state action: Action log_prob: Log probability of action value: Value of state

Source code in xlron/train/train_utils.py
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
def select_action(select_action_state, env, env_params, train_state, config):
    """Select an action from the policy.
    If using VONE, the action is a tuple of (source, path, destination).
    Otherwise, the action is a single lightpath.
    Args:
        select_action_state: Tuple of (rng_key, env_state, last_obs)
        env: Environment
        env_params: Environment parameters
        train_state: TrainState
        config: Configuration
    Returns:
        env_state: Environment state
        action: Action
        log_prob: Log probability of action
        value: Value of state
    """
    action_key, env_state, last_obs = select_action_state
    if config.USE_GNN or config.USE_TRANSFORMER:
        last_obs = (env_state.env_state, env_params)
    model = eqx.combine(train_state.model_params, train_state.model_static)
    pi, value = model(*last_obs)
    # Action masking
    action_mask, full_action_mask = env.action_mask(env_state.env_state, env_params)

    # Always do action masking with VONE
    if config.env_type.lower() == "vone":
        # TODO - change this to work with single set of logits (probably just slice them)
        vmap_mask_nodes = jax.vmap(env.action_mask_nodes, in_axes=(0, None))
        vmap_mask_slots = jax.vmap(env.action_mask_slots, in_axes=(0, None, 0))
        vmap_mask_dest_node = jax.vmap(env.action_mask_dest_node, in_axes=(0, None, 0))

        env_state = env_state.replace(env_state=vmap_mask_nodes(env_state.env_state, env_params))
        pi_source = distrax.Categorical(
            logits=pi[0]._logits + (-1e8 * (1 - env_state.env_state.node_mask_s))
        )

        action_s = (
            pi_source.sample(seed=action_key) if not config.deterministic else pi_source.mode()
        )

        # Update destination mask now source has been selected
        env_state = env_state.replace(
            env_state=vmap_mask_dest_node(env_state.env_state, env_params, action_s)
        )
        pi_dest = distrax.Categorical(
            logits=pi[0]._logits + (-1e8 * (1 - env_state.env_state.node_mask_d))
        )

        action_p = jnp.full(action_s.shape, 0)
        action_d = pi_dest.sample(seed=action_key) if not config.deterministic else pi_dest.mode()
        action = jnp.stack((action_s, action_p, action_d), axis=1)

        env_state = env_state.replace(
            env_state=vmap_mask_slots(env_state.env_state, env_params, action)
        )
        pi_path = distrax.Categorical(
            logits=pi[0]._logits + (-1e8 * (1 - env_state.env_state.link_slot_mask))
        )
        action_p = pi_path.sample(seed=action_key) if not config.deterministic else pi_path.mode()
        action = jnp.stack((action_s, action_p, action_d), axis=1)

        log_prob_source = pi_source.log_prob(action_s)
        log_prob_path = pi_path.log_prob(action_p)
        log_prob_dest = pi_dest.log_prob(action_d)
        log_prob = log_prob_dest + log_prob_path + log_prob_source
        probs = jax.nn.softmax(pi[0]._logits, axis=-1)
        valid_mass = jnp.sum(probs * action_mask, axis=-1)

    elif "gn_model" in config.env_type.lower() and config.launch_power_type == "rl":
        pi_masked = distrax.Categorical(logits=pi[0]._logits + (-1e8 * (1 - action_mask)))
        if config.GNN_OUTPUT_RSA and not config.GNN_OUTPUT_LP:
            path_action, log_prob = train_state.sample_fn(
                action_key, pi_masked, log_prob=True, deterministic=config.deterministic
            )
            power_action = jnp.array([env_params.default_launch_power])
        elif config.GNN_OUTPUT_RSA and config.GNN_OUTPUT_LP:
            path_action, power_action, log_prob = train_state.sample_fn(
                action_key,
                (pi_masked, pi[1]),
                log_prob=True,
                deterministic=config.deterministic,
            )
        else:
            power_action, log_prob = train_state.sample_fn(
                action_key, pi[1], log_prob=True, deterministic=config.deterministic
            )
            inner_state = env_state.env_state.replace(launch_power_array=power_action)
            env_state = env_state.replace(env_state=inner_state)
            path_action = (
                ksp_lf(env_state.env_state, env_params)
                if env_params.last_fit is True
                else ksp_ff(env_state.env_state, env_params)
            )
        inner_state = env_state.env_state.replace(launch_power_array=power_action)
        env_state = env_state.replace(env_state=inner_state)
        if config.output_globals_size_actor == 0:
            path_index, _ = process_path_action(env_state.env_state, env_params, path_action)
            power_action, log_prob = power_action[path_index], log_prob[path_index]
        action = jnp.concatenate([path_action.reshape((1,)), power_action.reshape((1,))], axis=0)
        probs = jax.nn.softmax(pi[0]._logits, axis=-1)
        valid_mass = jnp.sum(probs * action_mask, axis=-1)

    else:
        pi_masked = distrax.Categorical(logits=pi[0]._logits + (-1e8 * (1 - action_mask)))
        action = pi_masked.sample(seed=action_key) if not config.deterministic else pi_masked.mode()
        log_prob = pi_masked.log_prob(action)
        probs = jax.nn.softmax(pi[0]._logits, axis=-1)
        valid_mass = jnp.sum(probs * action_mask, axis=-1)

    # Single state update at the end
    inner_state = env_state.env_state.replace(
        link_slot_mask=action_mask,
        full_link_slot_mask=full_action_mask,
        valid_mass=valid_mass,
    )
    env_state = env_state.replace(env_state=inner_state)

    return env_state, action, log_prob, value

unreplicate_batch_dim(x)

Unreplicated just the update batch dimension. (The dimension that is vmapped over when acting and learning)

In stoix's case it is always the second dimension, after the device dimension. We simply take element 0 as the params are identical across this dimension.

Source code in xlron/train/train_utils.py
438
439
440
441
442
443
444
445
def unreplicate_batch_dim(x: chex.ArrayTree) -> chex.ArrayTree:
    """Unreplicated just the update batch dimension.
    (The dimension that is vmapped over when acting and learning)

    In stoix's case it is always the second dimension, after the device dimension.
    We simply take element 0 as the params are identical across this dimension.
    """
    return jax.tree_util.tree_map(lambda x: x[:, 0, ...], x)

unreplicate_n_dims(x, unreplicate_depth=2)

Unreplicates a pytree by removing the first unreplicate_depth axes.

This function takes a pytree and removes some number of axes, associated with parameter duplication for running multiple updates across devices and in parallel with vmap. This is typically one axis for device replication, and one for the update batch size.

Source code in xlron/train/train_utils.py
428
429
430
431
432
433
434
435
def unreplicate_n_dims(x: chex.ArrayTree, unreplicate_depth: int = 2) -> chex.ArrayTree:
    """Unreplicates a pytree by removing the first `unreplicate_depth` axes.

    This function takes a pytree and removes some number of axes, associated with parameter
    duplication for running multiple updates across devices and in parallel with `vmap`.
    This is typically one axis for device replication, and one for the `update batch size`.
    """
    return jax.tree_util.tree_map(lambda x: x[(0,) * unreplicate_depth], x)

Models

ActorCriticMLP

Bases: Module

Actor-Critic MLP using Equinox.

Source code in xlron/models/mlp.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
class ActorCriticMLP(eqx.Module):
    """Actor-Critic MLP using Equinox."""

    actor: eqx.Module
    critic: eqx.Module
    activation_fn: Callable
    temperature: float

    def __init__(
        self,
        action_dim: int,
        input_dim: int,
        activation: str = "tanh",
        num_layers: int = 2,
        num_units: int = 64,
        layer_norm: bool = False,  # Not used for now, can add eqx.nn.LayerNorm later
        temperature: float = 1.0,
        dropout_rate: float = 0.0,
        deterministic: bool = True,
        *,
        key: Array,
    ):
        actor_key, critic_key = jax.random.split(key)
        self.activation_fn = select_activation(activation)
        self.temperature = temperature

        # Build actor layers
        actor_features = [num_units] * num_layers + [action_dim]
        self.actor = MLP(
            actor_features,
            input_dim,
            activation=activation,
            layer_norm=layer_norm,
            dropout_rate=dropout_rate,
            deterministic=deterministic,
            key=actor_key,
        )
        critic_features = [num_units] * num_layers + [1]
        self.critic = MLP(
            critic_features,
            input_dim,
            activation=activation,
            layer_norm=layer_norm,
            dropout_rate=dropout_rate,
            deterministic=deterministic,
            key=actor_key,
        )

    def __call__(self, x: Array, key: Optional[Array] = None) -> Tuple[distrax.Categorical, Array]:
        # Actor forward pass
        actor_key, critic_key = jax.random.split(key) if key else (None, None)
        logits = self.actor(x, key=actor_key) / self.temperature
        action_dist = distrax.Categorical(logits=logits)
        value = self.critic(x, key=critic_key)
        return action_dist, jnp.squeeze(value, axis=-1)

    def sample_action(
        self,
        seed: chex.PRNGKey,
        dist: distrax.Categorical,
        log_prob: bool = False,
        deterministic: bool = False,
    ) -> Union[Array, Tuple[Array, Array]]:
        """Sample an action from the distribution"""
        action = jnp.argmax(dist.probs()) if deterministic else dist.sample(seed=seed)
        if log_prob:
            return action, dist.log_prob(action)
        return action

sample_action(seed, dist, log_prob=False, deterministic=False)

Sample an action from the distribution

Source code in xlron/models/mlp.py
229
230
231
232
233
234
235
236
237
238
239
240
def sample_action(
    self,
    seed: chex.PRNGKey,
    dist: distrax.Categorical,
    log_prob: bool = False,
    deterministic: bool = False,
) -> Union[Array, Tuple[Array, Array]]:
    """Sample an action from the distribution"""
    action = jnp.argmax(dist.probs()) if deterministic else dist.sample(seed=seed)
    if log_prob:
        return action, dist.log_prob(action)
    return action

LaunchPowerActorCriticMLP

Bases: Module

Actor-Critic MLP for launch power optimization.

Takes an observation of the current request + statistics on each of the K candidate paths. Makes K forward passes, one for each path, and outputs a distribution over power levels for each path.

Source code in xlron/models/mlp.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
class LaunchPowerActorCriticMLP(eqx.Module):
    """Actor-Critic MLP for launch power optimization.

    Takes an observation of the current request + statistics on each of the K candidate paths.
    Makes K forward passes, one for each path, and outputs a distribution over power levels for each path.
    """

    # For continuous action space (Beta distribution)
    alpha_out: Optional[eqx.nn.Linear]
    beta_out: Optional[eqx.nn.Linear]

    # Static configuration
    activation: str = eqx.field(static=True)
    layer_norm: bool = eqx.field(static=True)
    min_power_dbm: float = eqx.field(static=True)
    max_power_dbm: float = eqx.field(static=True)
    step_power_dbm: float = eqx.field(static=True)
    discrete: bool = eqx.field(static=True)
    temperature: float = eqx.field(static=True)
    k_paths: int = eqx.field(static=True)
    num_base_features: int = eqx.field(static=True)
    num_path_features: int = eqx.field(static=True)
    min_concentration: float = eqx.field(static=True)
    max_concentration: float = eqx.field(static=True)
    epsilon: float = eqx.field(static=True)

    def __init__(
        self,
        action_dim: Sequence[int],
        input_dim: int,
        activation: str = "tanh",
        num_layers: int = 2,
        num_units: int = 64,
        layer_norm: bool = False,
        min_power_dbm: float = 0.0,
        max_power_dbm: float = 2.0,
        step_power_dbm: float = 0.1,
        discrete: bool = True,
        temperature: float = 1.0,
        k_paths: int = 5,
        num_base_features: int = 4,
        num_path_features: int = 7,
        min_concentration: float = 0.1,
        max_concentration: float = 20.0,
        epsilon: float = 1e-6,
        *,
        key: Array,
    ):
        super().__init__()
        self.activation = activation
        self.layer_norm = layer_norm
        self.min_power_dbm = min_power_dbm
        self.max_power_dbm = max_power_dbm
        self.step_power_dbm = step_power_dbm
        self.discrete = discrete
        self.temperature = temperature
        self.k_paths = k_paths
        self.num_base_features = num_base_features
        self.num_path_features = num_path_features
        self.min_concentration = min_concentration
        self.max_concentration = max_concentration
        self.epsilon = epsilon

        actor_key, critic_key, output_key = jax.random.split(key, 3)

        # Build actor layers
        actor_keys = jax.random.split(actor_key, num_layers)
        actor_layers_list = []
        # Input is base features + path features
        current_in = num_base_features + num_path_features
        for i in range(num_layers):
            linear = make_linear_with_orthogonal_init(
                current_in, num_units, actor_keys[i], scale=np.sqrt(2)
            )
            actor_layers_list.append(linear)
            if layer_norm:
                actor_layers_list.append(eqx.nn.LayerNorm(num_units))
            current_in = num_units
        self.actor_layers = tuple(actor_layers_list)

        # Actor output
        out_key1, out_key2, out_key3 = jax.random.split(output_key, 3)
        if discrete:
            num_power_levels = int((max_power_dbm - min_power_dbm) / step_power_dbm) + 1
            self.actor_output = make_linear_with_orthogonal_init(
                num_units, num_power_levels, out_key1, scale=0.01
            )
            self.alpha_out = None
            self.beta_out = None
        else:
            self.actor_output = None  # Not used for continuous
            self.alpha_out = make_linear_with_orthogonal_init(num_units, 1, out_key2, scale=0.01)
            self.beta_out = make_linear_with_orthogonal_init(num_units, 1, out_key3, scale=0.01)

    @property
    def num_power_levels(self):
        """Calculate number of power levels dynamically"""
        return int((self.max_power_dbm - self.min_power_dbm) / self.step_power_dbm) + 1

    @property
    def power_levels(self):
        """Calculate power levels dynamically"""
        return jnp.linspace(
            self.min_power_dbm,
            self.max_power_dbm,
            self.num_power_levels,
            dtype=dtype_config.SMALL_FLOAT_DTYPE,
        )

    def _activate(self, x):
        if self.activation == "relu":
            return jax.nn.relu(x)
        elif self.activation == "crelu":
            return crelu(x)
        return jnp.tanh(x)

    def _forward_layers(self, x, layers):
        for layer in layers:
            x = layer(x)
            if isinstance(layer, eqx.nn.Linear):
                x = self._activate(x)
        return x

    def __call__(self, x: Array) -> Tuple[Tuple[None, distrax.Distribution], Array]:
        # Process each path
        def process_path(i):
            base = x[: self.num_base_features]
            path = jax.lax.dynamic_slice(
                x, (self.num_base_features + i * self.num_path_features,), (self.num_path_features,)
            )
            features = jnp.concatenate([base, path])
            actor_hidden = self._forward_layers(features, self.actor_layers)

            if self.discrete:
                return self.actor_output(actor_hidden) / self.temperature
            else:
                alpha = self.min_concentration + jax.nn.softplus(self.alpha_out(actor_hidden)) * (
                    self.max_concentration - self.min_concentration
                )
                beta = self.min_concentration + jax.nn.softplus(self.beta_out(actor_hidden)) * (
                    self.max_concentration - self.min_concentration
                )
                return jnp.concatenate([alpha, beta])

        # Use vmap instead of scan for simpler Equinox pattern
        dist_params = jax.vmap(process_path)(
            jnp.arange(self.k_paths, dtype=dtype_config.MED_INT_DTYPE)
        )

        # Critic forward pass
        critic_hidden = self._forward_layers(x, self.critic_layers)
        value = jnp.squeeze(self.critic_output(critic_hidden), axis=-1)

        # Create distribution
        if self.discrete:
            dist = distrax.Categorical(logits=dist_params)
        else:
            alpha = dist_params[:, 0]
            beta = dist_params[:, 1]
            dist = distrax.Beta(alpha, beta)

        return (None, dist), value

    def sample_action(self, seed, dist, log_prob=False, deterministic=False):
        """Sample an action and convert to power level"""
        if self.discrete:
            if deterministic:
                raw_action = dist.mode()
            else:
                raw_action = dist.sample(seed=seed)
            processed_action = self.power_levels[raw_action].reshape((self.k_paths, 1))
        else:
            if deterministic:
                mean = dist.alpha / (dist.alpha + dist.beta)
                raw_action = jnp.clip(mean, self.epsilon, 1.0 - self.epsilon)
            else:
                raw_action = jnp.clip(dist.sample(seed=seed), self.epsilon, 1.0 - self.epsilon)
            processed_action = self.min_power_dbm + raw_action * (
                self.max_power_dbm - self.min_power_dbm
            )
        processed_action = from_dbm(processed_action)
        if log_prob:
            return processed_action, dist.log_prob(jnp.squeeze(raw_action))
        return processed_action

    def get_action_probs(self, dist):
        """Get probabilities for discrete case or pdf for continuous case"""
        if self.discrete:
            return dist.probs()
        else:
            x = jnp.linspace(0, 1, 100)
            return dist.prob(x)

num_power_levels property

Calculate number of power levels dynamically

power_levels property

Calculate power levels dynamically

get_action_probs(dist)

Get probabilities for discrete case or pdf for continuous case

Source code in xlron/models/mlp.py
428
429
430
431
432
433
434
def get_action_probs(self, dist):
    """Get probabilities for discrete case or pdf for continuous case"""
    if self.discrete:
        return dist.probs()
    else:
        x = jnp.linspace(0, 1, 100)
        return dist.prob(x)

sample_action(seed, dist, log_prob=False, deterministic=False)

Sample an action and convert to power level

Source code in xlron/models/mlp.py
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
def sample_action(self, seed, dist, log_prob=False, deterministic=False):
    """Sample an action and convert to power level"""
    if self.discrete:
        if deterministic:
            raw_action = dist.mode()
        else:
            raw_action = dist.sample(seed=seed)
        processed_action = self.power_levels[raw_action].reshape((self.k_paths, 1))
    else:
        if deterministic:
            mean = dist.alpha / (dist.alpha + dist.beta)
            raw_action = jnp.clip(mean, self.epsilon, 1.0 - self.epsilon)
        else:
            raw_action = jnp.clip(dist.sample(seed=seed), self.epsilon, 1.0 - self.epsilon)
        processed_action = self.min_power_dbm + raw_action * (
            self.max_power_dbm - self.min_power_dbm
        )
    processed_action = from_dbm(processed_action)
    if log_prob:
        return processed_action, dist.log_prob(jnp.squeeze(raw_action))
    return processed_action

MLP

Bases: Module

Simple MLP module using Equinox.

Source code in xlron/models/mlp.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class MLP(eqx.Module):
    """Simple MLP module using Equinox."""

    layers: tuple
    activation_fn: Callable = eqx.field(static=True)
    dropout_rate: float = eqx.field(static=True)
    deterministic: bool = eqx.field(static=True)

    def __init__(
        self,
        features: Sequence[int],
        in_features: int,
        activation: str = "tanh",
        dropout_rate: float = 0.0,
        deterministic: bool = True,
        layer_norm: bool = False,
        *,
        key: Array,
    ):
        self.activation_fn = select_activation(activation)
        self.dropout_rate = dropout_rate
        self.deterministic = deterministic

        layers_list = []
        keys = jax.random.split(key, len(features))
        current_in = in_features

        for i, out_features in enumerate(features):
            linear = make_linear_with_orthogonal_init(
                current_in, out_features, keys[i], scale=np.sqrt(2)
            )
            layers_list.append(linear)
            if layer_norm and i < len(features) - 1:
                layers_list.append(eqx.nn.LayerNorm(out_features))
            current_in = out_features

        self.layers = tuple(layers_list)

    def __call__(self, x: Array, *, key: Optional[Array] = None) -> Array:
        for i, layer in enumerate(self.layers):
            if isinstance(layer, eqx.nn.Linear):
                x = layer(x)
                # Apply activation for all but the last linear layer
                if i < len(self.layers) - 1:
                    x = self.activation_fn(x)
                    if not self.deterministic and self.dropout_rate > 0 and key is not None:
                        key, subkey = jax.random.split(key)
                        x = eqx.nn.Dropout(self.dropout_rate)(x, key=subkey)
            elif isinstance(layer, eqx.nn.LayerNorm):
                x = layer(x)
        return x

bfloat16_safe_orthogonal(scale=1.0)

Returns an orthogonal initializer that is safe for bfloat16.

Source code in xlron/models/mlp.py
82
83
84
85
86
87
88
def bfloat16_safe_orthogonal(scale: float = 1.0) -> Callable:
    """Returns an orthogonal initializer that is safe for bfloat16."""

    def init(key: Array, shape: Sequence[int], dtype: jnp.dtype = jnp.float32) -> Array:
        return orthogonal_init(key, tuple(shape), scale=scale, dtype=dtype)

    return init

constant(value)

Returns a constant initializer.

Source code in xlron/models/mlp.py
91
92
93
94
95
96
97
def constant(value: float) -> Callable:
    """Returns a constant initializer."""

    def init(key: Array, shape: Sequence[int], dtype: jnp.dtype = jnp.float32) -> Array:
        return jnp.full(shape, value, dtype=dtype)

    return init

crelu(x)

Computes the Concatenated ReLU (CReLU) activation function.

Source code in xlron/models/mlp.py
66
67
68
69
def crelu(x: ArrayLike) -> Array:
    """Computes the Concatenated ReLU (CReLU) activation function."""
    x = jnp.concatenate([x, -x], axis=-1)
    return jax.nn.relu(x)

make_linear_with_orthogonal_init(in_features, out_features, key, scale=1.0, dtype=None)

Create a Linear layer with orthogonal initialization.

Source code in xlron/models/mlp.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def make_linear_with_orthogonal_init(
    in_features: int,
    out_features: int,
    key: Array,
    scale: float = 1.0,
    dtype: jnp.dtype = None,
) -> eqx.nn.Linear:
    """Create a Linear layer with orthogonal initialization."""
    if dtype is None:
        dtype = dtype_config.PARAMS_DTYPE

    key1, key2 = jax.random.split(key)
    weight = orthogonal_init(key1, (out_features, in_features), scale=scale, dtype=dtype)
    bias = jnp.zeros(out_features, dtype=dtype)

    linear = eqx.nn.Linear(in_features, out_features, key=key2, dtype=dtype)
    linear = eqx.tree_at(lambda layer: (layer.weight, layer.bias), linear, (weight, bias))
    return linear

orthogonal_init(key, shape, scale=1.0, dtype=jnp.float32)

Orthogonal initializer that is safe for bfloat16 and other dtypes. Based on JAX/Flax orthogonal initializer.

Parameters:

Name Type Description Default
key Array

PRNGKey for initialization

required
shape Tuple[int, ...]

Shape of the weight matrix

required
scale float

Scaling factor for the orthogonal matrix

1.0
dtype dtype

Target dtype (will init in float32 then cast)

float32

Returns:

Type Description
Array

Orthogonally initialized weight matrix

Source code in xlron/models/mlp.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def orthogonal_init(
    key: Array, shape: Tuple[int, ...], scale: float = 1.0, dtype: jnp.dtype = jnp.float32
) -> Array:
    """
    Orthogonal initializer that is safe for bfloat16 and other dtypes.
    Based on JAX/Flax orthogonal initializer.

    Args:
        key: PRNGKey for initialization
        shape: Shape of the weight matrix
        scale: Scaling factor for the orthogonal matrix
        dtype: Target dtype (will init in float32 then cast)

    Returns:
        Orthogonally initialized weight matrix
    """
    # Always initialize in float32 to avoid dtype issues
    if len(shape) < 2:
        raise ValueError("Orthogonal initialization requires at least 2D shape")

    num_rows, num_cols = shape[-2], shape[-1]
    flat_shape = (max(num_rows, num_cols), min(num_rows, num_cols))

    # Generate random matrix
    a = jax.random.normal(key, flat_shape, dtype=jnp.float32)

    # QR decomposition
    q, r = jnp.linalg.qr(a)
    d = jnp.diag(r)
    q = q * jnp.sign(d)

    # If num_rows < num_cols, we need to transpose
    if num_rows < num_cols:
        q = q.T

    # Take the slice we need
    q = q[:num_rows, :num_cols]

    # Reshape if needed
    if len(shape) > 2:
        q = q.reshape(shape)

    # Scale and cast
    return (scale * q).astype(dtype)

select_activation(activation)

Selects the activation function based on the provided string.

Source code in xlron/models/mlp.py
72
73
74
75
76
77
78
79
def select_activation(activation: str) -> Callable[[ArrayLike], Array]:
    """Selects the activation function based on the provided string."""
    if activation == "relu":
        return jax.nn.relu
    elif activation == "crelu":
        return crelu
    else:
        return jax.nn.tanh  # Default to tanh if no valid activation is specified

ActorCriticGNN

Bases: Module

Combined Actor-Critic GNN model.

Source code in xlron/models/gnn.py
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
class ActorCriticGNN(eqx.Module):
    """Combined Actor-Critic GNN model."""

    actor: ActorGNN
    critic: CriticGNN

    # Static configuration
    vmap: bool = eqx.field(static=True)
    min_power_dbm: float = eqx.field(static=True)
    max_power_dbm: float = eqx.field(static=True)
    step_power_dbm: float = eqx.field(static=True)
    discrete: bool = eqx.field(static=True)
    epsilon: float = eqx.field(static=True)
    output_path: bool = eqx.field(static=True)
    output_power: bool = eqx.field(static=True)

    def __init__(
        self,
        input_edge_features: int,
        input_node_features: int,
        input_global_features: int,
        activation: str = "tanh",
        num_layers: int = 2,
        num_units: int = 64,
        message_passing_steps: int = 1,
        mlp_layers: int = None,
        mlp_latent: int = None,
        edge_embedding_size: int = 128,
        edge_mlp_layers: int = 3,
        edge_mlp_latent: int = 128,
        edge_output_size_actor: int = 1,
        edge_output_size_critic: int = 1,
        global_embedding_size: int = 8,
        global_mlp_layers: int = 0,
        global_mlp_latent: int = 0,
        global_output_size_actor: int = 0,
        global_output_size_critic: int = 0,
        node_embedding_size: int = 16,
        node_mlp_layers: int = 2,
        node_mlp_latent: int = 128,
        node_output_size_actor: int = 0,
        node_output_size_critic: int = 0,
        attn_mlp_layers: int = 2,
        attn_mlp_latent: int = 128,
        gnn_mlp_layers: int = 1,
        use_attention: bool = True,
        normalise_by_link_length: bool = True,
        gnn_layer_norm: bool = True,
        mlp_layer_norm: bool = False,
        vmap: bool = True,
        temperature: float = 1.0,
        min_power_dbm: float = 0.0,
        max_power_dbm: float = 2.0,
        step_power_dbm: float = 0.1,
        discrete: bool = True,
        min_concentration: float = 0.1,
        max_concentration: float = 20.0,
        epsilon: float = 1e-6,
        output_path: bool = True,
        output_power: bool = True,
        *,
        key: Array,
    ):
        assert edge_output_size_actor > 0
        assert edge_output_size_critic + global_output_size_critic > 0

        self.vmap = vmap
        self.min_power_dbm = min_power_dbm
        self.max_power_dbm = max_power_dbm
        self.step_power_dbm = step_power_dbm
        self.discrete = discrete
        self.epsilon = epsilon
        self.output_path = output_path
        self.output_power = output_power

        actor_key, critic_key = jax.random.split(key)

        self.actor = ActorGNN(
            input_edge_features=input_edge_features,
            input_node_features=input_node_features,
            input_global_features=input_global_features,
            num_layers=num_layers,
            num_units=num_units,
            message_passing_steps=message_passing_steps,
            mlp_layers=mlp_layers,
            mlp_latent=mlp_latent,
            edge_embedding_size=edge_embedding_size,
            edge_mlp_layers=edge_mlp_layers,
            edge_mlp_latent=edge_mlp_latent,
            edge_output_size=edge_output_size_actor,
            global_embedding_size=global_embedding_size,
            global_mlp_layers=global_mlp_layers,
            global_mlp_latent=global_mlp_latent,
            global_output_size=global_output_size_actor,
            node_embedding_size=node_embedding_size,
            node_mlp_layers=node_mlp_layers,
            node_mlp_latent=node_mlp_latent,
            node_output_size=node_output_size_actor,
            attn_mlp_layers=attn_mlp_layers,
            attn_mlp_latent=attn_mlp_latent,
            use_attention=use_attention,
            normalise_by_link_length=normalise_by_link_length,
            gnn_layer_norm=gnn_layer_norm,
            mlp_layer_norm=mlp_layer_norm,
            temperature=temperature,
            min_power_dbm=min_power_dbm,
            max_power_dbm=max_power_dbm,
            step_power_dbm=step_power_dbm,
            discrete=discrete,
            min_concentration=min_concentration,
            max_concentration=max_concentration,
            epsilon=epsilon,
            key=actor_key,
        )

        self.critic = CriticGNN(
            input_edge_features=input_edge_features,
            input_node_features=input_node_features,
            input_global_features=input_global_features,
            activation=activation,
            num_layers=num_layers,
            num_units=num_units,
            message_passing_steps=message_passing_steps,
            mlp_layers=mlp_layers,
            mlp_latent=mlp_latent,
            edge_embedding_size=edge_embedding_size,
            edge_mlp_layers=edge_mlp_layers,
            edge_mlp_latent=edge_mlp_latent,
            edge_output_size=edge_output_size_critic,
            global_embedding_size=global_embedding_size,
            global_mlp_layers=global_mlp_layers,
            global_mlp_latent=global_mlp_latent,
            global_output_size=global_output_size_critic,
            node_embedding_size=node_embedding_size,
            node_mlp_layers=node_mlp_layers,
            node_mlp_latent=node_mlp_latent,
            node_output_size=node_output_size_critic,
            attn_mlp_layers=attn_mlp_layers,
            attn_mlp_latent=attn_mlp_latent,
            use_attention=use_attention,
            normalise_by_link_length=normalise_by_link_length,
            gnn_layer_norm=gnn_layer_norm,
            mlp_layer_norm=mlp_layer_norm,
            key=critic_key,
        )

    @property
    def num_power_levels(self):
        return int((self.max_power_dbm - self.min_power_dbm) / self.step_power_dbm) + 1

    @property
    def power_levels(self):
        return jnp.linspace(
            self.min_power_dbm,
            self.max_power_dbm,
            self.num_power_levels,
            dtype=dtype_config.LARGE_FLOAT_DTYPE,
        )

    def __call__(self, state: EnvState, params: EnvParams):
        if self.vmap:
            actor_fn = jax.vmap(self.actor, in_axes=(0, None))
            critic_fn = jax.vmap(self.critic, in_axes=(0, None))
        else:
            actor_fn = self.actor
            critic_fn = self.critic

        actor_out = actor_fn(state, params)
        critic_out = critic_fn(state, params)
        return actor_out, critic_out

    def sample_action_path(self, seed, dist, log_prob=False, deterministic=False):
        """Sample an action from the distribution."""
        action = (
            jnp.argmax(dist.probs()).astype(dtype_config.MED_INT_DTYPE)
            if deterministic
            else dist.sample(seed=seed)
        )
        if log_prob:
            return action, dist.log_prob(action)
        return action

    def sample_action_power(self, seed, dist, log_prob=False, deterministic=False):
        """Sample an action and convert to power level"""
        if self.discrete:
            if deterministic:
                raw_action = dist.mode()
            else:
                raw_action = dist.sample(seed=seed)
            processed_action = self.power_levels[raw_action]
        else:
            if deterministic:
                mean = dist.alpha / (dist.alpha + dist.beta)
                raw_action = jnp.clip(mean, self.epsilon, 1.0 - self.epsilon)
            else:
                raw_action = jnp.clip(dist.sample(seed=seed), self.epsilon, 1.0 - self.epsilon)
            processed_action = self.min_power_dbm + raw_action * (
                self.max_power_dbm - self.min_power_dbm
            )
        processed_action = from_dbm(processed_action)
        if log_prob:
            return processed_action, dist.log_prob(raw_action)
        return processed_action

    def sample_action_path_power(self, seed, dist, log_prob=False, deterministic=False):
        """Sample an action from the distributions."""
        path_action = self.sample_action_path(
            seed, dist[0], log_prob=log_prob, deterministic=deterministic
        )
        power_action = self.sample_action_power(
            seed, dist[1], log_prob=log_prob, deterministic=deterministic
        )
        if log_prob:
            return path_action[0], power_action[0], path_action[1] + power_action[1]
        return path_action, power_action

    def sample_action(self, seed, dist, log_prob=False, deterministic=False):
        """Sample an action from the distributions."""
        if self.output_path and self.output_power:
            return self.sample_action_path_power(
                seed, dist, log_prob=log_prob, deterministic=deterministic
            )
        elif self.output_path:
            return self.sample_action_path(
                seed, dist, log_prob=log_prob, deterministic=deterministic
            )
        elif self.output_power:
            return self.sample_action_power(
                seed, dist, log_prob=log_prob, deterministic=deterministic
            )
        else:
            raise ValueError("No action type specified for sampling.")

sample_action(seed, dist, log_prob=False, deterministic=False)

Sample an action from the distributions.

Source code in xlron/models/gnn.py
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
def sample_action(self, seed, dist, log_prob=False, deterministic=False):
    """Sample an action from the distributions."""
    if self.output_path and self.output_power:
        return self.sample_action_path_power(
            seed, dist, log_prob=log_prob, deterministic=deterministic
        )
    elif self.output_path:
        return self.sample_action_path(
            seed, dist, log_prob=log_prob, deterministic=deterministic
        )
    elif self.output_power:
        return self.sample_action_power(
            seed, dist, log_prob=log_prob, deterministic=deterministic
        )
    else:
        raise ValueError("No action type specified for sampling.")

sample_action_path(seed, dist, log_prob=False, deterministic=False)

Sample an action from the distribution.

Source code in xlron/models/gnn.py
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
def sample_action_path(self, seed, dist, log_prob=False, deterministic=False):
    """Sample an action from the distribution."""
    action = (
        jnp.argmax(dist.probs()).astype(dtype_config.MED_INT_DTYPE)
        if deterministic
        else dist.sample(seed=seed)
    )
    if log_prob:
        return action, dist.log_prob(action)
    return action

sample_action_path_power(seed, dist, log_prob=False, deterministic=False)

Sample an action from the distributions.

Source code in xlron/models/gnn.py
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
def sample_action_path_power(self, seed, dist, log_prob=False, deterministic=False):
    """Sample an action from the distributions."""
    path_action = self.sample_action_path(
        seed, dist[0], log_prob=log_prob, deterministic=deterministic
    )
    power_action = self.sample_action_power(
        seed, dist[1], log_prob=log_prob, deterministic=deterministic
    )
    if log_prob:
        return path_action[0], power_action[0], path_action[1] + power_action[1]
    return path_action, power_action

sample_action_power(seed, dist, log_prob=False, deterministic=False)

Sample an action and convert to power level

Source code in xlron/models/gnn.py
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
def sample_action_power(self, seed, dist, log_prob=False, deterministic=False):
    """Sample an action and convert to power level"""
    if self.discrete:
        if deterministic:
            raw_action = dist.mode()
        else:
            raw_action = dist.sample(seed=seed)
        processed_action = self.power_levels[raw_action]
    else:
        if deterministic:
            mean = dist.alpha / (dist.alpha + dist.beta)
            raw_action = jnp.clip(mean, self.epsilon, 1.0 - self.epsilon)
        else:
            raw_action = jnp.clip(dist.sample(seed=seed), self.epsilon, 1.0 - self.epsilon)
        processed_action = self.min_power_dbm + raw_action * (
            self.max_power_dbm - self.min_power_dbm
        )
    processed_action = from_dbm(processed_action)
    if log_prob:
        return processed_action, dist.log_prob(raw_action)
    return processed_action

ActorGNN

Bases: Module

Actor network using GNN for processing graph state.

Source code in xlron/models/gnn.py
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
class ActorGNN(eqx.Module):
    """Actor network using GNN for processing graph state."""

    graph_net: GraphNet
    power_mlp: Optional[eqx.nn.MLP]

    # Static configuration
    activation: str = eqx.field(static=True)
    edge_output_size: int = eqx.field(static=True)
    global_output_size: int = eqx.field(static=True)
    normalise_by_link_length: bool = eqx.field(static=True)
    temperature: float = eqx.field(static=True)
    min_power_dbm: float = eqx.field(static=True)
    max_power_dbm: float = eqx.field(static=True)
    step_power_dbm: float = eqx.field(static=True)
    discrete: bool = eqx.field(static=True)
    min_concentration: float = eqx.field(static=True)
    max_concentration: float = eqx.field(static=True)
    epsilon: float = eqx.field(static=True)

    def __init__(
        self,
        input_edge_features: int,
        input_node_features: int,
        input_global_features: int,
        activation: str = "tanh",
        num_layers: int = 2,
        num_units: int = 64,
        mlp_layers: int = None,
        mlp_latent: int = None,
        edge_embedding_size: int = 128,
        edge_mlp_layers: int = 3,
        edge_mlp_latent: int = 128,
        edge_output_size: int = 0,
        global_embedding_size: int = 8,
        global_mlp_layers: int = 0,
        global_mlp_latent: int = 0,
        global_output_size: int = 0,
        node_embedding_size: int = 16,
        node_mlp_layers: int = 2,
        node_mlp_latent: int = 128,
        node_output_size: int = 0,
        attn_mlp_layers: int = 2,
        attn_mlp_latent: int = 128,
        dropout_rate: float = 0,
        deterministic: bool = False,
        message_passing_steps: int = 1,
        use_attention: bool = True,
        normalise_by_link_length: bool = True,
        gnn_layer_norm: bool = True,
        mlp_layer_norm: bool = False,
        temperature: float = 1.0,
        min_power_dbm: float = 0.0,
        max_power_dbm: float = 2.0,
        step_power_dbm: float = 0.1,
        discrete: bool = True,
        min_concentration: float = 0.1,
        max_concentration: float = 20.0,
        epsilon: float = 1e-6,
        *,
        key: Array,
    ):
        self.activation = activation
        self.edge_output_size = edge_output_size
        self.global_output_size = global_output_size
        self.normalise_by_link_length = normalise_by_link_length
        self.temperature = temperature
        self.min_power_dbm = min_power_dbm
        self.max_power_dbm = max_power_dbm
        self.step_power_dbm = step_power_dbm
        self.discrete = discrete
        self.min_concentration = min_concentration
        self.max_concentration = max_concentration
        self.epsilon = epsilon

        gnn_key, mlp_key = jax.random.split(key)

        self.graph_net = GraphNet(
            input_edge_features=input_edge_features,
            input_node_features=input_node_features,
            input_global_features=input_global_features,
            message_passing_steps=message_passing_steps,
            mlp_layers=mlp_layers,
            mlp_latent=mlp_latent,
            edge_embedding_size=edge_embedding_size,
            edge_mlp_layers=edge_mlp_layers,
            edge_mlp_latent=edge_mlp_latent,
            edge_output_size=edge_output_size,
            global_embedding_size=global_embedding_size,
            global_mlp_layers=global_mlp_layers,
            global_mlp_latent=global_mlp_latent,
            global_output_size=global_output_size,
            node_embedding_size=node_embedding_size,
            node_mlp_layers=node_mlp_layers,
            node_mlp_latent=node_mlp_latent,
            node_output_size=node_output_size,
            attn_mlp_layers=attn_mlp_layers,
            attn_mlp_latent=attn_mlp_latent,
            dropout_rate=dropout_rate,
            use_attention=use_attention,
            gnn_layer_norm=gnn_layer_norm,
            mlp_layer_norm=mlp_layer_norm,
            deterministic=deterministic,
            key=gnn_key,
        )

        # Power MLP
        num_power_levels = int((max_power_dbm - min_power_dbm) / step_power_dbm) + 1
        output_size = num_power_levels if discrete else 2
        edge_feat_size = edge_output_size if edge_output_size > 0 else edge_embedding_size

        self.power_mlp = eqx.nn.MLP(
            in_size=edge_feat_size,
            out_size=output_size,
            width_size=num_units,
            depth=num_layers,
            activation=select_activation(activation),
            key=mlp_key,
        )

    @property
    def num_power_levels(self):
        return int((self.max_power_dbm - self.min_power_dbm) / self.step_power_dbm) + 1

    @property
    def power_levels(self):
        return jnp.linspace(
            self.min_power_dbm,
            self.max_power_dbm,
            self.num_power_levels,
            dtype=dtype_config.SMALL_FLOAT_DTYPE,
        )

    def __call__(
        self, state: EnvState, params: EnvParams
    ) -> distrax.Distribution | Tuple[distrax.Distribution, Optional[distrax.Distribution]]:
        processed_graph = self.graph_net(state.graph)

        # Index edge features
        edge_features = (
            processed_graph.edges
            if params.directed_graph
            else processed_graph.edges[: len(processed_graph.edges) // 2]
        )
        if self.normalise_by_link_length:
            edge_features = edge_features * (
                params.link_length_array.val
                / jnp.sum(params.link_length_array.val, promote_integers=False)
            )

        # Get current request
        nodes_sd, requested_bw = read_rsa_request(state.request_array)
        init_action_array = jnp.zeros(
            params.k_paths * self.edge_output_size, dtype=dtype_config.SMALL_FLOAT_DTYPE
        )

        def get_path_action_dist(i, action_array):
            path_features = get_path_slots(edge_features, params, nodes_sd, i, agg_func="sum")
            action_array = jax.lax.dynamic_update_slice(
                action_array, path_features, (i * self.edge_output_size,)
            )
            return action_array

        path_action_logits = jax.lax.fori_loop(
            0, params.k_paths, get_path_action_dist, init_action_array
        )
        if params.include_no_op:
            path_action_logits = jnp.hstack([path_action_logits, jnp.array([-1e4])])
        path_action_logits = jnp.reshape(path_action_logits, (-1,)) / self.temperature
        path_action_dist = distrax.Categorical(logits=path_action_logits)

        power_action_dist = None
        if params.__class__.__name__ == "RSAGNModelEnvParams":
            if self.global_output_size > 0:
                power_logits = processed_graph.globals.reshape((-1,)) / self.temperature
            else:
                init_feature_array = jnp.zeros(
                    (params.k_paths, edge_features.shape[1]), dtype=dtype_config.LARGE_FLOAT_DTYPE
                )

                def get_power_action_dist(i, feature_array):
                    path_features = get_path_slots(
                        edge_features, params, nodes_sd, i, agg_func="sum"
                    ).reshape((1, -1))
                    feature_array = jax.lax.dynamic_update_slice(
                        feature_array, path_features, (i, 0)
                    )
                    return feature_array

                path_feature_batch = jax.lax.fori_loop(
                    0, params.k_paths, get_power_action_dist, init_feature_array
                )
                power_logits = jax.vmap(self.power_mlp)(path_feature_batch)

            if self.discrete:
                power_action_dist = distrax.Categorical(logits=power_logits)
            else:
                alpha = self.min_concentration + jax.nn.softplus(power_logits) * (
                    self.max_concentration - self.min_concentration
                )
                beta = self.min_concentration + jax.nn.softplus(power_logits) * (
                    self.max_concentration - self.min_concentration
                )
                power_action_dist = distrax.Beta(alpha, beta)

            return (path_action_dist, power_action_dist)

        return path_action_dist

CriticGNN

Bases: Module

Critic network using GNN for processing graph state.

Source code in xlron/models/gnn.py
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
class CriticGNN(eqx.Module):
    """Critic network using GNN for processing graph state."""

    graph_net: GraphNet
    critic_mlp: eqx.nn.MLP
    critic_output: eqx.nn.Linear

    # Static configuration
    activation: str = eqx.field(static=True)
    global_output_size: int = eqx.field(static=True)
    normalise_by_link_length: bool = eqx.field(static=True)

    def __init__(
        self,
        input_edge_features: int,
        input_node_features: int,
        input_global_features: int,
        activation: str = "tanh",
        num_layers: int = 2,
        num_units: int = 64,
        message_passing_steps: int = 1,
        mlp_layers: int = None,
        mlp_latent: int = None,
        edge_embedding_size: int = 128,
        edge_mlp_layers: int = 3,
        edge_mlp_latent: int = 128,
        edge_output_size: int = 0,
        global_embedding_size: int = 8,
        global_mlp_layers: int = 0,
        global_mlp_latent: int = 0,
        global_output_size: int = 1,  # Must be 1!
        node_embedding_size: int = 16,
        node_mlp_layers: int = 2,
        node_mlp_latent: int = 128,
        node_output_size: int = 0,
        attn_mlp_layers: int = 2,
        attn_mlp_latent: int = 128,
        use_attention: bool = True,
        normalise_by_link_length: bool = True,
        gnn_layer_norm: bool = True,
        mlp_layer_norm: bool = False,
        *,
        key: Array,
    ):
        assert global_output_size == 1
        self.activation = activation
        self.global_output_size = global_output_size
        self.normalise_by_link_length = normalise_by_link_length

        gnn_key, mlp_key, output_key = jax.random.split(key, 3)

        self.graph_net = GraphNet(
            input_edge_features=input_edge_features,
            input_node_features=input_node_features,
            input_global_features=input_global_features,
            message_passing_steps=message_passing_steps,
            mlp_layers=mlp_layers,
            mlp_latent=mlp_latent,
            edge_embedding_size=edge_embedding_size,
            edge_mlp_layers=edge_mlp_layers,
            edge_mlp_latent=edge_mlp_latent,
            edge_output_size=edge_output_size,
            global_embedding_size=global_embedding_size,
            global_mlp_layers=global_mlp_layers,
            global_mlp_latent=global_mlp_latent,
            global_output_size=global_output_size,
            node_embedding_size=node_embedding_size,
            node_mlp_layers=node_mlp_layers,
            node_mlp_latent=node_mlp_latent,
            node_output_size=node_output_size,
            attn_mlp_layers=attn_mlp_layers,
            attn_mlp_latent=attn_mlp_latent,
            use_attention=use_attention,
            gnn_layer_norm=gnn_layer_norm,
            mlp_layer_norm=mlp_layer_norm,
            key=gnn_key,
        )

        # MLP for processing flattened edge features (only used if global_output_size == 0)
        # We use a placeholder size; actual input size depends on runtime graph
        self.critic_mlp = eqx.nn.MLP(
            in_size=edge_output_size if edge_output_size > 0 else edge_embedding_size,
            out_size=num_units,
            width_size=num_units,
            depth=num_layers,
            activation=select_activation(activation),
            key=mlp_key,
        )

        self.critic_output = make_linear_with_orthogonal_init(num_units, 1, output_key, scale=1.0)

    def __call__(self, state: EnvState, params: EnvParams) -> Array:
        # Remove globals so value does not depend on current request
        graph = state.graph._replace(
            globals=jnp.zeros_like(state.graph.globals)
        )
        state = state.replace(graph=graph)

        processed_graph = self.graph_net(state.graph)

        # Global output is already the scalar value
        # Shape: (1, 1)
        return processed_graph.globals.squeeze()

GraphNet

Bases: Module

A complete Graph Network model defined with Jraph and Equinox.

Source code in xlron/models/gnn.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
class GraphNet(eqx.Module):
    """A complete Graph Network model defined with Jraph and Equinox."""

    # Embedding layers
    edge_embedder: eqx.nn.Linear
    node_embedder: eqx.nn.Linear
    global_embedder: eqx.nn.Linear

    # MLP layers for each message passing step (list of tuples)
    message_passing_layers: tuple

    # Decoder layers
    edge_decoder: Optional[eqx.nn.Linear]
    node_decoder: Optional[eqx.nn.Linear]
    global_decoder: Optional[eqx.nn.Linear]

    # Layer norms for each step
    layer_norms: Optional[tuple]

    # Static configuration
    message_passing_steps: int = eqx.field(static=True)
    edge_embedding_size: int = eqx.field(static=True)
    node_embedding_size: int = eqx.field(static=True)
    global_embedding_size: int = eqx.field(static=True)
    edge_output_size: int = eqx.field(static=True)
    node_output_size: int = eqx.field(static=True)
    global_output_size: int = eqx.field(static=True)
    dropout_rate: float = eqx.field(static=True)
    skip_connections: bool = eqx.field(static=True)
    use_edge_model: bool = eqx.field(static=True)
    gnn_layer_norm: bool = eqx.field(static=True)
    deterministic: bool = eqx.field(static=True)
    use_attention: bool = eqx.field(static=True)

    def __init__(
        self,
        input_edge_features: int,
        input_node_features: int,
        input_global_features: int,
        message_passing_steps: int = 1,
        mlp_layers: int = 0,
        mlp_latent: int = 0,
        edge_embedding_size: int = 128,
        edge_mlp_layers: int = 3,
        edge_mlp_latent: int = 128,
        edge_output_size: int = 0,
        global_embedding_size: int = 8,
        global_mlp_layers: int = 0,
        global_mlp_latent: int = 0,
        global_output_size: int = 0,
        node_embedding_size: int = 16,
        node_mlp_layers: int = 2,
        node_mlp_latent: int = 128,
        node_output_size: int = 0,
        attn_mlp_layers: int = 2,
        attn_mlp_latent: int = 128,
        dropout_rate: float = 0,
        skip_connections: bool = True,
        use_edge_model: bool = True,
        gnn_layer_norm: bool = True,
        mlp_layer_norm: bool = False,
        deterministic: bool = True,
        use_attention: bool = True,
        *,
        key: Array,
    ):
        self.message_passing_steps = message_passing_steps
        self.edge_embedding_size = edge_embedding_size
        self.node_embedding_size = node_embedding_size
        self.global_embedding_size = global_embedding_size
        self.edge_output_size = edge_output_size
        self.node_output_size = node_output_size
        self.global_output_size = global_output_size
        self.dropout_rate = dropout_rate
        self.skip_connections = skip_connections
        self.use_edge_model = use_edge_model
        self.gnn_layer_norm = gnn_layer_norm
        self.deterministic = deterministic
        self.use_attention = use_attention

        # Determine MLP dimensions
        if mlp_latent is not None:
            global_mlp_dims = edge_mlp_dims = node_mlp_dims = attn_mlp_dims = [
                mlp_latent
            ] * mlp_layers
        else:
            global_mlp_dims = [global_mlp_latent] * global_mlp_layers
            edge_mlp_dims = [edge_mlp_latent] * edge_mlp_layers
            node_mlp_dims = [node_mlp_latent] * node_mlp_layers
            attn_mlp_dims = [attn_mlp_latent] * attn_mlp_layers

        if skip_connections:
            edge_mlp_dims = edge_mlp_dims + [edge_embedding_size]
            node_mlp_dims = node_mlp_dims + [node_embedding_size]
            global_mlp_dims = global_mlp_dims + [global_embedding_size]

        # Split keys
        keys = jax.random.split(key, 10 + message_passing_steps * 4)
        key_idx = 0

        # Create embedders
        self.edge_embedder = eqx.nn.Linear(
            input_edge_features, edge_embedding_size, key=keys[key_idx]
        )
        key_idx += 1
        self.node_embedder = eqx.nn.Linear(
            input_node_features, node_embedding_size, key=keys[key_idx]
        )
        key_idx += 1
        self.global_embedder = eqx.nn.Linear(
            input_global_features, global_embedding_size, key=keys[key_idx]
        )
        key_idx += 1

        # Create message passing layers for each step
        mp_layers = []
        layer_norms_list = []

        # Input sizes for MLPs after concatenation
        # Edge MLP: edges + sender_nodes + receiver_nodes + globals
        edge_mlp_input = edge_embedding_size + 2 * node_embedding_size + global_embedding_size
        # Node MLP: nodes + aggregated_sent + aggregated_received + globals
        node_mlp_input = node_embedding_size + 2 * edge_embedding_size + global_embedding_size
        # Global MLP: aggregated_nodes + aggregated_edges + globals
        global_mlp_input = node_embedding_size + edge_embedding_size + global_embedding_size
        # Attention MLP: edges + sender + receiver + globals
        attn_mlp_input = edge_embedding_size + 2 * node_embedding_size + global_embedding_size

        for step in range(message_passing_steps):
            step_layers = {}

            if use_edge_model and edge_mlp_dims:
                step_layers["edge_mlp"] = eqx.nn.MLP(
                    in_size=edge_mlp_input,
                    out_size=edge_mlp_dims[-1] if edge_mlp_dims else edge_embedding_size,
                    width_size=edge_mlp_dims[0] if edge_mlp_dims else edge_embedding_size,
                    depth=len(edge_mlp_dims),
                    activation=jax.nn.relu,
                    key=keys[key_idx],
                )
                key_idx += 1

            if node_mlp_dims:
                step_layers["node_mlp"] = eqx.nn.MLP(
                    in_size=node_mlp_input,
                    out_size=node_mlp_dims[-1] if node_mlp_dims else node_embedding_size,
                    width_size=node_mlp_dims[0] if node_mlp_dims else node_embedding_size,
                    depth=len(node_mlp_dims),
                    activation=jax.nn.relu,
                    key=keys[key_idx],
                )
                key_idx += 1

            if global_output_size > 0 and global_mlp_dims:
                step_layers["global_mlp"] = eqx.nn.MLP(
                    in_size=global_mlp_input,
                    out_size=global_mlp_dims[-1] if global_mlp_dims else global_embedding_size,
                    width_size=global_mlp_dims[0] if global_mlp_dims else global_embedding_size,
                    depth=len(global_mlp_dims),
                    activation=jax.nn.relu,
                    key=keys[key_idx],
                )
                key_idx += 1

            if use_attention and attn_mlp_dims:
                # Ensure at least depth 1 for GATv2-style dynamic attention
                attn_depth = max(len(attn_mlp_dims), 1)
                step_layers["attn_mlp"] = eqx.nn.MLP(
                    in_size=attn_mlp_input,
                    out_size=1,
                    width_size=attn_mlp_dims[0] if attn_mlp_dims else 128,
                    depth=attn_depth,
                    activation=jax.nn.relu,
                    key=keys[key_idx],
                )
                key_idx += 1

            mp_layers.append(step_layers)

            if gnn_layer_norm:
                layer_norms_list.append(
                    {
                        "node": eqx.nn.LayerNorm(node_embedding_size),
                        "edge": eqx.nn.LayerNorm(edge_embedding_size),
                        "global": eqx.nn.LayerNorm(global_embedding_size)
                        if global_output_size > 0
                        else None,
                    }
                )

        self.message_passing_layers = tuple(mp_layers)
        self.layer_norms = tuple(layer_norms_list) if gnn_layer_norm else None

        # Create decoders
        self.edge_decoder = (
            eqx.nn.Linear(edge_embedding_size, edge_output_size, key=keys[key_idx])
            if edge_output_size > 0
            else None
        )
        key_idx += 1
        self.node_decoder = (
            eqx.nn.Linear(node_embedding_size, node_output_size, key=keys[key_idx])
            if node_output_size > 0
            else None
        )
        key_idx += 1
        self.global_decoder = (
            eqx.nn.Linear(global_embedding_size, global_output_size, key=keys[key_idx])
            if global_output_size > 0
            else None
        )

    def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
        # Flatten edges if needed
        if graphs.edges.ndim >= 3:
            edges = graphs.edges.reshape((graphs.edges.shape[0], -1))
            graphs = graphs._replace(edges=edges)

        # Embed
        nodes = jax.vmap(self.node_embedder)(graphs.nodes)
        edges = jax.vmap(self.edge_embedder)(graphs.edges)
        globals_ = (
            jax.vmap(self.global_embedder)(graphs.globals)
            if graphs.globals is not None
            else jnp.zeros((1, self.global_embedding_size))
        )

        processed_graphs = graphs._replace(nodes=nodes, edges=edges, globals=globals_)

        # Message passing
        for step, step_layers in enumerate(self.message_passing_layers):
            # Build update functions using the MLPs
            # jraph.concatenated_args wraps a fn(concatenated_input) to accept separate args
            # # Remember to vmap over edges
            def make_update_edge_fn(edge_mlp):
                def update_edge_fn(concatenated_inputs):
                    return jax.vmap(edge_mlp)(concatenated_inputs)

                return jraph.concatenated_args(update_edge_fn)

            def make_update_node_fn(node_mlp):
                def update_node_fn(concatenated_inputs):
                    return jax.vmap(node_mlp)(concatenated_inputs)

                return jraph.concatenated_args(update_node_fn)

            def make_update_global_fn(global_mlp):
                def update_global_fn(concatenated_inputs):
                    return jax.vmap(global_mlp)(concatenated_inputs)

                return jraph.concatenated_args(update_global_fn)

            update_edge_fn = (
                make_update_edge_fn(step_layers["edge_mlp"]) if "edge_mlp" in step_layers else None
            )
            update_node_fn = (
                make_update_node_fn(step_layers["node_mlp"]) if "node_mlp" in step_layers else None
            )
            update_global_fn = (
                make_update_global_fn(step_layers["global_mlp"])
                if "global_mlp" in step_layers
                else None
            )

            if self.use_attention and "attn_mlp" in step_layers:
                attn_mlp = step_layers["attn_mlp"]

                def attention_logit_fn(edges, sender_attr, receiver_attr, global_edge_attributes):
                    x = jnp.concatenate(
                        (edges, sender_attr, receiver_attr, global_edge_attributes), axis=1
                    )
                    return jax.vmap(attn_mlp)(x)

                def attention_reduce_fn(edges, attention):
                    return attention * edges

                graph_net = GraphNetGAT(
                    update_node_fn=update_node_fn,
                    update_edge_fn=update_edge_fn,
                    update_global_fn=update_global_fn,
                    attention_logit_fn=attention_logit_fn,
                    attention_reduce_fn=attention_reduce_fn,
                )
            else:
                graph_net = GraphNetwork(
                    update_node_fn=update_node_fn,
                    update_edge_fn=update_edge_fn,
                    update_global_fn=update_global_fn,
                )

            new_graphs = graph_net(processed_graphs)

            if self.skip_connections:
                processed_graphs = add_graphs_tuples(new_graphs, processed_graphs)
            else:
                processed_graphs = new_graphs

            if self.gnn_layer_norm and self.layer_norms is not None:
                ln = self.layer_norms[step]
                processed_graphs = processed_graphs._replace(
                    nodes=jax.vmap(ln["node"])(processed_graphs.nodes),
                    edges=jax.vmap(ln["edge"])(processed_graphs.edges),
                    globals=jax.vmap(ln["global"])(processed_graphs.globals)
                    if ln["global"] is not None and processed_graphs.globals is not None
                    else processed_graphs.globals,
                )

        # Decode
        if self.edge_decoder is not None:
            edges = jax.vmap(self.edge_decoder)(processed_graphs.edges)
            processed_graphs = processed_graphs._replace(edges=edges)
        if self.node_decoder is not None:
            nodes = jax.vmap(self.node_decoder)(processed_graphs.nodes)
            processed_graphs = processed_graphs._replace(nodes=nodes)
        if self.global_decoder is not None and processed_graphs.globals is not None:
            globals_ = jax.vmap(self.global_decoder)(processed_graphs.globals)
            processed_graphs = processed_graphs._replace(globals=globals_)

        return processed_graphs

GAT(attention_query_fn, attention_logit_fn, node_update_fn=None)

Returns a method that applies a Graph Attention Network layer.

Graph Attention message passing as described in https://arxiv.org/abs/1710.10903. This model expects node features as a jnp.array, may use edge features for computing attention weights, and ignore global features. It does not support nests.

NOTE: this implementation assumes that the input graph has self edges. To recover the behavior of the referenced paper, please add self edges.

Parameters:

Name Type Description Default
attention_query_fn GATAttentionQueryFn

function that generates attention queries from sender node features.

required
attention_logit_fn GATAttentionLogitFn

function that converts attention queries into logits for softmax attention.

required
node_update_fn Optional[GATNodeUpdateFn]

function that updates the aggregated messages. If None, will apply leaky relu and concatenate (if using multi-head attention).

None

Returns:

Type Description

A function that applies a Graph Attention layer.

Source code in xlron/models/gnn.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
def GAT(
    attention_query_fn: GATAttentionQueryFn,
    attention_logit_fn: GATAttentionLogitFn,
    node_update_fn: Optional[GATNodeUpdateFn] = None,
):
    """Returns a method that applies a Graph Attention Network layer.

    Graph Attention message passing as described in
    https://arxiv.org/abs/1710.10903. This model expects node features as a
    jnp.array, may use edge features for computing attention weights, and
    ignore global features. It does not support nests.

    NOTE: this implementation assumes that the input graph has self edges. To
    recover the behavior of the referenced paper, please add self edges.

    Args:
      attention_query_fn: function that generates attention queries
        from sender node features.
      attention_logit_fn: function that converts attention queries into logits for
        softmax attention.
      node_update_fn: function that updates the aggregated messages. If None,
        will apply leaky relu and concatenate (if using multi-head attention).

    Returns:
      A function that applies a Graph Attention layer.
    """
    # pylint: disable=g-long-lambda
    if node_update_fn is None:
        # By default, apply the leaky relu and then concatenate the heads on the
        # feature axis.
        def node_update_fn(x):
            return jnp.reshape(jax.nn.leaky_relu(x), (x.shape[0], -1))

    def _ApplyGAT(graph):
        """Applies a Graph Attention layer."""
        nodes, edges, receivers, senders, _, _, _ = graph
        # Equivalent to the sum of n_node, but statically known.
        try:
            sum_n_node = nodes.shape[0]
        except IndexError:
            raise IndexError("GAT requires node features")  # pylint: disable=raise-missing-from

        # First pass nodes through the node updater.
        nodes = attention_query_fn(nodes)
        # pylint: disable=g-long-lambda
        # We compute the softmax logits using a function that takes the
        # embedded sender and receiver attributes.
        sent_attributes = nodes[senders]
        received_attributes = nodes[receivers]
        softmax_logits = attention_logit_fn(sent_attributes, received_attributes, edges)

        # Compute the softmax weights on the entire tree.
        weights = utils.segment_softmax(
            softmax_logits, segment_ids=receivers, num_segments=sum_n_node
        )
        # Apply weights
        messages = sent_attributes * weights
        # Aggregate messages to nodes.
        nodes = utils.segment_sum(messages, receivers, num_segments=sum_n_node)

        # Apply an update function to the aggregated messages.
        nodes = node_update_fn(nodes)
        return graph._replace(nodes=nodes)

    # pylint: enable=g-long-lambda
    return _ApplyGAT

GraphNetGAT(update_edge_fn, update_node_fn, attention_logit_fn, attention_reduce_fn, update_global_fn=None, aggregate_edges_for_nodes_fn=utils.segment_sum, aggregate_nodes_for_globals_fn=utils.segment_sum, aggregate_edges_for_globals_fn=utils.segment_sum)

Returns a method that applies a GraphNet with attention on edge features.

Parameters:

Name Type Description Default
update_edge_fn GNUpdateEdgeFn

function used to update the edges.

required
update_node_fn GNUpdateNodeFn

function used to update the nodes.

required
attention_logit_fn AttentionLogitFn

function used to calculate the attention weights.

required
attention_reduce_fn AttentionReduceFn

function used to apply attention weights to the edge features.

required
update_global_fn Optional[GNUpdateGlobalFn]

function used to update the globals or None to deactivate globals updates.

None
aggregate_edges_for_nodes_fn AggregateEdgesToNodesFn

function used to aggregate attention-weighted messages to each node.

segment_sum
aggregate_nodes_for_globals_fn AggregateNodesToGlobalsFn

function used to aggregate the nodes for the globals.

segment_sum
aggregate_edges_for_globals_fn AggregateEdgesToGlobalsFn

function used to aggregate attention-weighted edges for the globals.

segment_sum

Returns:

Type Description

A function that applies a GraphNet Graph Attention layer.

Source code in xlron/models/gnn.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def GraphNetGAT(
    update_edge_fn: GNUpdateEdgeFn,
    update_node_fn: GNUpdateNodeFn,
    attention_logit_fn: AttentionLogitFn,
    attention_reduce_fn: AttentionReduceFn,
    update_global_fn: Optional[GNUpdateGlobalFn] = None,
    aggregate_edges_for_nodes_fn: AggregateEdgesToNodesFn = utils.segment_sum,
    aggregate_nodes_for_globals_fn: AggregateNodesToGlobalsFn = utils.segment_sum,
    aggregate_edges_for_globals_fn: AggregateEdgesToGlobalsFn = utils.segment_sum,
):
    """Returns a method that applies a GraphNet with attention on edge features.

    Args:
      update_edge_fn: function used to update the edges.
      update_node_fn: function used to update the nodes.
      attention_logit_fn: function used to calculate the attention weights.
      attention_reduce_fn: function used to apply attention weights to the edge
        features.
      update_global_fn: function used to update the globals or None to deactivate
        globals updates.
      aggregate_edges_for_nodes_fn: function used to aggregate attention-weighted
        messages to each node.
      aggregate_nodes_for_globals_fn: function used to aggregate the nodes for the
        globals.
      aggregate_edges_for_globals_fn: function used to aggregate
        attention-weighted edges for the globals.

    Returns:
      A function that applies a GraphNet Graph Attention layer.
    """
    if (attention_logit_fn is None) or (attention_reduce_fn is None):
        raise ValueError(
            (
                "`None` value not supported for `attention_logit_fn` or "
                "`attention_reduce_fn` in a Graph Attention network."
            )
        )
    return GraphNetwork(
        update_edge_fn=update_edge_fn,
        update_node_fn=update_node_fn,
        update_global_fn=update_global_fn,
        attention_logit_fn=attention_logit_fn,
        attention_reduce_fn=attention_reduce_fn,
        aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn,
        aggregate_nodes_for_globals_fn=aggregate_nodes_for_globals_fn,
        aggregate_edges_for_globals_fn=aggregate_edges_for_globals_fn,
    )

GraphNetwork(update_edge_fn, update_node_fn, update_global_fn=None, aggregate_edges_for_nodes_fn=utils.segment_sum, aggregate_nodes_for_globals_fn=utils.segment_sum, aggregate_edges_for_globals_fn=utils.segment_sum, attention_logit_fn=None, attention_normalize_fn=utils.segment_softmax, attention_reduce_fn=None)

Returns a method that applies a configured GraphNetwork.

This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261

There is one difference. For the nodes update the class aggregates over the sender edges and receiver edges separately. This is a bit more general than the algorithm described in the paper. The original behaviour can be recovered by using only the receiver edge aggregations for the update.

In addition this implementation supports softmax attention over incoming edge features.

Example usage::

gn = GraphNetwork(update_edge_function, update_node_function, **kwargs) # Conduct multiple rounds of message passing with the same parameters: for _ in range(num_message_passing_steps): graph = gn(graph)

Parameters:

Name Type Description Default
update_edge_fn Optional[GNUpdateEdgeFn]

function used to update the edges or None to deactivate edge updates.

required
update_node_fn Optional[GNUpdateNodeFn]

function used to update the nodes or None to deactivate node updates.

required
update_global_fn Optional[GNUpdateGlobalFn]

function used to update the globals or None to deactivate globals updates.

None
aggregate_edges_for_nodes_fn AggregateEdgesToNodesFn

function used to aggregate messages to each node.

segment_sum
aggregate_nodes_for_globals_fn AggregateNodesToGlobalsFn

function used to aggregate the nodes for the globals.

segment_sum
aggregate_edges_for_globals_fn AggregateEdgesToGlobalsFn

function used to aggregate the edges for the globals.

segment_sum
attention_logit_fn Optional[AttentionLogitFn]

function used to calculate the attention weights or None to deactivate attention mechanism.

None
attention_normalize_fn Optional[AttentionNormalizeFn]

function used to normalize raw attention logits or None if attention mechanism is not active.

segment_softmax
attention_reduce_fn Optional[AttentionReduceFn]

function used to apply weights to the edge features or None if attention mechanism is not active.

None

Returns:

Type Description

A method that applies the configured GraphNetwork.

Source code in xlron/models/gnn.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def GraphNetwork(
    update_edge_fn: Optional[GNUpdateEdgeFn],
    update_node_fn: Optional[GNUpdateNodeFn],
    update_global_fn: Optional[GNUpdateGlobalFn] = None,
    # TODO: allow RNN/SSM to be used in the aggregation function
    #  https://github.com/luchris429/popjaxrl/blob/main/algorithms/ppo_gru.py
    aggregate_edges_for_nodes_fn: AggregateEdgesToNodesFn = utils.segment_sum,
    aggregate_nodes_for_globals_fn: AggregateNodesToGlobalsFn = utils.segment_sum,
    aggregate_edges_for_globals_fn: AggregateEdgesToGlobalsFn = utils.segment_sum,
    attention_logit_fn: Optional[AttentionLogitFn] = None,
    attention_normalize_fn: Optional[AttentionNormalizeFn] = utils.segment_softmax,
    attention_reduce_fn: Optional[AttentionReduceFn] = None,
):
    """Returns a method that applies a configured GraphNetwork.

    This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261

    There is one difference. For the nodes update the class aggregates over the
    sender edges and receiver edges separately. This is a bit more general
    than the algorithm described in the paper. The original behaviour can be
    recovered by using only the receiver edge aggregations for the update.

    In addition this implementation supports softmax attention over incoming
    edge features.

    Example usage::

      gn = GraphNetwork(update_edge_function,
      update_node_function, **kwargs)
      # Conduct multiple rounds of message passing with the same parameters:
      for _ in range(num_message_passing_steps):
        graph = gn(graph)

    Args:
      update_edge_fn: function used to update the edges or None to deactivate edge
        updates.
      update_node_fn: function used to update the nodes or None to deactivate node
        updates.
      update_global_fn: function used to update the globals or None to deactivate
        globals updates.
      aggregate_edges_for_nodes_fn: function used to aggregate messages to each
        node.
      aggregate_nodes_for_globals_fn: function used to aggregate the nodes for the
        globals.
      aggregate_edges_for_globals_fn: function used to aggregate the edges for the
        globals.
      attention_logit_fn: function used to calculate the attention weights or
        None to deactivate attention mechanism.
      attention_normalize_fn: function used to normalize raw attention logits or
        None if attention mechanism is not active.
      attention_reduce_fn: function used to apply weights to the edge features or
        None if attention mechanism is not active.

    Returns:
      A method that applies the configured GraphNetwork.
    """

    def not_both_supplied(x, y):
        return (x != y) and ((x is None) or (y is None))

    if not_both_supplied(attention_reduce_fn, attention_logit_fn):
        raise ValueError(("attention_logit_fn and attention_reduce_fn must both be supplied."))

    def _ApplyGraphNet(graph):
        """Applies a configured GraphNetwork to a graph.

        This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261

        There is one difference. For the nodes update the class aggregates over the
        sender edges and receiver edges separately. This is a bit more general
        the algorithm described in the paper. The original behaviour can be
        recovered by using only the receiver edge aggregations for the update.

        In addition this implementation supports softmax attention over incoming
        edge features.

        Many popular Graph Neural Networks can be implemented as special cases of
        GraphNets, for more information please see the paper.

        Args:
          graph: a `GraphsTuple` containing the graph.

        Returns:
          Updated `GraphsTuple`.
        """
        # pylint: disable=g-long-lambda
        nodes, edges, receivers, senders, globals_, n_node, n_edge = graph
        # Equivalent to jnp.sum(n_node), but jittable
        sum_n_node = tree.tree_leaves(nodes)[0].shape[0]
        sum_n_edge = senders.shape[0]
        if not tree.tree_all(tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)):
            raise ValueError("All node arrays in nest must contain the same number of nodes.")

        sent_attributes = tree.tree_map(lambda n: n[senders], nodes)
        received_attributes = tree.tree_map(lambda n: n[receivers], nodes)
        # Here we scatter the global features to the corresponding edges,
        # giving us tensors of shape [num_edges, global_feat].
        global_edge_attributes = tree.tree_map(
            lambda g: jnp.repeat(g, n_edge, axis=0, total_repeat_length=sum_n_edge),
            globals_,
        )

        if update_edge_fn:
            edges = update_edge_fn(
                edges, sent_attributes, received_attributes, global_edge_attributes
            )

        if attention_logit_fn:
            logits = attention_logit_fn(
                edges, sent_attributes, received_attributes, global_edge_attributes
            )
            tree_calculate_weights = functools.partial(
                attention_normalize_fn, segment_ids=receivers, num_segments=sum_n_node
            )
            weights = tree.tree_map(tree_calculate_weights, logits)
            edges = attention_reduce_fn(edges, weights)

        if update_node_fn:
            sent_attributes = tree.tree_map(
                lambda e: aggregate_edges_for_nodes_fn(e, senders, sum_n_node), edges
            )
            received_attributes = tree.tree_map(
                lambda e: aggregate_edges_for_nodes_fn(e, receivers, sum_n_node), edges
            )
            # Here we scatter the global features to the corresponding nodes,
            # giving us tensors of shape [num_nodes, global_feat].
            global_attributes = tree.tree_map(
                lambda g: jnp.repeat(g, n_node, axis=0, total_repeat_length=sum_n_node),
                globals_,
            )
            nodes = update_node_fn(nodes, sent_attributes, received_attributes, global_attributes)

        if update_global_fn:
            n_graph = n_node.shape[0]
            graph_idx = jnp.arange(n_graph)
            # To aggregate nodes and edges from each graph to global features,
            # we first construct tensors that map the node to the corresponding graph.
            # For example, if you have `n_node=[1,2]`, we construct the tensor
            # [0, 1, 1]. We then do the same for edges.
            node_gr_idx = jnp.repeat(graph_idx, n_node, axis=0, total_repeat_length=sum_n_node)
            edge_gr_idx = jnp.repeat(graph_idx, n_edge, axis=0, total_repeat_length=sum_n_edge)
            # We use the aggregation function to pool the nodes/edges per graph.
            node_attributes = tree.tree_map(
                lambda n: aggregate_nodes_for_globals_fn(n, node_gr_idx, n_graph), nodes
            )
            edge_attributes = tree.tree_map(
                lambda e: aggregate_edges_for_globals_fn(e, edge_gr_idx, n_graph), edges
            )
            # These pooled nodes are the inputs to the global update fn.
            globals_ = update_global_fn(node_attributes, edge_attributes, globals_)
        # pylint: enable=g-long-lambda
        return gn_graph.GraphsTuple(
            nodes=nodes,
            edges=edges,
            receivers=receivers,
            senders=senders,
            globals=globals_,
            n_node=n_node,
            n_edge=n_edge,
        )

    return _ApplyGraphNet

add_graphs_tuples(graphs, other_graphs)

Adds the nodes, edges and global features from other_graphs to graphs.

Source code in xlron/models/gnn.py
380
381
382
383
384
385
386
387
388
def add_graphs_tuples(
    graphs: jraph.GraphsTuple, other_graphs: jraph.GraphsTuple
) -> jraph.GraphsTuple:
    """Adds the nodes, edges and global features from other_graphs to graphs."""
    return graphs._replace(
        nodes=graphs.nodes + other_graphs.nodes,
        edges=graphs.edges + other_graphs.edges,
        globals=graphs.globals + other_graphs.globals if graphs.globals is not None else None,
    )

add_self_edges_fn(receivers, senders, total_num_nodes)

Adds self edges. Assumes self edges are not in the graph yet.

Source code in xlron/models/gnn.py
84
85
86
87
88
89
90
91
def add_self_edges_fn(
    receivers: jnp.ndarray, senders: jnp.ndarray, total_num_nodes: int
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Adds self edges. Assumes self edges are not in the graph yet."""
    # TODo - check if self-edges required
    receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)), axis=0)
    senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)), axis=0)
    return receivers, senders

ActorCriticTransformer

Bases: Module

Source code in xlron/models/transformer.py
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
class ActorCriticTransformer(eqx.Module):
    actor_critic: eqx.nn.Shared | Tuple[eqx.Module, eqx.Module]
    actor_mlp: eqx.nn.MLP
    critic_mlp: eqx.nn.MLP
    share_layers: bool
    num_slot_actions: int
    num_request_specific_cols: int = eqx.field(static=True)
    embedding_size: int = eqx.field(static=True)

    def __init__(
        self,
        input_size: int,
        embedding_size: int,
        intermediate_size: int,
        num_slot_actions: int,
        num_layers: int,
        num_heads: int,
        enable_dropout: bool,
        dropout_rate: float,
        attention_dropout_rate: float,
        share_layers: bool,
        num_wire_features: int,
        actor_mlp_width: int,
        critic_mlp_width: int,
        actor_mlp_depth: int,
        critic_mlp_depth: int,
        num_request_specific_cols: int,
        key: chex.PRNGKey,
    ):
        (
            encoder_key,
            actor_key,
            critic_key,
        ) = jax.random.split(key, 3)
        self.share_layers = share_layers
        self.num_request_specific_cols = num_request_specific_cols
        self.embedding_size = embedding_size
        actor = Encoder(
            input_size=input_size,
            intermediate_size=intermediate_size,
            embedding_size=embedding_size,
            num_layers=num_layers,
            num_heads=num_heads,
            num_wire_features=num_wire_features,
            dropout_rate=dropout_rate,
            attention_dropout_rate=attention_dropout_rate,
            key=encoder_key,
        )
        critic = Encoder(
            input_size=input_size - num_request_specific_cols,
            intermediate_size=intermediate_size,
            embedding_size=embedding_size,
            num_layers=num_layers,
            num_heads=num_heads,
            num_wire_features=num_wire_features,
            dropout_rate=dropout_rate,
            attention_dropout_rate=attention_dropout_rate,
            key=encoder_key,
        )
        if self.share_layers:
            # When sharing layers, use the same encoder for both actor and critic
            self.actor_critic = (actor, actor)
        else:
            self.actor_critic = (actor, critic)
        self.actor_mlp = eqx.nn.MLP(
            in_size=embedding_size,
            width_size=actor_mlp_width,
            out_size=num_slot_actions,
            depth=actor_mlp_depth,
            key=actor_key,
        )
        self.critic_mlp = eqx.nn.MLP(
            in_size=embedding_size,
            width_size=critic_mlp_width,
            out_size=1,
            depth=critic_mlp_depth,
            key=critic_key,
        )
        self.num_slot_actions = num_slot_actions

    def __call__(
        self,
        state: EnvState,
        params: EnvParams,
        *,
        enable_dropout: bool = False,
        key: chex.PRNGKey | None = None,
    ) -> Tuple[distrax.Categorical, Array]:
        """Forward pass through the actor-critic transformer.

        Args:
            state: Environment state
            params: Environment parameters
            enable_dropout: Whether to enable dropout
            key: PRNG key for dropout

        Returns:
            Tuple of (action_distribution, value)
        """
        actor, critic = self.actor_critic
        actor_key, critic_key = jax.random.split(key) if key is not None else (None, None)
        tokens = get_obs_transformer(state, params)

        action_tokens = actor(
            tokens,
            enable_dropout=enable_dropout,
            key=actor_key,
        )["output"]

        # Strip request-specific columns for critic
        tokens_for_critic = tokens[:, : -self.num_request_specific_cols]
        value_tokens = critic(
            tokens_for_critic,
            enable_dropout=enable_dropout,
            key=critic_key,
        )["output"]

        # Project per-link embeddings to slot logits, then pool across path links
        action_tokens = jax.vmap(self.actor_mlp)(action_tokens)

        # POOLING - sum edges per path
        nodes_sd, requested_bw = read_rsa_request(state.request_array)
        def path_action_dist(i):
            return get_path_slots(
                action_tokens,
                params,
                nodes_sd,
                i,
                agg_func="sum",
            )
        path_action_logits = jax.vmap(path_action_dist)(
            jnp.arange(params.k_paths)
        )
        action_logits = path_action_logits.reshape((-1,))

        if params.include_no_op:
            action_logits = jnp.hstack([action_logits, jnp.array([-1e4])])
        action_dist = distrax.Categorical(logits=action_logits)

        # Pool and Value
        value_tokens_mean = jnp.mean(value_tokens, axis=0)
        value = self.critic_mlp(value_tokens_mean).squeeze()

        return action_dist, value

    def sample_action(
        self,
        seed: chex.PRNGKey,
        dist: distrax.Categorical,
        log_prob: bool = False,
        deterministic: bool = False,
    ) -> Union[Array, Tuple[Array, Array]]:
        """Sample an action from the distribution"""
        action = jnp.argmax(dist.probs()) if deterministic else dist.sample(seed=seed)
        if log_prob:
            return action, dist.log_prob(action)
        return action

__call__(state, params, *, enable_dropout=False, key=None)

Forward pass through the actor-critic transformer.

Parameters:

Name Type Description Default
state EnvState

Environment state

required
params EnvParams

Environment parameters

required
enable_dropout bool

Whether to enable dropout

False
key PRNGKey | None

PRNG key for dropout

None

Returns:

Type Description
Tuple[Categorical, Array]

Tuple of (action_distribution, value)

Source code in xlron/models/transformer.py
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
def __call__(
    self,
    state: EnvState,
    params: EnvParams,
    *,
    enable_dropout: bool = False,
    key: chex.PRNGKey | None = None,
) -> Tuple[distrax.Categorical, Array]:
    """Forward pass through the actor-critic transformer.

    Args:
        state: Environment state
        params: Environment parameters
        enable_dropout: Whether to enable dropout
        key: PRNG key for dropout

    Returns:
        Tuple of (action_distribution, value)
    """
    actor, critic = self.actor_critic
    actor_key, critic_key = jax.random.split(key) if key is not None else (None, None)
    tokens = get_obs_transformer(state, params)

    action_tokens = actor(
        tokens,
        enable_dropout=enable_dropout,
        key=actor_key,
    )["output"]

    # Strip request-specific columns for critic
    tokens_for_critic = tokens[:, : -self.num_request_specific_cols]
    value_tokens = critic(
        tokens_for_critic,
        enable_dropout=enable_dropout,
        key=critic_key,
    )["output"]

    # Project per-link embeddings to slot logits, then pool across path links
    action_tokens = jax.vmap(self.actor_mlp)(action_tokens)

    # POOLING - sum edges per path
    nodes_sd, requested_bw = read_rsa_request(state.request_array)
    def path_action_dist(i):
        return get_path_slots(
            action_tokens,
            params,
            nodes_sd,
            i,
            agg_func="sum",
        )
    path_action_logits = jax.vmap(path_action_dist)(
        jnp.arange(params.k_paths)
    )
    action_logits = path_action_logits.reshape((-1,))

    if params.include_no_op:
        action_logits = jnp.hstack([action_logits, jnp.array([-1e4])])
    action_dist = distrax.Categorical(logits=action_logits)

    # Pool and Value
    value_tokens_mean = jnp.mean(value_tokens, axis=0)
    value = self.critic_mlp(value_tokens_mean).squeeze()

    return action_dist, value

sample_action(seed, dist, log_prob=False, deterministic=False)

Sample an action from the distribution

Source code in xlron/models/transformer.py
555
556
557
558
559
560
561
562
563
564
565
566
def sample_action(
    self,
    seed: chex.PRNGKey,
    dist: distrax.Categorical,
    log_prob: bool = False,
    deterministic: bool = False,
) -> Union[Array, Tuple[Array, Array]]:
    """Sample an action from the distribution"""
    action = jnp.argmax(dist.probs()) if deterministic else dist.sample(seed=seed)
    if log_prob:
        return action, dist.log_prob(action)
    return action

AttentionBlock

Bases: Module

A single transformer attention block.

Source code in xlron/models/transformer.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
class AttentionBlock(eqx.Module):
    """A single transformer attention block."""

    attention: MultiheadAttention
    layernorm: eqx.nn.LayerNorm
    dropout: eqx.nn.Dropout
    num_heads: int = eqx.field(static=True)

    def __init__(
        self,
        embedding_size: int,
        num_heads: int,
        dropout_rate: float,
        attention_dropout_rate: float,
        key: chex.PRNGKey,
    ):
        self.num_heads = num_heads
        self.attention = MultiheadAttention(
            num_heads=num_heads,
            query_size=embedding_size,
            dropout_p=attention_dropout_rate,
            key=key,
        )
        self.layernorm = eqx.nn.LayerNorm(shape=embedding_size)
        self.dropout = eqx.nn.Dropout(dropout_rate)

    def __call__(
        self,
        inputs: Array,
        enable_dropout: bool = False,
        attn_mask: Array | None = None,
        process_heads: Callable | None = None,
        key: chex.PRNGKey | None = None,
    ) -> Array:
        attention_key, dropout_key = (None, None) if key is None else jax.random.split(key)

        norm_input = jax.vmap(self.layernorm)(inputs)
        attention_output = self.attention(
            query=norm_input,
            key_=norm_input,
            value=norm_input,
            mask=attn_mask,
            inference=not enable_dropout,
            key=attention_key,
            process_heads=process_heads,
        )
        result = self.dropout(attention_output, inference=not enable_dropout, key=dropout_key)
        result = result + inputs
        return result

FeedForwardBlock

Bases: Module

A single transformer feed forward block.

Source code in xlron/models/transformer.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
class FeedForwardBlock(eqx.Module):
    """A single transformer feed forward block."""

    linear: eqx.nn.Linear
    output: eqx.nn.Linear
    layernorm: eqx.nn.LayerNorm
    dropout: eqx.nn.Dropout

    def __init__(
        self,
        embedding_size: int,
        intermediate_size: int,
        dropout_rate: float,
        key: chex.PRNGKey,
    ):
        mlp_key, output_key = jax.random.split(key)
        self.linear = eqx.nn.Linear(
            in_features=embedding_size, out_features=intermediate_size, key=mlp_key
        )
        self.output = eqx.nn.Linear(
            in_features=intermediate_size, out_features=embedding_size, key=output_key
        )
        self.layernorm = eqx.nn.LayerNorm(shape=embedding_size)
        self.dropout = eqx.nn.Dropout(dropout_rate)

    def __call__(
        self,
        inputs: Array,
        enable_dropout: bool = True,
        key: chex.PRNGKey | None = None,
    ) -> Array:
        hidden = self.layernorm(inputs)

        # Feed-forward.
        hidden = self.linear(hidden)
        hidden = jax.nn.gelu(hidden)

        # Project back to input size.
        output = self.output(hidden)
        output = self.dropout(output, inference=not enable_dropout, key=key)

        # Residual
        output += inputs

        return output

TransformerLayer

Bases: Module

A single transformer layer.

Source code in xlron/models/transformer.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
class TransformerLayer(eqx.Module):
    """A single transformer layer."""

    attention_block: AttentionBlock
    ff_block: FeedForwardBlock

    def __init__(
        self,
        embedding_size: int,
        intermediate_size: int,
        num_heads: int,
        dropout_rate: float,
        attention_dropout_rate: float,
        key: chex.PRNGKey,
        custom_bias: bool = False,
    ):
        attention_key, ff_key = jax.random.split(key)

        self.attention_block = AttentionBlock(
            embedding_size=embedding_size,
            num_heads=num_heads,
            dropout_rate=dropout_rate,
            attention_dropout_rate=attention_dropout_rate,
            key=attention_key,
        )
        self.ff_block = FeedForwardBlock(
            embedding_size=embedding_size,
            intermediate_size=intermediate_size,
            dropout_rate=dropout_rate,
            key=ff_key,
        )

    def __call__(
        self,
        inputs: Array,
        *,
        enable_dropout: bool = False,
        attn_mask: Array | None = None,
        key: chex.PRNGKey | None = None,
        process_heads: Callable | None = None,
    ) -> Array:
        attn_key, ff_key = (None, None) if key is None else jax.random.split(key)

        attn_result = self.attention_block(
            inputs,
            enable_dropout=enable_dropout,
            attn_mask=attn_mask,
            key=attn_key,
            process_heads=process_heads,
        )
        # attention_block returns Array when return_attention=False (the default)
        attention_output = attn_result if isinstance(attn_result, Array) else attn_result[0]

        seq_len = inputs.shape[0]
        ff_keys = None if ff_key is None else jax.random.split(ff_key, num=seq_len)
        output = jax.vmap(self.ff_block, in_axes=(0, None, 0))(
            attention_output, enable_dropout, ff_keys
        )
        return output

WIRE

Bases: Module

Wavelet-Induced Rotary Encodings for graphs. https://openreview.net/pdf?id=f7BvsdILYx

Projects m-dimensional node features (e.g., RWSE, spectral coords) to rotation angles for RoPE-style positional encoding.

Source code in xlron/models/transformer.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class WIRE(eqx.Module):
    """Wavelet-Induced Rotary Encodings for graphs.
    https://openreview.net/pdf?id=f7BvsdILYx

    Projects m-dimensional node features (e.g., RWSE, spectral coords)
    to rotation angles for RoPE-style positional encoding.
    """

    freq_proj: eqx.nn.Linear  # (m,) -> (embedding_size // 2,)
    embedding_size: int = eqx.field(static=True)

    def __init__(
        self,
        num_features: int,
        embedding_size: int,
        key: PRNGKeyArray,
        freq_scale: float = 0.01,
    ):
        """
        Args:
            num_features: Dimension of input position features (m)
            embedding_size: Dimension of queries/keys to rotate (must be even)
            key: PRNG key
            freq_scale: Scale for frequency initialisation
        """
        if embedding_size % 2 != 0:
            raise ValueError("embedding_size must be even")

        self.embedding_size = embedding_size

        # Project position features to angles
        # Output dim is embedding_size // 2 (one angle per 2D rotation block)
        self.freq_proj = eqx.nn.Linear(
            in_features=num_features,
            out_features=embedding_size // 2,
            use_bias=False,
            key=key,
        )

        # Optionally scale down initial frequencies for stability
        scaled_weight = self.freq_proj.weight * freq_scale
        self.freq_proj = eqx.tree_at(lambda layer: layer.weight, self.freq_proj, scaled_weight)

    def get_angles(
        self, positions: Float[Array, "num_nodes num_features"]
    ) -> Float[Array, "num_nodes half_emb"]:
        """Compute rotation angles from position features."""
        return jax.vmap(self.freq_proj)(positions)

    def rotate(
        self,
        x: Float[Array, "num_nodes embedding_size"],
        angles: Float[Array, "num_nodes half_emb"],
    ) -> Float[Array, "num_nodes embedding_size"]:
        """Apply rotary encoding to queries or keys.

        For each 2D block [x_{2i}, x_{2i+1}], rotate by angle theta_i:
            x_{2i}'   = x_{2i} * cos(theta) - x_{2i+1} * sin(theta)
            x_{2i+1}' = x_{2i} * sin(theta) + x_{2i+1} * cos(theta)
        """
        # angles: (num_nodes, embedding_size // 2)
        cos_angles = jnp.cos(angles)
        sin_angles = jnp.sin(angles)

        # Repeat for pairs: [cos_0, cos_0, cos_1, cos_1, ...]
        cos_angles = jnp.repeat(cos_angles, 2, axis=-1)
        sin_angles = jnp.repeat(sin_angles, 2, axis=-1)

        # Rotate pairs: for indices [0,1], [2,3], etc.
        # x_rotated = x * cos + rotate_pairs(x) * sin
        # where rotate_pairs swaps and negates: [x0, x1] -> [-x1, x0]
        x_pairs = x.reshape(x.shape[0], -1, 2)  # (num_nodes, num_pairs, 2)
        x_rotated_pairs = jnp.stack([-x_pairs[..., 1], x_pairs[..., 0]], axis=-1)
        x_rotated = x_rotated_pairs.reshape(x.shape)

        return x * cos_angles + x_rotated * sin_angles

    def __call__(
        self,
        queries: Float[Array, "num_nodes embedding_size"],
        keys: Float[Array, "num_nodes embedding_size"],
        positions: Float[Array, "num_nodes num_features"],
    ) -> tuple[
        Float[Array, "num_nodes embedding_size"],
        Float[Array, "num_nodes embedding_size"],
    ]:
        """Apply WIRE to queries and keys.

        Args:
            queries: Query vectors (num_nodes, embedding_size)
            keys: Key vectors (num_nodes, embedding_size)
            positions: Node position features, e.g., RWSE (num_nodes, num_features)

        Returns:
            Rotated (queries, keys)
        """
        angles = self.get_angles(positions)
        return self.rotate(queries, angles), self.rotate(keys, angles)

__call__(queries, keys, positions)

Apply WIRE to queries and keys.

Parameters:

Name Type Description Default
queries Float[Array, 'num_nodes embedding_size']

Query vectors (num_nodes, embedding_size)

required
keys Float[Array, 'num_nodes embedding_size']

Key vectors (num_nodes, embedding_size)

required
positions Float[Array, 'num_nodes num_features']

Node position features, e.g., RWSE (num_nodes, num_features)

required

Returns:

Type Description
tuple[Float[Array, 'num_nodes embedding_size'], Float[Array, 'num_nodes embedding_size']]

Rotated (queries, keys)

Source code in xlron/models/transformer.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def __call__(
    self,
    queries: Float[Array, "num_nodes embedding_size"],
    keys: Float[Array, "num_nodes embedding_size"],
    positions: Float[Array, "num_nodes num_features"],
) -> tuple[
    Float[Array, "num_nodes embedding_size"],
    Float[Array, "num_nodes embedding_size"],
]:
    """Apply WIRE to queries and keys.

    Args:
        queries: Query vectors (num_nodes, embedding_size)
        keys: Key vectors (num_nodes, embedding_size)
        positions: Node position features, e.g., RWSE (num_nodes, num_features)

    Returns:
        Rotated (queries, keys)
    """
    angles = self.get_angles(positions)
    return self.rotate(queries, angles), self.rotate(keys, angles)

__init__(num_features, embedding_size, key, freq_scale=0.01)

Parameters:

Name Type Description Default
num_features int

Dimension of input position features (m)

required
embedding_size int

Dimension of queries/keys to rotate (must be even)

required
key PRNGKeyArray

PRNG key

required
freq_scale float

Scale for frequency initialisation

0.01
Source code in xlron/models/transformer.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __init__(
    self,
    num_features: int,
    embedding_size: int,
    key: PRNGKeyArray,
    freq_scale: float = 0.01,
):
    """
    Args:
        num_features: Dimension of input position features (m)
        embedding_size: Dimension of queries/keys to rotate (must be even)
        key: PRNG key
        freq_scale: Scale for frequency initialisation
    """
    if embedding_size % 2 != 0:
        raise ValueError("embedding_size must be even")

    self.embedding_size = embedding_size

    # Project position features to angles
    # Output dim is embedding_size // 2 (one angle per 2D rotation block)
    self.freq_proj = eqx.nn.Linear(
        in_features=num_features,
        out_features=embedding_size // 2,
        use_bias=False,
        key=key,
    )

    # Optionally scale down initial frequencies for stability
    scaled_weight = self.freq_proj.weight * freq_scale
    self.freq_proj = eqx.tree_at(lambda layer: layer.weight, self.freq_proj, scaled_weight)

get_angles(positions)

Compute rotation angles from position features.

Source code in xlron/models/transformer.py
78
79
80
81
82
def get_angles(
    self, positions: Float[Array, "num_nodes num_features"]
) -> Float[Array, "num_nodes half_emb"]:
    """Compute rotation angles from position features."""
    return jax.vmap(self.freq_proj)(positions)

rotate(x, angles)

Apply rotary encoding to queries or keys.

For each 2D block [x_{2i}, x_{2i+1}], rotate by angle theta_i: x_{2i}' = x_{2i} * cos(theta) - x_{2i+1} * sin(theta) x_{2i+1}' = x_{2i} * sin(theta) + x_{2i+1} * cos(theta)

Source code in xlron/models/transformer.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def rotate(
    self,
    x: Float[Array, "num_nodes embedding_size"],
    angles: Float[Array, "num_nodes half_emb"],
) -> Float[Array, "num_nodes embedding_size"]:
    """Apply rotary encoding to queries or keys.

    For each 2D block [x_{2i}, x_{2i+1}], rotate by angle theta_i:
        x_{2i}'   = x_{2i} * cos(theta) - x_{2i+1} * sin(theta)
        x_{2i+1}' = x_{2i} * sin(theta) + x_{2i+1} * cos(theta)
    """
    # angles: (num_nodes, embedding_size // 2)
    cos_angles = jnp.cos(angles)
    sin_angles = jnp.sin(angles)

    # Repeat for pairs: [cos_0, cos_0, cos_1, cos_1, ...]
    cos_angles = jnp.repeat(cos_angles, 2, axis=-1)
    sin_angles = jnp.repeat(sin_angles, 2, axis=-1)

    # Rotate pairs: for indices [0,1], [2,3], etc.
    # x_rotated = x * cos + rotate_pairs(x) * sin
    # where rotate_pairs swaps and negates: [x0, x1] -> [-x1, x0]
    x_pairs = x.reshape(x.shape[0], -1, 2)  # (num_nodes, num_pairs, 2)
    x_rotated_pairs = jnp.stack([-x_pairs[..., 1], x_pairs[..., 0]], axis=-1)
    x_rotated = x_rotated_pairs.reshape(x.shape)

    return x * cos_angles + x_rotated * sin_angles

Heuristics

best_fit(state, params)

Best-Fit Spectrum Allocation. Returns the best fit slot for each path.

Source code in xlron/heuristics/heuristics.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
def best_fit(state: EnvState, params: RSAEnvParams) -> Tuple[chex.Array, chex.Array]:
    """Best-Fit Spectrum Allocation. Returns the best fit slot for each path."""
    mask = get_action_mask(state, params)
    link_slot_array = jnp.where(state.link_slot_array < 0, 1.0, state.link_slot_array)
    nodes_sd, requested_bw = read_rsa_request(state.request_array)

    # We need to define a wrapper function in order to vmap with named arguments
    def _find_block_sizes(arr, starts_only=False, reverse=True):
        return jax.vmap(find_block_sizes, in_axes=(0, None, None))(arr, starts_only, reverse)

    block_sizes_right = _find_block_sizes(link_slot_array, starts_only=False, reverse=False)
    block_sizes_left = _find_block_sizes(link_slot_array, starts_only=False, reverse=True)
    block_sizes = jnp.maximum((block_sizes_left + block_sizes_right) - 1, 0)
    paths = get_paths(params, nodes_sd)
    se = (
        get_paths_se(params, nodes_sd)
        if params.consider_modulation_format
        else jnp.ones((params.k_paths,))
    )
    num_slots = jax.vmap(required_slots, in_axes=(None, 0, None, None))(
        requested_bw, se, params.slot_size, params.guardband
    )

    # Quantify how well the request fits within a free spectral block
    def get_bf_on_path(path, blocks, req_slots):
        fits = jax.vmap(lambda x: x - req_slots, in_axes=0)(blocks)
        fits = jnp.where(fits >= 0, fits, params.link_resources)
        path_fit = jnp.dot(path, fits) / jnp.sum(path)
        return path_fit

    fits_block = jax.vmap(lambda x, y, z: get_bf_on_path(x, y, z), in_axes=(0, None, 0))(
        paths, block_sizes, num_slots
    )

    # Quantity much of a gap there is between the assigned slots and the next occupied slots on the left
    def get_bf_on_path_left(path, blocks, req_slots):
        fits = jax.vmap(lambda x: x - req_slots, in_axes=0)(blocks)
        fits = jnp.where(fits >= 0, fits, params.link_resources)
        fits_shift = jax.lax.dynamic_slice(fits, (0, 1), (fits.shape[0], fits.shape[1] - 1))
        fits_shift = jnp.concatenate(
            (jnp.full((fits.shape[0], 1), params.link_resources), fits_shift), axis=1
        )
        fits = fits + 1 / jnp.maximum(fits_shift, 1)
        path_fit = jnp.dot(path, fits) / jnp.sum(path)
        return path_fit

    fits_left = jax.vmap(lambda x, y, z: get_bf_on_path_left(x, y, z), in_axes=(0, None, 0))(
        paths, block_sizes_left, num_slots
    )

    # Quantity much of a gap there is between the assigned slots and the next occupied slots on the right
    def get_bf_on_path_right(path, blocks, req_slots):
        fits = jax.vmap(lambda x: x - req_slots, in_axes=0)(blocks)
        fits = jnp.where(fits >= 0, fits, params.link_resources)
        fits_shift = jax.lax.dynamic_slice(fits, (0, 0), (fits.shape[0], fits.shape[1] - 1))
        fits_shift = jnp.concatenate(
            (fits_shift, jnp.full((fits.shape[0], 1), params.link_resources)), axis=1
        )
        fits = fits + 1 / jnp.maximum(fits_shift, 1)
        path_fit = jnp.dot(path, fits) / jnp.sum(path)
        return path_fit

    fits_right = jax.vmap(lambda x, y, z: get_bf_on_path_right(x, y, z), in_axes=(0, None, 0))(
        paths, block_sizes_right, num_slots
    )

    # Sum the contribution to the overall quality of fit, and scale down the left/right contributions
    fits = jnp.sum(
        jnp.stack(
            (fits_block, fits_left / params.link_resources, fits_right / params.link_resources),
            axis=0,
        ),
        axis=0,
    )
    # Mask out occupied lightpaths (in case the quality of fit on some links is good enough to be considered, even if the overall path is invalid)
    fits = jnp.where(mask == 0, jnp.inf, fits)
    best_slots = jnp.argmin(fits, axis=1)
    best_fits = jnp.min(fits, axis=1)
    return best_slots, best_fits

bf_ksp(state, params)

Get the first available slot from the first k-shortest paths Method: Go through action mask and find the first available slot on all paths

Parameters:

Name Type Description Default
state EnvState

Environment state

required
params EnvParams

Environment parameters

required

Returns:

Type Description
Array

chex.Array: Action

Source code in xlron/heuristics/heuristics.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
@partial(jax.jit, static_argnums=(1,))
def bf_ksp(state: EnvState, params: RSAEnvParams) -> chex.Array:
    """Get the first available slot from the first k-shortest paths
    Method: Go through action mask and find the first available slot on all paths

    Args:
        state (EnvState): Environment state
        params (EnvParams): Environment parameters

    Returns:
        chex.Array: Action
    """
    best_slots, fitness = best_fit(state, params)
    # Chosen path is the one with the best fit
    path_index = jnp.argmin(fitness)
    slot_index = best_slots[path_index] % params.link_resources
    # Convert indices to action
    action = path_index * params.link_resources + slot_index
    return action

ff_ksp(state, params)

Get the first available slot from all paths Method: Go through action mask and find the first available slot on all paths

Parameters:

Name Type Description Default
state EnvState

Environment state

required
params EnvParams

Environment parameters

required

Returns:

Type Description
Array

chex.Array: Action

Source code in xlron/heuristics/heuristics.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
@partial(jax.jit, static_argnums=(1,))
def ff_ksp(state: RSAEnvState, params: RSAEnvParams) -> chex.Array:
    """Get the first available slot from all paths
    Method: Go through action mask and find the first available slot on all paths

    Args:
        state (EnvState): Environment state
        params (EnvParams): Environment parameters

    Returns:
        chex.Array: Action
    """
    first_slots = first_fit(state, params)
    # Chosen path is the one with the lowest index of first available slot
    path_index = jnp.argmin(first_slots)
    slot_index = first_slots[path_index] % params.link_resources
    # Convert indices to action
    action = path_index * params.link_resources + slot_index
    return action

first_fit(state, params)

First-Fit Spectrum Allocation. Returns the first fit slot for each path.

When band_slot_order_ff is set (GN model envs with --band_preference), slots are searched in band preference order rather than raw index order.

Source code in xlron/heuristics/heuristics.py
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
def first_fit(state: EnvState, params: RSAEnvParams) -> chex.Array:
    """First-Fit Spectrum Allocation. Returns the first fit slot for each path.

    When band_slot_order_ff is set (GN model envs with --band_preference),
    slots are searched in band preference order rather than raw index order.
    """
    mask = get_action_mask(state, params)
    if isinstance(params, GNModelEnvParams) and len(params.band_slot_order_ff.val) > 0:
        order = params.band_slot_order_ff.val
        reordered = mask[:, order]
        reordered = jnp.concatenate((reordered, jnp.full((reordered.shape[0], 1), 1)), axis=1)
        idx = jnp.argmax(reordered, axis=1)
        safe_idx = jnp.clip(idx, 0, params.link_resources - 1)
        first_slots = jnp.where(idx < params.link_resources, order[safe_idx], params.link_resources)
    else:
        # Add a column of ones to make sure occupied paths have non-zero index in "first_slots"
        mask = jnp.concatenate((mask, jnp.full((mask.shape[0], 1), 1)), axis=1)
        first_slots = jnp.argmax(mask, axis=1)
    return first_slots

Get link weights based on occupancy for use in congestion-aware routing heuristics.

Parameters:

Name Type Description Default
state EnvState

Environment state

required
params EnvParams

Environment parameters

required

Returns:

Type Description

chex.Array: Link weights

Source code in xlron/heuristics/heuristics.py
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
def get_link_weights(state: EnvState, params: RSAEnvParams):
    """Get link weights based on occupancy for use in congestion-aware routing heuristics.

    Args:
        state (EnvState): Environment state
        params (EnvParams): Environment parameters

    Returns:
        chex.Array: Link weights
    """
    if isinstance(params, RWALightpathReuseEnvParams):
        initial_path_capacity = init_path_capacity_array(
            params.link_length_array.val, params.path_link_array.val, scale_factor=1.0
        )
        initial_path_capacity = jnp.squeeze(
            jax.vmap(lambda x: initial_path_capacity[x])(state.path_index_array)
        )
        utilisation = (
            jnp.where(
                initial_path_capacity - state.link_capacity_array < 0,
                0,
                initial_path_capacity - state.link_capacity_array,
            )
            / initial_path_capacity
        )
        link_occupancy = jnp.sum(utilisation, axis=1)
    else:
        link_occupancy = jnp.count_nonzero(state.link_slot_array, axis=1)
    link_weights = jnp.multiply(
        params.link_length_array.val.T, (1 / (1 - link_occupancy / (params.link_resources + 1)))
    )[0]
    return link_weights

kca_ff(state, params)

Congestion-aware First Fit. Only suitable for RSA/RMSA. Method:

Source code in xlron/heuristics/heuristics.py
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
@partial(jax.jit, static_argnums=(1,))
def kca_ff(state: EnvState, params: RSAEnvParams) -> chex.Array:
    """Congestion-aware First Fit. Only suitable for RSA/RMSA.
    Method:

    """
    mask = get_action_mask(state, params)
    # Get index of first available slots for each path
    first_slots = first_fit(state, params)
    # Get nodes
    nodes_sd, _ = read_rsa_request(state.request_array)
    # Initialise array to hold congestion on each path
    path_congestion_array = jnp.full((mask.shape[0],), 0.0)
    link_weights = get_link_weights(state, params)

    def get_path_congestion(i, val):
        # Get links on path
        path = get_paths(params, nodes_sd)[i]
        # Get congestion
        path_link_congestion = jnp.multiply(link_weights, path)
        path_congestion = jnp.sum(path_link_congestion).reshape((1,))
        return jax.lax.dynamic_update_slice(val, path_congestion, (i,))

    path_congestion_array = jax.lax.fori_loop(
        0, mask.shape[0], get_path_congestion, path_congestion_array
    )
    path_index = jnp.argmin(path_congestion_array)
    slot_index = first_slots[path_index] % params.link_resources
    action = path_index * params.link_resources + slot_index
    return action

kmc_ff(state, params)

K-Minimum Cut. Only suitable for RSA/RMSA. Method: 1. Go through action mask and find the first available slot on all paths. 2. For each path, allocate the first available slot. 3. Sum number of new consecutive zero regions (cuts) created by assignment (on each link) 4. Choose path that creates the fewest cuts.

Source code in xlron/heuristics/heuristics.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
@partial(jax.jit, static_argnums=(1,))
def kmc_ff(state: EnvState, params: RSAEnvParams) -> chex.Array:
    """K-Minimum Cut. Only suitable for RSA/RMSA.
    Method:
    1. Go through action mask and find the first available slot on all paths.
    2. For each path, allocate the first available slot.
    3. Sum number of new consecutive zero regions (cuts) created by assignment (on each link)
    4. Choose path that creates the fewest cuts.
    """
    mask = get_action_mask(state, params)
    first_slots = first_fit(state, params)
    link_slot_array = jnp.where(state.link_slot_array < 0, 1.0, state.link_slot_array)
    nodes_sd, requested_bw = read_rsa_request(state.request_array)
    block_sizes = jax.vmap(find_block_sizes, in_axes=(0,))(link_slot_array)
    block_sizes_mask = jnp.where(
        block_sizes > 0, 1, 0.0
    )  # Binary array showing initial block starts
    block_count = jnp.sum(block_sizes_mask, axis=1)

    def get_cuts_on_path(i, result):
        initial_slot_index = first_slots[i] % params.link_resources
        path = get_paths(params, nodes_sd)[i]
        se = get_paths_se(params, nodes_sd)[i] if params.consider_modulation_format else 1
        num_slots = required_slots(requested_bw, se, params.slot_size, guardband=params.guardband)
        affected_slots_mask = get_affected_slots_mask(initial_slot_index, num_slots, path, params)
        # Make link-slot_array positive
        updated_slots = set_path_links(link_slot_array, affected_slots_mask, 1.0)
        updated_block_sizes = jax.vmap(find_block_sizes, in_axes=(0,))(updated_slots)
        updated_block_sizes_mask = jnp.where(
            updated_block_sizes > 0, 1, 0
        )  # Binary array showing updated block starts
        updated_block_count = jnp.sum(updated_block_sizes_mask, axis=1)
        num_cuts = jax.lax.cond(
            mask[i][initial_slot_index] == 0.0,  # If true, no valid action for path
            lambda x: jnp.full((1,), params.link_resources * params.num_links).astype(
                jnp.float32
            ),  # Return max no. of cuts
            lambda x: jnp.sum(jnp.maximum(updated_block_count - block_count, 0.0)).reshape(
                (1,)
            ),  # Else, return number of cuts
            1.0,
        )
        result = jax.lax.dynamic_update_slice(result, num_cuts, (i,))
        return result

    # Initialise array to hold number of cuts on each path
    path_cuts_array = jnp.full((mask.shape[0],), 0.0)
    path_cuts_array = jax.lax.fori_loop(0, mask.shape[0], get_cuts_on_path, path_cuts_array)
    path_index = jnp.argmin(path_cuts_array)
    slot_index = first_slots[path_index] % params.link_resources
    # Convert indices to action
    action = path_index * params.link_resources + slot_index
    return action

kme_ff(state, params)

K-Minimum Entropy. Only suitable for RSA/RMSA. Method: 1. Go through action mask and find the first available slot on all paths. 2. For each path, allocate the first available slot. 3. Sum number of new consecutive zero regions (cuts) created by assignment (on each link) 4. Choose path that creates the fewest cuts.

Source code in xlron/heuristics/heuristics.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
@partial(jax.jit, static_argnums=(1,))
def kme_ff(state: EnvState, params: RSAEnvParams) -> chex.Array:
    """K-Minimum Entropy. Only suitable for RSA/RMSA.
    Method:
    1. Go through action mask and find the first available slot on all paths.
    2. For each path, allocate the first available slot.
    3. Sum number of new consecutive zero regions (cuts) created by assignment (on each link)
    4. Choose path that creates the fewest cuts.
    """
    mask = get_action_mask(state, params)
    first_slots = first_fit(state, params)
    link_slot_array = jnp.where(state.link_slot_array < 0, 1.0, state.link_slot_array)
    nodes_sd, requested_bw = read_rsa_request(state.request_array)
    max_entropy = jnp.sum(jnp.log(params.link_resources)) * params.num_links

    def get_link_entropy(blocks):
        ent = jax.vmap(
            lambda x: jnp.sum(x / params.link_resources * jnp.log(params.link_resources / x)),
            in_axes=0,
        )(blocks)
        return jnp.sum(jnp.where(blocks > 0, ent, 0))

    def get_entropy_on_path(i, result):
        initial_slot_index = first_slots[i] % params.link_resources
        path = get_paths(params, nodes_sd)[i]
        se = get_paths_se(params, nodes_sd)[i] if params.consider_modulation_format else 1
        num_slots = required_slots(requested_bw, se, params.slot_size, guardband=params.guardband)
        affected_slots_mask = get_affected_slots_mask(initial_slot_index, num_slots, path, params)
        # Make link-slot_array positive
        updated_slots = set_path_links(link_slot_array, affected_slots_mask, 1.0)
        updated_block_sizes = jax.vmap(find_block_sizes, in_axes=(0,))(updated_slots)
        updated_entropy = jax.vmap(get_link_entropy, in_axes=(0,))(updated_block_sizes)
        new_path_entropy = jnp.sum(jnp.dot(path, updated_entropy)).reshape((1,))
        new_path_entropy = jax.lax.cond(
            mask[i][initial_slot_index] == 0.0,  # If true, no valid action for path
            lambda x: max_entropy.astype(jnp.float32).reshape((1,)),  # Return maximum entropy
            lambda x: new_path_entropy,  # Else, return number of cuts
            1.0,
        )
        result = jax.lax.dynamic_update_slice(result, new_path_entropy, (i,))
        return result

    path_entropy_array = jnp.full((mask.shape[0],), 0.0)
    path_entropy_array = jax.lax.fori_loop(
        0, mask.shape[0], get_entropy_on_path, path_entropy_array
    )
    path_index = jnp.argmin(path_entropy_array)
    slot_index = first_slots[path_index] % params.link_resources
    # Convert indices to action
    action = path_index * params.link_resources + slot_index
    return action

kmf_ff(state, params)

K-Minimum Frag-size. Only suitable for RSA/RMSA. Method: 1. Go through action mask and find the first available slot on all paths. 2. For each path, allocate the first available slot. 3. Sum number of new consecutive zero regions (cuts) created by assignment (on each link) 4. Choose path that creates the fewest cuts.

Source code in xlron/heuristics/heuristics.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
@partial(jax.jit, static_argnums=(1,))
def kmf_ff(state: RSAEnvState, params: RSAEnvParams) -> chex.Array:
    """K-Minimum Frag-size. Only suitable for RSA/RMSA.
    Method:
    1. Go through action mask and find the first available slot on all paths.
    2. For each path, allocate the first available slot.
    3. Sum number of new consecutive zero regions (cuts) created by assignment (on each link)
    4. Choose path that creates the fewest cuts.
    """
    mask = get_action_mask(state, params)
    first_slots = first_fit(state, params)
    link_slot_array = jnp.where(state.link_slot_array < 0, 1.0, state.link_slot_array)
    nodes_sd, requested_bw = read_rsa_request(state.request_array)
    blocks = jax.vmap(find_block_sizes, in_axes=(0,))(link_slot_array)

    def get_frags_on_path(i, result):
        initial_slot_index = first_slots[i] % params.link_resources
        path = get_paths(params, nodes_sd)[i]
        se = get_paths_se(params, nodes_sd)[i] if params.consider_modulation_format else 1
        num_slots = required_slots(requested_bw, se, params.slot_size, guardband=params.guardband)
        affected_slots_mask = get_affected_slots_mask(initial_slot_index, num_slots, path, params)
        # Mask on path links
        block_sizes = jax.vmap(lambda x, y: jnp.where(x > 0, y, 0.0), in_axes=(0, 0))(path, blocks)
        updated_slots = set_path_links(state.link_slot_array, affected_slots_mask, -1)
        updated_block_sizes = jax.vmap(find_block_sizes, in_axes=(0,))(updated_slots)
        # Mask on path links
        updated_block_sizes = jax.vmap(lambda x, y: jnp.where(x > 0, y, 0.0), in_axes=(0, 0))(
            path, updated_block_sizes
        )
        difference = updated_block_sizes - block_sizes
        new_frags = jnp.where(difference != 0, block_sizes + difference, 0.0)
        # Slice new frags up to initial slot index (so as to only consider frags to the left)
        new_frags = jnp.where(
            jnp.arange(params.link_resources) < initial_slot_index, new_frags, 0.0
        )
        new_frag_size = jnp.sum(new_frags)
        num_frags = jax.lax.cond(
            mask[i][initial_slot_index] == 0.0,  # If true, no valid action for path
            lambda x: jnp.full(
                (1,), float(params.link_resources * params.num_links)
            ),  # Return max frag size
            lambda x: new_frag_size.reshape((1,)),
            # Else, return number of cuts
            1.0,
        )
        result = jax.lax.dynamic_update_slice(result, num_frags, (i,))
        return result

    # Initialise array to hold number of cuts on each path
    path_frags_array = jnp.full((mask.shape[0],), 0.0)
    path_frags_array = jax.lax.fori_loop(0, mask.shape[0], get_frags_on_path, path_frags_array)
    path_index = jnp.argmin(path_frags_array)
    slot_index = first_slots[path_index] % params.link_resources
    # Convert indices to action
    action = path_index * params.link_resources + slot_index
    return action

ksp_bf(state, params)

Get the first available slot from all k-shortest paths Method: Go through action mask and find the first available slot, starting from shortest path

Parameters:

Name Type Description Default
state EnvState

Environment state

required
params EnvParams

Environment parameters

required

Returns:

Type Description
Array

chex.Array: Action

Source code in xlron/heuristics/heuristics.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
@partial(jax.jit, static_argnums=(1,))
def ksp_bf(state: EnvState, params: RSAEnvParams) -> chex.Array:
    """Get the first available slot from all k-shortest paths
    Method: Go through action mask and find the first available slot, starting from shortest path

    Args:
        state (EnvState): Environment state
        params (EnvParams): Environment parameters

    Returns:
        chex.Array: Action
    """
    best_slots, fitness = best_fit(state, params)
    # Chosen path is the first one with an available slot
    path_index = jnp.argmin(jnp.where(fitness < jnp.inf, 0, 1))
    slot_index = best_slots[path_index] % params.link_resources
    # Convert indices to action
    action = path_index * params.link_resources + slot_index
    return action

ksp_ff(state, params)

Get the first available slot from the shortest available path Method: Go through action mask and find the first available slot, starting from shortest path

Parameters:

Name Type Description Default
state EnvState

Environment state

required
params EnvParams

Environment parameters

required

Returns:

Type Description
Array

chex.Array: Action

Source code in xlron/heuristics/heuristics.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
@partial(jax.jit, static_argnums=(1,))
def ksp_ff(state: RSAEnvState, params: RSAEnvParams) -> chex.Array:
    """Get the first available slot from the shortest available path
    Method: Go through action mask and find the first available slot, starting from shortest path

    Args:
        state (EnvState): Environment state
        params (EnvParams): Environment parameters

    Returns:
        chex.Array: Action
    """
    first_slots = first_fit(state, params)
    # Chosen path is the first one with an available slot
    path_index = jnp.argmax(first_slots < params.link_resources)
    slot_index = first_slots[path_index] % params.link_resources
    # Convert indices to action
    action = path_index * params.link_resources + slot_index
    return action

ksp_ff_multiband(state, params)

Get the first available slot from all k-shortest paths in multiband scenario Method: Go through action mask and find the first available slot, starting from shortest path

Parameters:

Name Type Description Default
state MultiBandRSAEnvState

Environment state specific to multiband operations

required
params MultiBandRSAEnvParams

Environment parameters including multiband details

required

Returns: chex.Array: Action

Source code in xlron/heuristics/heuristics.py
56
57
58
59
60
61
62
63
64
65
66
def ksp_ff_multiband(state: EnvState, params: RSAEnvParams) -> None:
    """Get the first available slot from all k-shortest paths in multiband scenario
    Method: Go through action mask and find the first available slot, starting from shortest path

    Args:
        state (MultiBandRSAEnvState): Environment state specific to multiband operations
        params (MultiBandRSAEnvParams): Environment parameters including multiband details
    Returns:
        chex.Array: Action
    """
    pass

ksp_lf(state, params)

Get the last available slot on the shortest available path Method: Go through action mask and find the last available slot, starting from shortest path

Parameters:

Name Type Description Default
state EnvState

Environment state

required
params EnvParams

Environment parameters

required

Returns:

Type Description
Array

chex.Array: Action

Source code in xlron/heuristics/heuristics.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@partial(jax.jit, static_argnums=(1,))
def ksp_lf(state: RSAEnvState, params: RSAEnvParams) -> chex.Array:
    """Get the last available slot on the shortest available path
    Method: Go through action mask and find the last available slot, starting from shortest path

    Args:
        state (EnvState): Environment state
        params (EnvParams): Environment parameters

    Returns:
        chex.Array: Action
    """
    last_slots = last_fit(state, params)
    # Chosen path is the first one with an available slot
    path_index = jnp.argmax(last_slots < params.link_resources)
    slot_index = last_slots[path_index] % params.link_resources
    # Convert indices to action
    action = path_index * params.link_resources + slot_index
    return action

ksp_mu(state, params, unique_lightpaths, relative)

Get the most-used slot on the shortest available path. Method: Go through action mask and find the utilisation of available slots on each path. Find the shortest available path and choose the most utilised slot on that path.

Parameters:

Name Type Description Default
state EnvState

Environment state

required
params EnvParams

Environment parameters

required
unique_lightpaths bool

Whether to consider unique lightpaths

required
relative bool

Whether to return relative utilisation

required

Returns:

Type Description
Array

chex.Array: Action

Source code in xlron/heuristics/heuristics.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
@partial(jax.jit, static_argnums=(1, 2, 3))
def ksp_mu(
    state: EnvState, params: RSAEnvParams, unique_lightpaths: bool, relative: bool
) -> chex.Array:
    """Get the most-used slot on the shortest available path.
    Method: Go through action mask and find the utilisation of available slots on each path.
    Find the shortest available path and choose the most utilised slot on that path.

    Args:
        state (EnvState): Environment state
        params (EnvParams): Environment parameters
        unique_lightpaths (bool): Whether to consider unique lightpaths
        relative (bool): Whether to return relative utilisation

    Returns:
        chex.Array: Action
    """
    mask = get_action_mask(state, params)
    most_used_slots = most_used(state, params, unique_lightpaths, relative)
    # Get usage of available slots
    most_used_mask = most_used_slots * mask
    # Get index of most-used available slot for each path
    most_used_slots = jnp.argmax(most_used_mask, axis=1).astype(jnp.int32)
    # Chosen path is the first one with an available slot
    available_paths = jnp.max(mask, axis=1)
    path_index = jnp.argmax(available_paths)
    slot_index = most_used_slots[path_index] % params.link_resources
    # Convert indices to action
    action = path_index * params.link_resources + slot_index
    return action

last_fit(state, params)

Last-Fit Spectrum Allocation. Returns the last fit slot for each path.

When band_slot_order_lf is set (GN model envs with --band_preference), slots are searched in band preference order (descending within each band).

Source code in xlron/heuristics/heuristics.py
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
def last_fit(state: EnvState, params: RSAEnvParams) -> chex.Array:
    """Last-Fit Spectrum Allocation. Returns the last fit slot for each path.

    When band_slot_order_lf is set (GN model envs with --band_preference),
    slots are searched in band preference order (descending within each band).
    """
    mask = get_action_mask(state, params)
    if isinstance(params, GNModelEnvParams) and len(params.band_slot_order_lf.val) > 0:
        order = params.band_slot_order_lf.val
        reordered = mask[:, order]
        reordered = jnp.concatenate((reordered, jnp.full((reordered.shape[0], 1), 1)), axis=1)
        idx = jnp.argmax(reordered, axis=1)
        safe_idx = jnp.clip(idx, 0, params.link_resources - 1)
        last_slots = jnp.where(idx < params.link_resources, order[safe_idx], params.link_resources)
    else:
        # Add a column of ones to make sure occupied paths have non-zero index in "last_slots"
        mask = jnp.concatenate((jnp.full((mask.shape[0], 1), 1), mask), axis=1)
        last_slots = jnp.argmax(mask[:, ::-1], axis=1)
        last_slots = params.link_resources - last_slots - 1
    return last_slots

lf_ksp(state, params)

Get the last available slot from all paths Method: Go through action mask and find the last available slot on all paths

Parameters:

Name Type Description Default
state EnvState

Environment state

required
params EnvParams

Environment parameters

required

Returns:

Type Description
Array

chex.Array: Action

Source code in xlron/heuristics/heuristics.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
@partial(jax.jit, static_argnums=(1,))
def lf_ksp(state: EnvState, params: RSAEnvParams) -> chex.Array:
    """Get the last available slot from all paths
    Method: Go through action mask and find the last available slot on all paths

    Args:
        state (EnvState): Environment state
        params (EnvParams): Environment parameters

    Returns:
        chex.Array: Action
    """
    last_slots = last_fit(state, params)
    # Chosen path is the one with the highest index of last available slot
    path_index = jnp.argmax(last_slots)
    slot_index = last_slots[path_index] % params.link_resources
    # Convert indices to action
    action = path_index * params.link_resources + slot_index
    return action

most_used(state, params, unique_lightpaths, relative)

Get the amount of utilised bandwidth on each lightpath. If RWA-LR environment, the utilisation of a slot is defined by either the count of unique active lightpahts on the slot (if unique_lightpaths is True) or the count of active lightpaths on the slot (if unique_lightpaths is False). If RSA-type environment, utilisation is the count of active lightpaths on that slot.

Parameters:

Name Type Description Default
state EnvState

Environment state

required
params EnvParams

Environment parameters

required
unique_lightpaths bool

Whether to consider unique lightpaths

required
relative bool

Whether to return relative utilisation

required

Returns:

Type Description
Array

chex.Array: Most used slots (array length = link_resources)

Source code in xlron/heuristics/heuristics.py
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
@partial(jax.jit, static_argnums=(1, 2, 3))
def most_used(state: EnvState, params: RSAEnvParams, unique_lightpaths, relative) -> chex.Array:
    """Get the amount of utilised bandwidth on each lightpath.
    If RWA-LR environment, the utilisation of a slot is defined by either the count of unique active lightpahts on the
    slot (if unique_lightpaths is True) or the count of active lightpaths on the slot (if unique_lightpaths is False).
    If RSA-type environment, utilisation is the count of active lightpaths on that slot.

    Args:
        state (EnvState): Environment state
        params (EnvParams): Environment parameters
        unique_lightpaths (bool): Whether to consider unique lightpaths
        relative (bool): Whether to return relative utilisation

    Returns:
        chex.Array: Most used slots (array length = link_resources)
    """
    if isinstance(params, RWALightpathReuseEnvParams) and unique_lightpaths:
        most_used_slots = jnp.count_nonzero(state.path_index_array + 1, axis=0) + 1
    elif isinstance(params, RWALightpathReuseEnvParams) and not unique_lightpaths:
        # Get initial path capacity
        initial_path_capacity = init_path_capacity_array(
            params.link_length_array.val, params.path_link_array.val, scale_factor=1.0
        )
        initial_path_capacity = jnp.squeeze(
            jax.vmap(lambda x: initial_path_capacity[x])(state.path_index_array)
        )
        utilisation = jnp.where(
            initial_path_capacity - state.link_capacity_array < 0,
            0,
            initial_path_capacity - state.link_capacity_array,
        )
        if relative:
            utilisation = utilisation / initial_path_capacity
        # Get most used slots by summing the utilisation along the slots
        most_used_slots = jnp.sum(utilisation, axis=0) + 1
    else:
        most_used_slots = jnp.count_nonzero(state.link_slot_array, axis=0) + 1
    return most_used_slots

mu_ksp(state, params, unique_lightpaths, relative)

Use the most-used available slot on any path. The most-used slot is that which has the most unique lightpaths (if unique_lightpaths=True) or active lightpaths. Method: Go through action mask and find the usage of available slots, choose available slot that is most utilised.

Parameters:

Name Type Description Default
state EnvState

Environment state

required
params EnvParams

Environment parameters

required
unique_lightpaths bool

Whether to consider unique lightpaths

required
relative bool

Whether to return relative utilisation

required

Returns:

Type Description
Array

chex.Array: Action

Source code in xlron/heuristics/heuristics.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
@partial(jax.jit, static_argnums=(1, 2, 3))
def mu_ksp(
    state: EnvState, params: RSAEnvParams, unique_lightpaths: bool, relative: bool
) -> chex.Array:
    """Use the most-used available slot on any path.
    The most-used slot is that which has the most unique lightpaths (if unique_lightpaths=True) or active lightpaths.
    Method: Go through action mask and find the usage of available slots, choose available slot that is most utilised.

    Args:
        state (EnvState): Environment state
        params (EnvParams): Environment parameters
        unique_lightpaths (bool): Whether to consider unique lightpaths
        relative (bool): Whether to return relative utilisation

    Returns:
        chex.Array: Action
    """
    mask = get_action_mask(state, params)
    # Get most used slots by summing the link_slot_array along the links
    most_used_slots = most_used(state, params, unique_lightpaths, relative)
    # Get usage of available slots
    most_used_mask = most_used_slots * mask
    # Chosen slot is the most used globally
    action = jnp.argmax(most_used_mask)
    return action