Source code for tf_al.metrics

import os, sys, csv

[docs]class Metrics: """ Uses the given path to create Prepares and writes metrics into a csv file. Parameters: base_path (str): The base path where to save the metrics. keys (list(str)): A list of keys. """ def __init__(self, base_path, keys=["accuracy", "loss"]): self.metric_keys = keys self.BASE_PATH = base_path self.EXT = "csv" # CSV Parameters self.delimiter = " " self.quotechar = "\"" self.quoting = csv.QUOTE_MINIMAL def write_line(self, filename, values): file_path = os.path.join(self.BASE_PATH, filename+"."+self.EXT) with open(file_path, "a", newline="") as csv_file: pass
[docs] def write(self, filename, values): """ Write given values into a csv file. Parameters: filename (str): The name of the file. values (list(dict)): A dictionary of metrics/values to write into a .csv file. """ file_path = os.path.join(self.BASE_PATH, filename+"."+self.EXT) with open(file_path, "w", newline="") as csv_file: # Setup csv file file_writer = csv.DictWriter( csv_file, delimiter=self.delimiter, quotechar=self.quotechar, quoting=self.quoting, fieldnames=self.metric_keys) # Create content of csv file file_writer.writeheader() for line in values: collected = self.collect(line) file_writer.writerow(collected)
[docs] def read(self, filename): """ Read a .csv file of metrics. Parameters: filename (str): The filename to read in. Returns: (list(dict)) a list of metric values, per trained iteration. """ values = [] if not ("."+self.EXT in filename): filename = filename + "." + self.EXT file_path = os.path.join(self.BASE_PATH, filename) with open(file_path, "r") as csv_file: reader = csv.DictReader( filter(lambda row: row[0] != "#", csv_file), delimiter=self.delimiter, quotechar=self.quotechar ) for row in reader: values.append(row) return values
# ------------- # Utilities # ------------------
[docs] def collect(self, values, keys=None): """ Collect metric values from a dictionary of values. Parameter: values (dict): A collection of values collected during training Returns: (dict) A subset of metrics extracted from the values. """ # Set default keys to use if keys is None: keys = self.metric_keys return {key: self.__prepare_value(value) for key, value in values.items() if key in keys}
def __prepare_value(self, value): """ Prevent's saving list of single values. """ if isinstance(value, list) and len(value) == 1: return value[0] return value # ------------- # Setter/-Getter # ------------------ def get_path(self): return self.BASE_PATH
def save_history(history, path, filename): """ Saves values of history to the path. """ metrics = Metric(path) metrics.write(history, filename) def read_history(path, filename): """ Reads values from the saved history. """ metrics = Metric(path) return metric.read(filename) def aggregates_per_key(history): """ Aggregate values per key """ if len(history) == 0: return history sample_entry = history[0] keys = list(sample_entry.keys()) aggregates = {key: [] for key in keys} for entry in history: for key, value in entry.items(): aggregates[key].append(value) return aggregates