Keras 3 API documentation / KerasNLP / KerasNLP Models

KerasNLP Models

KerasNLP contains end-to-end implementations of popular model architectures. These models can be created in two ways:

  • Through the from_preset() constructor, which instantiates an object with a pre-trained configurations, vocabularies, and (optionally) weights.
  • Through custom configuration controlled by the user.

Below, we list all presets available in the library. For more detailed usage, browse the docstring for a particular class. For an in depth introduction to our API, see the getting started guide.

Presets

The following preset names correspond to a config and weights for a pretrained model. Any task, preprocessor, backbone or tokenizer from_preset() can be used to create a model from the saved preset.

backbone = keras_nlp.models.Backbone.from_preset("bert_base_en")
tokenizer = keras_nlp.models.Tokenizer.from_preset("bert_base_en")
classifier = keras_nlp.models.TextClassifier.from_preset("bert_base_en", num_classes=2)
preprocessor = keras_nlp.models.TextClassifierPreprocessor.from_preset("bert_base_en")
Preset name Model Parameters Description
albert_base_en_uncased ALBERT 11.68M 12-layer ALBERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card
albert_large_en_uncased ALBERT 17.68M 24-layer ALBERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card
albert_extra_large_en_uncased ALBERT 58.72M 24-layer ALBERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card
albert_extra_extra_large_en_uncased ALBERT 222.60M 12-layer ALBERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card
bart_base_en BART 139.42M 6-layer BART model where case is maintained. Trained on BookCorpus, English Wikipedia and CommonCrawl. Model Card
bart_large_en BART 406.29M 12-layer BART model where case is maintained. Trained on BookCorpus, English Wikipedia and CommonCrawl. Model Card
bart_large_en_cnn BART 406.29M The bart_large_en backbone model fine-tuned on the CNN+DM summarization dataset. Model Card
bert_tiny_en_uncased BERT 4.39M 2-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card
bert_small_en_uncased BERT 28.76M 4-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card
bert_medium_en_uncased BERT 41.37M 8-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card
bert_base_en_uncased BERT 109.48M 12-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card
bert_base_en BERT 108.31M 12-layer BERT model where case is maintained. Trained on English Wikipedia + BooksCorpus. Model Card
bert_base_zh BERT 102.27M 12-layer BERT model. Trained on Chinese Wikipedia. Model Card
bert_base_multi BERT 177.85M 12-layer BERT model where case is maintained. Trained on trained on Wikipedias of 104 languages Model Card
bert_large_en_uncased BERT 335.14M 24-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card
bert_large_en BERT 333.58M 24-layer BERT model where case is maintained. Trained on English Wikipedia + BooksCorpus. Model Card
bert_tiny_en_uncased_sst2 BERT 4.39M The bert_tiny_en_uncased backbone model fine-tuned on the SST-2 sentiment analysis dataset. Model Card
bloom_560m_multi BLOOM 559.21M 24-layer Bloom model with hidden dimension of 1024. trained on 45 natural languages and 12 programming languages. Model Card
bloom_1.1b_multi BLOOM 1.07B 24-layer Bloom model with hidden dimension of 1536. trained on 45 natural languages and 12 programming languages. Model Card
bloom_1.7b_multi BLOOM 1.72B 24-layer Bloom model with hidden dimension of 2048. trained on 45 natural languages and 12 programming languages. Model Card
bloom_3b_multi BLOOM 3.00B 30-layer Bloom model with hidden dimension of 2560. trained on 45 natural languages and 12 programming languages. Model Card
bloomz_560m_multi BLOOMZ 559.21M 24-layer Bloom model with hidden dimension of 1024. finetuned on crosslingual task mixture (xP3) dataset. Model Card
bloomz_1.1b_multi BLOOMZ 1.07B 24-layer Bloom model with hidden dimension of 1536. finetuned on crosslingual task mixture (xP3) dataset. Model Card
bloomz_1.7b_multi BLOOMZ 1.72B 24-layer Bloom model with hidden dimension of 2048. finetuned on crosslingual task mixture (xP3) dataset. Model Card
bloomz_3b_multi BLOOMZ 3.00B 30-layer Bloom model with hidden dimension of 2560. finetuned on crosslingual task mixture (xP3) dataset. Model Card
deberta_v3_extra_small_en DeBERTaV3 70.68M 12-layer DeBERTaV3 model where case is maintained. Trained on English Wikipedia, BookCorpus and OpenWebText. Model Card
deberta_v3_small_en DeBERTaV3 141.30M 6-layer DeBERTaV3 model where case is maintained. Trained on English Wikipedia, BookCorpus and OpenWebText. Model Card
deberta_v3_base_en DeBERTaV3 183.83M 12-layer DeBERTaV3 model where case is maintained. Trained on English Wikipedia, BookCorpus and OpenWebText. Model Card
deberta_v3_large_en DeBERTaV3 434.01M 24-layer DeBERTaV3 model where case is maintained. Trained on English Wikipedia, BookCorpus and OpenWebText. Model Card
deberta_v3_base_multi DeBERTaV3 278.22M 12-layer DeBERTaV3 model where case is maintained. Trained on the 2.5TB multilingual CC100 dataset. Model Card
deeplab_v3_plus_resnet50_pascalvoc DeepLabV3 39.19M DeepLabV3+ model with ResNet50 as image encoder and trained on augmented Pascal VOC dataset by Semantic Boundaries Dataset(SBD)which is having categorical accuracy of 90.01 and 0.63 Mean IoU. Model Card
densenet_121_imagenet DenseNet 7.04M 121-layer DenseNet model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
densenet_169_imagenet DenseNet 12.64M 169-layer DenseNet model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
densenet_201_imagenet DenseNet 18.32M 201-layer DenseNet model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
distil_bert_base_en_uncased DistilBERT 66.36M 6-layer DistilBERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus using BERT as the teacher model. Model Card
distil_bert_base_en DistilBERT 65.19M 6-layer DistilBERT model where case is maintained. Trained on English Wikipedia + BooksCorpus using BERT as the teacher model. Model Card
distil_bert_base_multi DistilBERT 134.73M 6-layer DistilBERT model where case is maintained. Trained on Wikipedias of 104 languages Model Card
electra_small_discriminator_uncased_en ELECTRA 13.55M 12-layer small ELECTRA discriminator model. All inputs are lowercased. Trained on English Wikipedia + BooksCorpus. Model Card
electra_small_generator_uncased_en ELECTRA 13.55M 12-layer small ELECTRA generator model. All inputs are lowercased. Trained on English Wikipedia + BooksCorpus. Model Card
electra_base_discriminator_uncased_en ELECTRA 109.48M 12-layer base ELECTRA discriminator model. All inputs are lowercased. Trained on English Wikipedia + BooksCorpus. Model Card
electra_base_generator_uncased_en ELECTRA 33.58M 12-layer base ELECTRA generator model. All inputs are lowercased. Trained on English Wikipedia + BooksCorpus. Model Card
electra_large_discriminator_uncased_en ELECTRA 335.14M 24-layer large ELECTRA discriminator model. All inputs are lowercased. Trained on English Wikipedia + BooksCorpus. Model Card
electra_large_generator_uncased_en ELECTRA 51.07M 24-layer large ELECTRA generator model. All inputs are lowercased. Trained on English Wikipedia + BooksCorpus. Model Card
f_net_base_en FNet 82.86M 12-layer FNet model where case is maintained. Trained on the C4 dataset. Model Card
f_net_large_en FNet 236.95M 24-layer FNet model where case is maintained. Trained on the C4 dataset. Model Card
falcon_refinedweb_1b_en Falcon 1.31B 24-layer Falcon model (Falcon with 1B parameters), trained on 350B tokens of RefinedWeb dataset. Model Card
resnet_18_imagenet ResNet 11.19M 18-layer ResNet model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
resnet_50_imagenet ResNet 23.56M 50-layer ResNet model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
resnet_101_imagenet ResNet 42.61M 101-layer ResNet model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
resnet_152_imagenet ResNet 58.30M 152-layer ResNet model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
resnet_v2_50_imagenet ResNet 23.56M 50-layer ResNetV2 model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
resnet_v2_101_imagenet ResNet 42.61M 101-layer ResNetV2 model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
resnet_vd_18_imagenet ResNet 11.72M 18-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
resnet_vd_34_imagenet ResNet 21.84M 34-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
resnet_vd_50_imagenet ResNet 25.63M 50-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
resnet_vd_50_ssld_imagenet ResNet 25.63M 50-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution with knowledge distillation. Model Card
resnet_vd_50_ssld_v2_imagenet ResNet 25.63M 50-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution with knowledge distillation and AutoAugment. Model Card
resnet_vd_50_ssld_v2_fix_imagenet ResNet 25.63M 50-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution with knowledge distillation, AutoAugment and additional fine-tuning of the classification head. Model Card
resnet_vd_101_imagenet ResNet 44.67M 101-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
resnet_vd_101_ssld_imagenet ResNet 44.67M 101-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution with knowledge distillation. Model Card
resnet_vd_152_imagenet ResNet 60.36M 152-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
resnet_vd_200_imagenet ResNet 74.93M 200-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
mit_b0_ade20k_512 MiT 3.32M MiT (MixTransformer) model with 8 transformer blocks.
mit_b1_ade20k_512 MiT 13.16M MiT (MixTransformer) model with 8 transformer blocks.
mit_b2_ade20k_512 MiT 24.20M MiT (MixTransformer) model with 16 transformer blocks.
mit_b3_ade20k_512 MiT 44.08M MiT (MixTransformer) model with 28 transformer blocks.
mit_b4_ade20k_512 MiT 60.85M MiT (MixTransformer) model with 41 transformer blocks.
mit_b5_ade20k_640 MiT 81.45M MiT (MixTransformer) model with 52 transformer blocks.
mit_b0_cityscapes_1024 MiT 3.32M MiT (MixTransformer) model with 8 transformer blocks.
mit_b1_cityscapes_1024 MiT 13.16M MiT (MixTransformer) model with 8 transformer blocks.
mit_b2_cityscapes_1024 MiT 24.20M MiT (MixTransformer) model with 16 transformer blocks.
mit_b3_cityscapes_1024 MiT 44.08M MiT (MixTransformer) model with 28 transformer blocks.
mit_b4_cityscapes_1024 MiT 60.85M MiT (MixTransformer) model with 41 transformer blocks.
mit_b5_cityscapes_1024 MiT 81.45M MiT (MixTransformer) model with 52 transformer blocks.
gemma_2b_en Gemma 2.51B 2 billion parameter, 18-layer, base Gemma model. Model Card
gemma_instruct_2b_en Gemma 2.51B 2 billion parameter, 18-layer, instruction tuned Gemma model. Model Card
gemma_1.1_instruct_2b_en Gemma 2.51B 2 billion parameter, 18-layer, instruction tuned Gemma model. The 1.1 update improves model quality. Model Card
code_gemma_1.1_2b_en Gemma 2.51B 2 billion parameter, 18-layer, CodeGemma model. This model has been trained on a fill-in-the-middle (FIM) task for code completion. The 1.1 update improves model quality. Model Card
code_gemma_2b_en Gemma 2.51B 2 billion parameter, 18-layer, CodeGemma model. This model has been trained on a fill-in-the-middle (FIM) task for code completion. Model Card
gemma_7b_en Gemma 8.54B 7 billion parameter, 28-layer, base Gemma model. Model Card
gemma_instruct_7b_en Gemma 8.54B 7 billion parameter, 28-layer, instruction tuned Gemma model. Model Card
gemma_1.1_instruct_7b_en Gemma 8.54B 7 billion parameter, 28-layer, instruction tuned Gemma model. The 1.1 update improves model quality. Model Card
code_gemma_7b_en Gemma 8.54B 7 billion parameter, 28-layer, CodeGemma model. This model has been trained on a fill-in-the-middle (FIM) task for code completion. Model Card
code_gemma_instruct_7b_en Gemma 8.54B 7 billion parameter, 28-layer, instruction tuned CodeGemma model. This model has been trained for chat use cases related to code. Model Card
code_gemma_1.1_instruct_7b_en Gemma 8.54B 7 billion parameter, 28-layer, instruction tuned CodeGemma model. This model has been trained for chat use cases related to code. The 1.1 update improves model quality. Model Card
gemma2_2b_en Gemma 2.61B 2 billion parameter, 26-layer, base Gemma model. Model Card
gemma2_instruct_2b_en Gemma 2.61B 2 billion parameter, 26-layer, instruction tuned Gemma model. Model Card
gemma2_9b_en Gemma 9.24B 9 billion parameter, 42-layer, base Gemma model. Model Card
gemma2_instruct_9b_en Gemma 9.24B 9 billion parameter, 42-layer, instruction tuned Gemma model. Model Card
gemma2_27b_en Gemma 27.23B 27 billion parameter, 42-layer, base Gemma model. Model Card
gemma2_instruct_27b_en Gemma 27.23B 27 billion parameter, 42-layer, instruction tuned Gemma model. Model Card
shieldgemma_2b_en Gemma 2.61B 2 billion parameter, 26-layer, ShieldGemma model. Model Card
shieldgemma_9b_en Gemma 9.24B 9 billion parameter, 42-layer, ShieldGemma model. Model Card
shieldgemma_27b_en Gemma 27.23B 27 billion parameter, 42-layer, ShieldGemma model. Model Card
gpt2_base_en GPT-2 124.44M 12-layer GPT-2 model where case is maintained. Trained on WebText. Model Card
gpt2_medium_en GPT-2 354.82M 24-layer GPT-2 model where case is maintained. Trained on WebText. Model Card
gpt2_large_en GPT-2 774.03M 36-layer GPT-2 model where case is maintained. Trained on WebText. Model Card
gpt2_extra_large_en GPT-2 1.56B 48-layer GPT-2 model where case is maintained. Trained on WebText. Model Card
gpt2_base_en_cnn_dailymail GPT-2 124.44M 12-layer GPT-2 model where case is maintained. Finetuned on the CNN/DailyMail summarization dataset.
llama3_8b_en LLaMA 3 8.03B 8 billion parameter, 32-layer, base LLaMA 3 model. Model Card
llama3_8b_en_int8 LLaMA 3 8.03B 8 billion parameter, 32-layer, base LLaMA 3 model with activation and weights quantized to int8. Model Card
llama3_instruct_8b_en LLaMA 3 8.03B 8 billion parameter, 32-layer, instruction tuned LLaMA 3 model. Model Card
llama3_instruct_8b_en_int8 LLaMA 3 8.03B 8 billion parameter, 32-layer, instruction tuned LLaMA 3 model with activation and weights quantized to int8. Model Card
llama2_7b_en LLaMA 2 6.74B 7 billion parameter, 32-layer, base LLaMA 2 model. Model Card
llama2_7b_en_int8 LLaMA 2 6.74B 7 billion parameter, 32-layer, base LLaMA 2 model with activation and weights quantized to int8. Model Card
llama2_instruct_7b_en LLaMA 2 6.74B 7 billion parameter, 32-layer, instruction tuned LLaMA 2 model. Model Card
llama2_instruct_7b_en_int8 LLaMA 2 6.74B 7 billion parameter, 32-layer, instruction tuned LLaMA 2 model with activation and weights quantized to int8. Model Card
vicuna_1.5_7b_en Vicuna 6.74B 7 billion parameter, 32-layer, instruction tuned Vicuna v1.5 model. Model Card
mistral_7b_en Mistral 7.24B Mistral 7B base model Model Card
mistral_instruct_7b_en Mistral 7.24B Mistral 7B instruct model Model Card
mistral_0.2_instruct_7b_en Mistral 7.24B Mistral 7B instruct Version 0.2 model Model Card
opt_125m_en OPT 125.24M 12-layer OPT model where case in maintained. Trained on BookCorpus, CommonCrawl, Pile, and PushShift.io corpora. Model Card
opt_1.3b_en OPT 1.32B 24-layer OPT model where case in maintained. Trained on BookCorpus, CommonCrawl, Pile, and PushShift.io corpora. Model Card
opt_2.7b_en OPT 2.70B 32-layer OPT model where case in maintained. Trained on BookCorpus, CommonCrawl, Pile, and PushShift.io corpora. Model Card
opt_6.7b_en OPT 6.70B 32-layer OPT model where case in maintained. Trained on BookCorpus, CommonCrawl, Pile, and PushShift.io corpora. Model Card
pali_gemma_3b_mix_224 PaliGemma 2.92B image size 224, mix fine tuned, text sequence length is 256 Model Card
pali_gemma_3b_mix_448 PaliGemma 2.92B image size 448, mix fine tuned, text sequence length is 512 Model Card
pali_gemma_3b_224 PaliGemma 2.92B image size 224, pre trained, text sequence length is 128 Model Card
pali_gemma_3b_448 PaliGemma 2.92B image size 448, pre trained, text sequence length is 512 Model Card
pali_gemma_3b_896 PaliGemma 2.93B image size 896, pre trained, text sequence length is 512 Model Card
phi3_mini_4k_instruct_en Phi-3 3.82B 3.8 billion parameters, 32 layers, 4k context length, Phi-3 model. The model was trained using the Phi-3 datasets. This dataset includes both synthetic data and filtered publicly available website data, with an emphasis on high-quality and reasoning-dense properties. Model Card
phi3_mini_128k_instruct_en Phi-3 3.82B 3.8 billion parameters, 32 layers, 128k context length, Phi-3 model. The model was trained using the Phi-3 datasets. This dataset includes both synthetic data and filtered publicly available website data, with an emphasis on high-quality and reasoning-dense properties. Model Card
roberta_base_en RoBERTa 124.05M 12-layer RoBERTa model where case is maintained.Trained on English Wikipedia, BooksCorpus, CommonCraw, and OpenWebText. Model Card
roberta_large_en RoBERTa 354.31M 24-layer RoBERTa model where case is maintained.Trained on English Wikipedia, BooksCorpus, CommonCraw, and OpenWebText. Model Card
xlm_roberta_base_multi XLM-RoBERTa 277.45M 12-layer XLM-RoBERTa model where case is maintained. Trained on CommonCrawl in 100 languages. Model Card
xlm_roberta_large_multi XLM-RoBERTa 558.84M 24-layer XLM-RoBERTa model where case is maintained. Trained on CommonCrawl in 100 languages. Model Card
sam_base_sa1b SAMImageSegmenter 93.74M The base SAM model trained on the SA1B dataset. Model Card
sam_large_sa1b SAMImageSegmenter 641.09M The large SAM model trained on the SA1B dataset. Model Card
sam_huge_sa1b SAMImageSegmenter 312.34M The huge SAM model trained on the SA1B dataset. Model Card
stable_diffusion_3_medium StableDiffusion3 2.99B 3 billion parameter, including CLIP L and CLIP G text encoders, MMDiT generative model, and VAE autoencoder. Developed by Stability AI. Model Card
t5_small_multi T5 0 8-layer T5 model. Trained on the Colossal Clean Crawled Corpus (C4). Model Card
t5_base_multi T5 0 12-layer T5 model. Trained on the Colossal Clean Crawled Corpus (C4). Model Card
t5_large_multi T5 0 24-layer T5 model. Trained on the Colossal Clean Crawled Corpus (C4). Model Card
flan_small_multi T5 0 8-layer T5 model. Trained on the Colossal Clean Crawled Corpus (C4). Model Card
flan_base_multi T5 0 12-layer T5 model. Trained on the Colossal Clean Crawled Corpus (C4). Model Card
flan_large_multi T5 0 24-layer T5 model. Trained on the Colossal Clean Crawled Corpus (C4). Model Card
vgg_11_imagenet vgg 9.22M 11-layer vgg model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
vgg_13_imagenet vgg 9.40M 13-layer vgg model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
vgg_16_imagenet vgg 14.71M 16-layer vgg model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
vgg_19_imagenet vgg 20.02M 19-layer vgg model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Model Card
whisper_tiny_en Whisper 37.18M 4-layer Whisper model. Trained on 438,000 hours of labelled English speech data. Model Card
whisper_base_en Whisper 124.44M 6-layer Whisper model. Trained on 438,000 hours of labelled English speech data. Model Card
whisper_small_en Whisper 241.73M 12-layer Whisper model. Trained on 438,000 hours of labelled English speech data. Model Card
whisper_medium_en Whisper 763.86M 24-layer Whisper model. Trained on 438,000 hours of labelled English speech data. Model Card
whisper_tiny_multi Whisper 37.76M 4-layer Whisper model. Trained on 680,000 hours of labelled multilingual speech data. Model Card
whisper_base_multi Whisper 72.59M 6-layer Whisper model. Trained on 680,000 hours of labelled multilingual speech data. Model Card
whisper_small_multi Whisper 241.73M 12-layer Whisper model. Trained on 680,000 hours of labelled multilingual speech data. Model Card
whisper_medium_multi Whisper 763.86M 24-layer Whisper model. Trained on 680,000 hours of labelled multilingual speech data. Model Card
whisper_large_multi Whisper 1.54B 32-layer Whisper model. Trained on 680,000 hours of labelled multilingual speech data. Model Card
whisper_large_multi_v2 Whisper 1.54B 32-layer Whisper model. Trained for 2.5 epochs on 680,000 hours of labelled multilingual speech data. An improved of whisper_large_multi. Model Card

Note: The links provided will lead to the model card or to the official README, if no model card has been provided by the author.

API Documentation

Albert

Bart

Bert

Bloom

DebertaV3

DistilBert

Gemma

Electra

Falcon

FNet

GPT2

Llama

Llama3

Mistral

OPT

PaliGemma

Phi3

Roberta

XLMRoberta