Source code for network_gym_client.adapter

#Copyright(C) 2025 Network Simulation Solutions LLC
#SPDX-License-Identifier: Apache-2.0
#File : adapter.py

import sys
import wandb
import numpy as np
import pandas as pd
import json
import time

from threading import Thread
from rich.live import Live
from rich.layout import Layout

from rich.ansi import AnsiDecoder
from rich.console import Group
from rich.jupyter import JupyterMixin
from rich.text import Text
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TaskProgressColumn, TimeRemainingColumn

import plotext as plt

[docs]class Adapter: """The base class for environment data format adapter. This class is an data format "adapter" between the gymnasium environment and network_gym environment. It transforms the network stats measurements (json) to obs and reward (Spaces). It also transforms the action (Spaces) to a policy (json) that can be applied to the network. """ def __init__(self, config_json): """Initialize Adapter. Args: config_json (json): the configuration file """ self.config_json = None self.wandb_log_buffer = None self.wandb = wandb self.config_json = config_json self.action_data_format = None rl_alg = config_json['rl_config']['agent'] config = { "policy_type": "MlpPolicy", "env_id": "network_gym_client", "RL_algo" : rl_alg } if config_json["enable_wandb"]: self.wandb.init( name=config_json['env_config']['env'] + "::" + rl_alg, project="network_gym_client", config=config, sync_tensorboard=True, # auto-upload sb3's tensorboard metrics # save_code=True, # optional ) def initial_rich_thread(self): self.layout = None if self.config_json["enable_terminal_redering"]: rich_thread = Thread(target=self.rich_live) rich_thread.daemon = True # Dies when main thread (only non-daemon thread) exits. rich_thread.start() def wandb_log_buffer_append (self, info): """Add to wandb log buffer, the info will be send to wandb later in the :meth:`wandb_log` function Args: info (dict): information to append to the buffer. """ if info: # info not empty! if not self.wandb_log_buffer: self.wandb_log_buffer = info else: self.wandb_log_buffer.update(info) def wandb_log (self): """Send the log information to WanDB. """ # send info to wandb #print(self.wandb_log_buffer) if self.config_json["enable_wandb"]: self.wandb.log(self.wandb_log_buffer) self.wandb_log_buffer = None def df_to_dict(self, df, id_name='id'): """Transform datatype from pandas.dataframe to dictionary. Args: df (pandas.dataframe): a pandas.dataframe object description (string): a descritption for the data Returns: dict : converted data with dictionary format """ if df is None: return {} description = df['source'] + "::" + df['name'] get_key = lambda u: description+"::"+id_name+f'={u}' dict_key = list(map(get_key, df['id'])) dict_value = df['value'] #print(dict_key) #print(dict_value) data = dict(zip(dict_key, dict_value)) #print(data) return data def fill_empty_feature(self, feature, value): """Fill the missing measurements with a input value Args: feature (pd.DataFrame): feature from the measurement value (int): the value to fill the missing measurement Returns: list: results after replace missing measurements with value """ #Fill the missing data with the input value. if feature is None: print("[WARNING] all users of a feature returns empty measurement.") emptyFeatureArray = np.empty([self.config_json['env_config']['num_users'],], dtype=int) emptyFeatureArray.fill(value) return emptyFeatureArray if len(feature['user']) > self.config_json['env_config']['num_users']: print ("[WARNING] This feature has more user than input!!") print (feature) emptyFeatureArray = np.empty([self.config_json['env_config']['num_users'],], dtype=int) emptyFeatureArray.fill(value) return emptyFeatureArray elif len(feature['user']) == self.config_json['env_config']['num_users']: # measurement size match the user number return feature["value"] elif len(feature['user'])> 0: # some of the user's data are missing, fill with input value. print("[WARNING] some users of a feature returns empty measurement.") #feature = feature.set_index("user") #feature = feature.reindex(list(range(self.config_json['env_config']['num_users'])),fill_value=value) #data = feature[:]["value"] data = np.empty([self.config_json['env_config']['num_users'],], dtype=int) data.fill(value) for index, user_id in enumerate(feature['user']): data[user_id] = feature['value'][index] #print(data) return data else: # all user's data are missing, fill the entire feature will input value. print("[WARNING] all users of a feature returns empty measurement.") emptyFeatureArray = np.empty([self.config_json['env_config']['num_users'],], dtype=int) emptyFeatureArray.fill(value) return emptyFeatureArray def rich_live(self): self.layout = self.make_layout() header = self.layout['header'] title = plt.colorize(self.config_json['rl_config']['agent'], "cyan+", "bold" ) + " agent is interacting with NetworkGym's " + plt.colorize(self.config_json['env_config']['env'], "cyan+", "bold") + " env." header.update(Text(title, justify = "left")) overall_progress = Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), TimeRemainingColumn(), TextColumn("env time: {task.completed}/{task.total} ms"), ) overall_task = overall_progress.add_task("Progress", total=self.config_json['env_config']['env_end_time_ms']) footer = self.layout["footer"] footer.update(overall_progress) self.layout["main"].visible=False step_length = self.config_json['env_config']['measurement_interval_ms'] if 'measurement_guard_interval_ms' in self.config_json['env_config']: step_length = step_length + self.config_json['env_config']['measurement_guard_interval_ms'] with Live(self.layout, refresh_per_second=1, redirect_stderr=False) as live: while True: time.sleep(1) if self.action_data_format is not None: overall_progress.update(overall_task, completed=self.action_data_format["ts"]+step_length) #live.refresh() def make_layout(self): layout = Layout(name="root") layout.split( Layout(name="header", size=1), Layout(name="main", ratio=1), Layout(name="footer", size = 1), ) layout["main"].split_row( Layout(name="left", ratio = 1), Layout(name="right", ratio = 1), ) return layout class plotextMixin(JupyterMixin): def __init__(self, plot_function): self.decoder = AnsiDecoder() self.plot_function = plot_function def __rich_console__(self, console, options): self.width = options.max_width or console.width self.height = options.height or console.height canvas = self.plot_function(self.width, self.height) self.rich_canvas = Group(*self.decoder.decode(canvas)) yield self.rich_canvas