Source code for jammy.cli.cmdline_viz
from .colors import COLORS
from tqdm.auto import tqdm
from collections import defaultdict
from jammy.utils.meter import GroupMeters
[docs]class CmdLineViz:
def __init__(self):
self.meter = defaultdict(GroupMeters)
self.prev_mean = {}
[docs] def update(self, mode, eval_dict):
self.meter[mode].update(eval_dict)
[docs] def flush(self):
if len(self.meter) == 0:
return
_str = "\n" + "\t" * 1 + "=== Summary ===\n"
keys = []
color_write = {}
for mode, meters in self.meter.items():
avg = meters.avg
color_write[mode] = defaultdict(lambda: " " * 10)
keys.extend(list(avg.keys()))
for key, value in avg.items():
color = COLORS.White
if mode + key in self.prev_mean:
color = (
COLORS.Green
if value > self.prev_mean[mode + key]
else COLORS.Red
)
color_write[mode][key] = f"{color}{value:10.4f}{COLORS.END_NO_TOKEN}"
self.prev_mean[mode + key] = value
keys = set(keys)
_str += f"\t{'mode':<10}"
for cur_key in self.meter:
_str += f" -- {cur_key:>10}"
_str += "\n"
for cur_key in keys:
_str += f"\t{cur_key:10}"
for _, writer in color_write.items():
_str += f" -- {writer[cur_key]}"
_str += f"\n"
tqdm.write(_str + "\n")
self.meter = defaultdict(GroupMeters)
if __name__ == "__main__":
import numpy as np
N = 10
line = CmdLineViz()
increase = np.arange(N) + np.random.randn(N)
decrease = -np.arange(N) + np.random.randn(N)
for cur_inc, cur_dec in zip(increase, decrease):
line.update("test", {"increase": cur_inc, "decrease": cur_dec})
line.flush()
for cur_inc, cur_dec in zip(decrease, increase):
line.update("test", {"increase": cur_inc, "decrease": cur_dec})
line.flush()