Perplexity
classkeras_nlp.metrics.Perplexity(
from_logits=False, mask_token_id=None, dtype="float32", name="perplexity", **kwargs
)
Perplexity metric.
This class implements the perplexity metric. In short, this class calculates the cross entropy loss and takes its exponent. Note: This implementation is not suitable for fixed-size windows.
Arguments
y_pred
(input to update_state()
) should
be the logits as returned by the model. Otherwise, y_pred
is a
tensor of probabilities.sample_weight
field in update_state()
is also provided,
we will compute the final sample_weight
as the element-wise
product of the mask and the sample_weight
."float32"
.Examples
sample_weight
, and mask_token_id
are not provided.>>> np.random.seed(42)
>>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> perplexity.update_state(target, logits)
>>> perplexity.result()
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
1.2. sample_weight
specified (masking token with ID 0).
>>> np.random.seed(42)
>>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> sample_weight = (target != 0).astype("float32")
>>> perplexity.update_state(target, logits, sample_weight)
>>> perplexity.result()
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
>>> np.random.seed(42)
>>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> perplexity(target, logits)
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
>>> np.random.seed(42)
>>> perplexity = keras_hub.metrics.Perplexity(mask_token_id=0)
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> perplexity(target, logits)
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>