Source code for transformers_domain_adaptation.data_selection.metrics.similarity

"""Similiarity metrics for data selection introduced by Ruder and Plank.

The functions here were adapted and vectorized
from those in the authors' `repo <https://github.com/sebastianruder/learn-to-select-data/blob/master/similarity.py>`_.
"""
from typing import Callable, Dict
from typing_extensions import Literal

import numpy as np
import scipy.stats
import scipy.spatial.distance


[docs]def jensen_shannon_similarity(repr1: np.ndarray, repr2: np.ndarray) -> np.ndarray: """Calculate similairty based on Jensen-Shannon divergence. https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence """ if len(repr1) == 1: repr1 = np.repeat(repr1, len(repr2), axis=0) elif len(repr2) == 1: repr2 = np.repeat(repr2, len(repr1), axis=0) avg_repr = 0.5 * (repr1 + repr2) sim = np.array( [ 1 - 0.5 * (scipy.stats.entropy(p, avg) + scipy.stats.entropy(q, avg)) for p, q, avg in zip(repr1, repr2, avg_repr) ] ) # the similarity is -inf if no term in the document is in the vocabulary sim = np.where(np.isinf(sim), 0, sim) return sim
[docs]def renyi_similarity( repr1: np.ndarray, repr2: np.ndarray, alpha: float = 0.99 ) -> np.ndarray: """Calculate similarity based on Rényi divergence. https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy#R.C3.A9nyi_divergence """ log_sum = (np.power(repr1, alpha) / np.power(repr2, alpha - 1)).sum(axis=-1) renyi_divergence = 1 / (alpha - 1) * np.log(log_sum) return -renyi_divergence
[docs]def cosine_similarity(repr1: np.ndarray, repr2: np.ndarray) -> np.ndarray: """Calculate cosine similarity (https://en.wikipedia.org/wiki/Cosine_similarity).""" if len(repr1) == 1: repr1 = np.repeat(repr1, len(repr2), axis=0) elif len(repr2) == 1: repr2 = np.repeat(repr2, len(repr1), axis=0) assert not (np.isnan(repr2).any() or np.isinf(repr2).any()) assert not (np.isnan(repr1).any() or np.isinf(repr1).any()) sim = np.array( [1 - scipy.spatial.distance.cosine(p, q) for p, q in zip(repr1, repr2)] ) # the similarity is nan if no term in the document is in the vocabulary sim = np.where(np.isnan(sim), 0, sim) return sim
[docs]def euclidean_similarity(repr1: np.ndarray, repr2: np.ndarray) -> np.ndarray: """Calculate similarity based on Euclidean distance. https://en.wikipedia.org/wiki/Euclidean_distance """ euclidean_distance = np.sqrt(((repr1 - repr2) ** 2).sum(axis=-1)) return -euclidean_distance
[docs]def variational_similarity(repr1: np.ndarray, repr2: np.ndarray) -> np.ndarray: """Calculate similarity based on L1 / Manhattan distance. https://en.wikipedia.org/wiki/Taxicab_geometry """ manhattan_distance = np.abs(repr1 - repr2).sum(axis=-1) return -manhattan_distance
[docs]def bhattacharyya_similarity(repr1: np.ndarray, repr2: np.ndarray) -> np.ndarray: """Calculate similarity based on Bhattacharyya distance. https://en.wikipedia.org/wiki/Bhattacharyya_distance """ distance = -np.log(np.sqrt(repr1 * repr2).sum(axis=-1)) assert not np.isnan(distance).any(), "Error: Similarity is nan." # the distance is -inf if no term in the review is in the vocabulary distance = np.where(np.isinf(distance), 0, distance) return -distance
############################ ##### Function factory ##### ############################ SimilarityMetric = Literal[ "jensen-shannon", "renyi", "cosine", "euclidean", "variational", "bhattacharyya", ] SimilarityFunc = Callable[[np.ndarray, np.ndarray], np.ndarray] SIMILARITY_FEATURES = { "jensen-shannon", "renyi", "cosine", "euclidean", "variational", "bhattacharyya", }
[docs]def similarity_func_factory(metric: SimilarityMetric) -> SimilarityFunc: """Return the corresponding similarity function based on the provided metric. Args: metric (str): Similarity metric Raises: ValueError: If `metric` does not exist in SIMILARITY_FEATURES """ if metric not in SIMILARITY_FEATURES: raise ValueError(f'"{metric}" is not a valid similarity metric.') mapping: Dict[SimilarityMetric, SimilarityFunc] = { "jensen-shannon": jensen_shannon_similarity, "renyi": renyi_similarity, "cosine": cosine_similarity, "euclidean": euclidean_similarity, "variational": variational_similarity, "bhattacharyya": bhattacharyya_similarity, } similarity_function = mapping[metric] return similarity_function