BeamSampler
classkeras_hub.samplers.BeamSampler(num_beams=5, return_all_beams=False, **kwargs)
Beam Sampler class.
This sampler implements beam search algorithm. At each time-step, beam
search keeps the beams (sequences) of the top num_beams
highest
accumulated probabilities, and uses each one of the beams to predict
candidate next tokens.
Arguments
num_beams
should be strictly positive.True
, the sampler will return all
beams and their respective probabilities score.Call arguments
{{call_args}}
Examples
causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Pass by name to compile.
causal_lm.compile(sampler="beam")
causal_lm.generate(["Keras is a"])
# Pass by object to compile.
sampler = keras_hub.samplers.BeamSampler(num_beams=5)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])