WAUM¶
- class WAUM(tasks, answers, n_classes, model, criterion, optimizer, n_epoch, verbose=False, use_pleiss=False, topk=False, **kwargs)¶
WAUM (Lefort et al., 2024 in TMLR)¶
Measures the WAUM (Weighted Area Under the Margin) per worker and task by duplicating each task by the number of workers that responded. Ones too prone to confusion tasks are removed, the final label is a weighted distribution by the diagonal of the estimated confusion matrix.
Using:
Margin estimation
Trust score per worker and task
- __init__(tasks, answers, n_classes, model, criterion, optimizer, n_epoch, verbose=False, use_pleiss=False, topk=False, **kwargs)¶
Compute the WAUM score for each task using a stacked version of the dataset (stacked over workers)
\[\mathrm{WAUM}(x_i)= \frac{1}{\displaystyle\sum_{j'\in\mathcal{A}(x_i)} s^{(j')}(x_i)}\sum_{j\in\mathcal{A}(x_i)} s^{(j)}(x_i) \mathrm{AUM}\big(x_i, y_i^{(j)}\big)\]- Parameters:
tasks (torch Dataset) – Loader for dataset of tasks as \((x_i, y_i^{(j)}, w^{(j)}, y_i^\star, i)_{(i,j)}\)
answers (dict) –
Dictionary of workers answers with format
{ task0: {worker0: label, worker1: label}, task1: {worker1: label} }
n_classes (int) – Number of possible classes, defaults to 2
model (torch Module) – Neural network to use
criterion (torch loss) – loss to minimize for the network
optimizer (torch optimizer) – Optimization strategy for the minimization
n_epoch (int) – Number of epochs
verbose (bool, optional) – Print details in log, defaults to False
use_pleiss (bool, optional) – Use Pleiss margin instead of Yang, defaults to False
- run_DS(cut=False)¶
Run Dawid and Skene aggregation model. It
cut=True
, runs it on the pruned dataset.- Parameters:
cut (bool, optional) – Run on full dataset or dataset pruned from lower WAUM tasks, defaults to False
- make_step(batch)¶
One optimization step
- Parameters:
batch (batch) –
Batch of tasks
- Batch:
index 0: tasks \((x_i)_i\)
index 1: labels
index 2: worker
index 3: true index (witout redundancy)
index 4: tasks index \((i)_i\)
- Returns:
Tuple with length, logits, targets, workers, ground turths and index
- Return type:
- get_aum()¶
Records prediction scores of interest for the AUM during n_epoch training epochs
- reset()¶
Reload the model and optimizer from the last checkpoint. The checkpoint path is
./temp/checkpoint_waum.pth
.
- get_psuccess(probas, pij)¶
Compute weights as:
\[s_i^{(j)}=\sum_{k=1}^K \sigma(\mathcal{C}(x_i))_k \pi^{(j)}_{k,k}\]If one wishes to modify the weight, they only need to modify this function.
- Parameters:
probas (torch.Tensor) – Output predictions of neural network classifier
pij (torch.tensor) – Confusion matrix of worker j
- Returns:
Weight in the WAUM
- Return type:
torch.tensor
- get_psi1_waum()¶
To use the original margin from Pleiss et al. (2020):
\[\sigma(\mathcal{C}(x_i))_{y_i^{(j)}} - \max_{k\neq y_i^{(j)}} \sigma(\mathcal{C}(x_i))_{k}\]
- get_psi5_waum()¶
To use the margin for top-2 (or top-k):
\[\sigma(\mathcal{C}(x_i))_{y_i^{(j)}} - \sigma(\mathcal{C}(x_i))_{[2]}\]
- cut_lowests(alpha=0.01)¶
Identify tasks with lowest WAUM
- Parameters:
alpha (float, optional) – quantile for identification, defaults to 0.01
- run(alpha=0.01)¶
Run WAUM identification and label aggregation using the cut-off hyperparameter alpha
- Parameters:
alpha (float, optional) – WAUM quantile below which tasks are removed, defaults to 0.01
- get_probas()¶
Get soft labels distribution for each task
- Returns:
Weighted label frequency for each task in D_pruned
- Return type:
numpy.ndarray(n_task, n_classes)
- get_answers()¶
Argmax of soft labels.
- Returns:
Hard labels
- Return type: