WAUM_perworker

class WAUM_perworker(tasks, answers, n_classes, model, criterion, optimizer, n_epoch, verbose=False, use_pleiss=False, **kwargs)

WAUM per worker (Lefort et al. 2024 in TMLR)

Measures the WAUM per worker and task without duplication for each task by the number of workers that responded. Once too prone to confusion tasks are removed, the final label is a weighted distribution by the diagonal of the estimated confusion matrix.

Using:

  • Margin estimation

  • Trust score per worker and task

__init__(tasks, answers, n_classes, model, criterion, optimizer, n_epoch, verbose=False, use_pleiss=False, **kwargs)

Compute the WAUM score for each task using a stacked version of the dataset (stacked over workers). Each classifier is trained on \((x_i, y_i^{(j)})_i\) for a given worker \(j\). THE WAUM per worker writes:

\[\mathrm{WAUM}(x_i)= \frac{1}{\displaystyle\sum_{j'\in\mathcal{A}(x_i)} s^{(j')}(x_i)}\sum_{j\in\mathcal{A}(x_i)} s^{(j)}(x_i) \mathrm{AUM}\big(x_i, y_i^{(j)}\big)\]

where the difference with the classical WAUM is that the weights are notinfluenced by other workers’ answers in the classifier’s prediction. In low-number of votes per worker regime, the classical WAUM is recommended.

Parameters:
  • tasks (torch Dataset) – Loader for dataset of tasks as \((x_i, y_i^{(j)}, w^{(j)}, y_i^\star, i)_{(i,j)}\)

  • answers (dict) –

    Dictionary of workers answers with format

    {
        task0: {worker0: label, worker1: label},
        task1: {worker1: label}
    }
    

  • n_classes (int) – Number of possible classes, defaults to 2

  • model (torch Module) – Neural network to use

  • criterion (torch loss) – loss to minimize for the network

  • optimizer (torch optimizer) – Optimization strategy for the minimization

  • n_epoch (int) – Number of epochs

  • verbose (bool, optional) – Print details in log, defaults to False

  • use_pleiss (bool, optional) – Use Pleiss margin instead of Yang, defaults to False

run_DS(cut=False)

Run Dawid and Skene aggregation model. It cut=True, runs it on the pruned dataset.

Parameters:

cut (bool, optional) – Run on full dataset or dataset pruned from lower WAUM tasks, defaults to False

make_step(data_j, batch)

One optimization step

Parameters:

batch (batch) –

Batch of tasks

Batch:
  • index 0: tasks \((x_i)_i\)

  • index 1: labels

  • index 2: tasks index \((i)_i\)

Returns:

Tuple with length, logits, targets, ground turths and index

Return type:

tuple

get_aum()

Records prediction scores of interest for the AUM during n_epoch training epochs for each worker

reset()

Reload the model and optimizer from the last checkpoint. The checkpoint path is ./temp/checkpoint_waum.pth. The classifier backbone is reset between each worker in this version of the WAUM.

get_psuccess(probas, pij)

From the classifier associated to worker j \(\mathcal{C}^{(j)}\), computes weights as:

\[s_i^{(j)}=\sum_{k=1}^K \sigma(\mathcal{C}^{(j)}(x_i))_k \pi^{(j)}_{k,k}\]

If one wishes to modify the weight, they only need to modify this function.

Parameters:
  • probas (torch.Tensor) – Output predictions of neural network classifier

  • pij (torch.tensor) – Confusion matrix of worker j

Returns:

Weight in the WAUM

Return type:

torch.tensor

get_psi1_waum()

To use the original margin from Pleiss et al. (2020):

\[\sigma(\mathcal{C}^{(j)}(x_i))_{y_i^{(j)}} - \max_{k\neq y_i^{(j)}} \sigma(\mathcal{C}^{(j)}(x_i))_{k}\]
get_psi5_waum()

To use the margin for top-2 (or top-k):

\[\sigma(\mathcal{C}^{(j)}(x_i))_{y_i^{(j)}} - \sigma(\mathcal{C}^{(j)}(x_i))_{[2]}\]
cut_lowests(alpha=0.01)

Identify tasks with lowest WAUM

Parameters:

alpha (float, optional) – quantile for identification, defaults to 0.01

run(alpha=0.01)

Run WAUM identification and label aggregation using the cut-off hyperparameter alpha

Parameters:

alpha (float, optional) – WAUM quantile below which tasks are removed, defaults to 0.01

get_probas()

Get soft labels distribution for each task

Returns:

Weighted label frequency for each task

Return type:

numpy.ndarray(n_task, n_classes)

get_answers()

Argmax of soft labels, in this case corresponds to a majority vote

Returns:

Hard labels

Return type:

numpy.ndarray