Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions avalanche/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
from .cpu_usage import *
from .disk_usage import *
from .forgetting_bwt import *
from .forward_transfer import *
from .gpu_usage import *
from .loss import *
from .mac import *
Expand Down
157 changes: 2 additions & 155 deletions avalanche/evaluation/metrics/forgetting_bwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,122 +447,7 @@ def __str__(self):
return "StreamForgetting"


class GenericTaskForgetting(PluginMetric[Dict[int, float]]):
"""
The GenericTaskForgetting metric, describing the average evaluation
change in the desired metric detected over all tasks observed
during training and evaluation.

In particular, the user should override:
* __init__ by calling `super` and instantiating the
`self.current_train_metric` and `current_eval_metric` properties as
valid avalanche metrics (the same metric for both properties).
The metric should be able to return values for each task separately.
* `metric_update`, to update `current_metric`
* `__str__` to define the experience forgetting name.

This plugin metric, computed over all observed experiences during training,
is the average over the difference between the metric result obtained
after first training on a experience and the metric result obtained
on the same experience at the end of successive experiences.

This metric is computed during the eval phase only.
"""
def __init__(self):
super().__init__()
self.forgetting = Forgetting()
self._current_train_metric = None
self._current_eval_metric = None

def reset(self, **kwargs) -> None:
self.forgetting.reset()

def result(self, **kwargs):
return self.forgetting.result()

def update(self, k, v, initial):
self.forgetting.update(k, v, initial=initial)

def before_training(self, strategy: 'BaseStrategy'):
self._current_train_metric.reset()

def after_training_iteration(self, strategy: 'BaseStrategy'):
super().after_training_iteration(strategy)
try:
unique_tasks = strategy.mb_task_id.unique()
for t in unique_tasks:
self._current_train_metric.reset(t.item())
except AssertionError:
self._current_train_metric.reset()
self.metric_update(strategy, train=True)

def before_eval(self, strategy: 'BaseStrategy'):
self.forgetting.reset_last()
for k, v in self._current_train_metric.result().items():
self.update(k, v, initial=True)
self._current_eval_metric.reset()

def after_eval_iteration(self, strategy: 'BaseStrategy'):
super().after_eval_iteration(strategy)
self.metric_update(strategy, train=False)

def after_eval(self, strategy: 'BaseStrategy') -> 'MetricResult':
for k, v in self._current_eval_metric.result().items():
self.update(k, v, initial=False)
return self._package_result(strategy)

def _package_result(self, strategy: 'BaseStrategy') -> \
MetricResult:
metric_value = self.result()
plot_x_position = self.get_global_counter()
results = []
for k, v in metric_value.items():
metric_name = get_metric_name(self, strategy, add_experience=False,
add_task=k)
results.append(MetricValue(self, metric_name, v, plot_x_position))

return results

def metric_update(self, strategy, train):
raise NotImplementedError

def __str__(self):
raise NotImplementedError


class TaskForgetting(GenericTaskForgetting):
"""
The Task Forgetting metric returns the amount of forgetting
on each task separately. The task-wise forgetting is computed
as the difference between the average accuracy when last training
on the task and the average accuracy when last evaluating on the same task.

"""
def __init__(self):
super().__init__()
self._current_train_metric = Accuracy()
self._current_eval_metric = Accuracy()

def metric_update(self, strategy, train):
# task labels defined for each experience
task_labels = strategy.experience.task_labels
if len(task_labels) > 1:
# task labels defined for each pattern
task_labels = strategy.mb_task_id
else:
task_labels = task_labels[0]
if train:
self._current_train_metric.update(strategy.mb_output, strategy.mb_y,
task_labels)
else:
self._current_eval_metric.update(strategy.mb_output, strategy.mb_y,
task_labels)

def __str__(self):
return "TaskForgetting"


def forgetting_metrics(*, experience=False, stream=False, task=False) \
def forgetting_metrics(*, experience=False, stream=False) \
-> List[PluginMetric]:
"""
Helper method that can be used to obtain the desired set of
Expand All @@ -573,8 +458,6 @@ def forgetting_metrics(*, experience=False, stream=False, task=False) \
:param stream: If True, will return a metric able to log
the forgetting averaged over the evaluation stream experiences,
which have been observed during training.
:param task: If True, will return a metric able to log the forgetting
across each task encountered during training and evaluation.

:return: A list of plugin metrics.
"""
Expand All @@ -587,9 +470,6 @@ def forgetting_metrics(*, experience=False, stream=False, task=False) \
if stream:
metrics.append(StreamForgetting())

if task:
metrics.append(TaskForgetting())

return metrics


Expand Down Expand Up @@ -699,32 +579,7 @@ def __str__(self):
return "StreamBWT"


class TaskBWT(TaskForgetting):
"""
The TaskBWT metric, emitting the average BWT task-wise.

This plugin metric, computed over all observed tasks during training,
is the average over the difference between the last accuracy result
obtained on a task and the accuracy result obtained when last
training on that task.

This metric is computed during the eval phase only.
"""

def result(self) -> Union[float, None, Dict[int, float]]:
"""
Result for experience defined by a key.
See `BWT` documentation for more detailed information.

"""
forgetting = super().result()
return forgetting_to_bwt(forgetting)

def __str__(self):
return "TaskBWT"


def bwt_metrics(*, experience=False, stream=False, task=False) \
def bwt_metrics(*, experience=False, stream=False) \
-> List[PluginMetric]:
"""
Helper method that can be used to obtain the desired set of
Expand All @@ -735,8 +590,6 @@ def bwt_metrics(*, experience=False, stream=False, task=False) \
:param stream: If True, will return a metric able to log
the backward transfer averaged over the evaluation stream experiences
which have been observed during training.
:param task: If True, will return a metric able to log
the backward transfer for each task in the evaluation stream
:return: A list of plugin metrics.
"""

Expand All @@ -748,24 +601,18 @@ def bwt_metrics(*, experience=False, stream=False, task=False) \
if stream:
metrics.append(StreamBWT())

if task:
metrics.append(TaskBWT())

return metrics


__all__ = [
'Forgetting',
'GenericExperienceForgetting',
'GenericStreamForgetting',
'GenericTaskForgetting',
'ExperienceForgetting',
'StreamForgetting',
'TaskForgetting',
'forgetting_metrics',
'BWT',
'ExperienceBWT',
'StreamBWT',
'TaskBWT',
'bwt_metrics'
]
Loading