Source code for transformers_domain_adaptation.data_selection.metrics.diversity

"""Diversity metrics for data selection introduced by Ruder and Plank.

The functions here were adapted and vectorized
from those in the authors' `repo <https://github.com/sebastianruder/learn-to-select-data/blob/master/features.py>`_.
"""
from functools import partial
from typing import Callable, Dict, Sequence
from typing_extensions import Literal

import numpy as np
import scipy.stats


from transformers_domain_adaptation.type import Token


[docs]def number_of_term_types(example: Sequence[Token]) -> int: """Count the number of term types of the example.""" return len(set(example))
[docs]def type_token_diversity(example: Sequence[Token]) -> float: """Calculate diversity based on the type-token ratio of the example.""" if not len(example): return 1 type_token_ratio = number_of_term_types(example) / len(example) return -type_token_ratio
[docs]def entropy(example: Sequence[Token], vocab2id: Dict[Token, int]) -> float: """Calculate Entropy (https://en.wikipedia.org/wiki/Entropy_(information_theory%29#Definition).""" example = {term for term in example if term in vocab2id} term_ids = [vocab2id[term] for term in example] return scipy.stats.entropy(term_ids)
[docs]def simpsons_index( example: Sequence[Token], train_term_dist: np.ndarray, vocab2id: Dict[Token, int] ) -> float: """Calculate Simpson's Index (https://en.wikipedia.org/wiki/Diversity_index#Simpson_index).""" if not len(example): return 0 example = {term for term in example if term in vocab2id} term_ids = [vocab2id[term] for term in example] score = (train_term_dist[term_ids] ** 2).sum() return score
[docs]def renyi_entropy( example: Sequence[Token], domain_term_dist: np.ndarray, vocab2id: Dict[Token, int] ) -> float: """Calculate Rényi Entropy (https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy).""" example = {term for term in example if term in vocab2id} term_ids = [vocab2id[term] for term in example] alpha = 0.99 summed = (domain_term_dist[term_ids] ** alpha).sum() if summed == 0: # 0 if none of the terms appear in the dictionary; # set to a small constant == low prob instead summed = 0.0001 score = 1 / (1 - alpha) * np.log(summed) return score
############################# ##### Function factory ##### ############################# DiversityMetric = Literal[ "num_token_types", "type_token_ratio", "entropy", "simpsons_index", "renyi_entropy", ] DiversityFunction = Callable[[Sequence[Token]], float] DIVERSITY_FEATURES = { "num_token_types", "type_token_ratio", "entropy", "simpsons_index", "renyi_entropy", }
[docs]def diversity_func_factory( metric: DiversityMetric, train_term_dist: np.ndarray, vocab2id: Dict[Token, int] ) -> DiversityFunction: """Return the corresponding diversity function based on the provided metric. Args: metric (str): Diversity metric train_term_dist: Term distribution of the training data vocab2id: Vocabulary-to-id mapping Raises: ValueError: If `metric` does not exist in DIVERSITY_FEATURES """ if metric not in DIVERSITY_FEATURES: raise ValueError(f'"{metric}" is not a valid diversity metric.') mapping: Dict[DiversityMetric, DiversityFunction] = { "num_token_types": number_of_term_types, "type_token_ratio": type_token_diversity, "entropy": partial(entropy, vocab2id=vocab2id), "simpsons_index": partial( simpsons_index, train_term_dist=train_term_dist, vocab2id=vocab2id ), "renyi_entropy": partial( renyi_entropy, domain_term_dist=train_term_dist, vocab2id=vocab2id, # TODO: Double check correctness of `domain_term_dist` ), } diversity_function = mapping[metric] return diversity_function