Author: Humbulani Ndou
Date created: 2025/03/17
Last modified: 2025/03/17
Description: Gated Residual and Variable Selection Networks prediction with HyperParameters tuning.
The following example extends the script structured_data/classification_with_grn_and_vsn.py
by incorporating hyperparameters tuning
using Autokeras and KerasTuner. Specifics regarding
which APIs are used from the these two packages will be described in detail in the relevant code sections.
This example demonstrates the use of Gated Residual Networks (GRN) and Variable Selection Networks (VSN), proposed by Bryan Lim et al. in Temporal Fusion Transformers (TFT) for Interpretable Multi-horizon Time Series Forecasting, for structured data classification. GRNs give the flexibility to the model to apply non-linear processing only where needed. VSNs allow the model to softly remove any unnecessary noisy inputs which could negatively impact performance. Together, those techniques help improving the learning capacity of deep neural network models.
Note that this example implements only the GRN and VSN components described in in the paper, rather than the whole TFT model, as GRN and VSN can be useful on their own for structured data learning tasks.
To run the code you need to use TensorFlow 2.3 or higher.
Our dataset is provided by the Cleveland Clinic Foundation for Heart Disease. It's a CSV file with 303 rows. Each row contains information about a patient (a sample), and each column describes an attribute of the patient (a feature). We use the features to predict whether a patient has a heart disease (binary classification).
Here's the description of each feature:
Column | Description | Feature Type |
---|---|---|
Age | Age in years | Numerical |
Sex | (1 = male; 0 = female) | Categorical |
CP | Chest pain type (0, 1, 2, 3, 4) | Categorical |
Trestbpd | Resting blood pressure (in mm Hg on admission) | Numerical |
Chol | Serum cholesterol in mg/dl | Numerical |
FBS | fasting blood sugar in 120 mg/dl (1 = true; 0 = false) | Categorical |
RestECG | Resting electrocardiogram results (0, 1, 2) | Categorical |
Thalach | Maximum heart rate achieved | Numerical |
Exang | Exercise induced angina (1 = yes; 0 = no) | Categorical |
Oldpeak | ST depression induced by exercise relative to rest | Numerical |
Slope | Slope of the peak exercise ST segment | Numerical |
CA | Number of major vessels (0-3) colored by fluoroscopy | Both numerical & categorical |
Thal | 3 = normal; 6 = fixed defect; 7 = reversible defect | Categorical |
Target | Diagnosis of heart disease (1 = true; 0 = false) | Target |
import os
import subprocess
import tarfile
import numpy as np
import pandas as pd
import tree
from typing import Optional, Union
os.environ["KERAS_BACKEND"] = "tensorflow" # or jax, or torch
# Keras imports
import keras
from keras import layers
# KerasTuner imports
import keras_tuner
from keras_tuner import HyperParameters
# AutoKeras imports
import autokeras as ak
from autokeras.utils import utils, types
Let's download the data and load it into a Pandas dataframe:
file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
dataframe = pd.read_csv(file_url)
The dataset includes 303 samples with 14 columns per sample (13 features, plus the target label):
dataframe.shape
(303, 14)
Here's a preview of a few samples:
dataframe.head()
age | sex | cp | trestbps | chol | fbs | restecg | thalach | exang | oldpeak | slope | ca | thal | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 63 | 1 | 1 | 145 | 233 | 1 | 2 | 150 | 0 | 2.3 | 3 | 0 | fixed | 0 |
1 | 67 | 1 | 4 | 160 | 286 | 0 | 2 | 108 | 1 | 1.5 | 2 | 3 | normal | 1 |
2 | 67 | 1 | 4 | 120 | 229 | 0 | 2 | 129 | 1 | 2.6 | 2 | 2 | reversible | 0 |
3 | 37 | 1 | 3 | 130 | 250 | 0 | 0 | 187 | 0 | 3.5 | 3 | 0 | normal | 0 |
4 | 41 | 0 | 2 | 130 | 204 | 0 | 2 | 172 | 0 | 1.4 | 1 | 0 | normal | 0 |
The last column, "target", indicates whether the patient has a heart disease (1) or not (0).
Let's split the data into a training and validation set:
val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
train_dataframe = dataframe.drop(val_dataframe.index)
print(
f"Using {len(train_dataframe)} samples for training "
f"and {len(val_dataframe)} for validation"
)
Using 242 samples for training and 61 for validation
Here, we define the metadata of the dataset that will be useful for reading and parsing the data into input features, and encoding the input features with respect to their types.
COLUMN_NAMES = [
"age",
"sex",
"cp",
"trestbps",
"chol",
"fbs",
"restecg",
"thalach",
"exang",
"oldpeak",
"slope",
"ca",
"thal",
"target",
]
# Target feature name.
TARGET_FEATURE_NAME = "target"
# Numeric feature names.
NUMERIC_FEATURE_NAMES = ["age", "trestbps", "thalach", "oldpeak", "slope", "chol"]
# Categorical features and their vocabulary lists.
# Note that we add 'v=' as a prefix to all categorical feature values to make
# sure that they are treated as strings.
CATEGORICAL_FEATURES_WITH_VOCABULARY = {
feature_name: sorted(
[
# Integer categorcal must be int and string must be str
value if dataframe[feature_name].dtype == "int64" else str(value)
for value in list(dataframe[feature_name].unique())
]
)
for feature_name in COLUMN_NAMES
if feature_name not in list(NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME])
}
# All features names.
FEATURE_NAMES = NUMERIC_FEATURE_NAMES + list(
CATEGORICAL_FEATURES_WITH_VOCABULARY.keys()
)
The following features are categorical features encoded as integers:
sex
cp
fbs
restecg
exang
ca
We will encode these features using one-hot encoding. We have two options here:
CategoryEncoding()
, which requires knowing the range of input values
and will error on input outside the range.IntegerLookup()
which will build a lookup table for inputs and reserve
an output index for unkown input values.For this example, we want a simple solution that will handle out of range inputs
at inference, so we will use IntegerLookup()
.
We also have a categorical feature encoded as a string: thal
. We will create an
index of all possible features and encode output using the StringLookup()
layer.
Finally, the following feature are continuous numerical features:
age
trestbps
chol
thalach
oldpeak
slope
For each of these features, we will use a Normalization()
layer to make sure the mean
of each feature is 0 and its standard deviation is 1.
Below, we define a utility function to do the operations:
process
to one-hot encode string or integer categorical features.# Tensorflow required for tf.data.Dataset
import tensorflow as tf
# We process our datasets elements here (categorical) and convert them to indices to avoid this step
# during model training since only tensorflow support strings.
def encode_categorical(features, target):
for f in features:
if f in CATEGORICAL_FEATURES_WITH_VOCABULARY:
# Create a lookup to convert a string values to an integer indices.
# Since we are not using a mask token nor expecting any out of vocabulary
# (oov) token, we set mask_token to None and num_oov_indices to 0.
cls = (
layers.StringLookup
if features[f].dtype == "string"
else layers.IntegerLookup
)
features[f] = cls(
vocabulary=CATEGORICAL_FEATURES_WITH_VOCABULARY[f],
mask_token=None,
num_oov_indices=0,
output_mode="binary",
)(features[f])
# Change features from OrderedDict to Dict to match Inputs as they are Dict.
return dict(features), target
Let's generate tf.data.Dataset
objects for each dataframe:
def dataframe_to_dataset(dataframe):
dataframe = dataframe.copy()
labels = dataframe.pop("target")
ds = (
tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
.map(encode_categorical)
.shuffle(buffer_size=len(dataframe))
)
return ds
train_ds = dataframe_to_dataset(train_dataframe)
val_ds = dataframe_to_dataset(val_dataframe)
Each Dataset
yields a tuple (input, target)
where input
is a dictionary of features
and target
is the value 0
or 1
:
for x, y in train_ds.take(1):
print("Input:", x)
print("Target:", y)
Input: {'age': <tf.Tensor: shape=(), dtype=int64, numpy=37>, 'sex': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'cp': <tf.Tensor: shape=(5,), dtype=int64, numpy=array([0, 0, 0, 1, 0])>, 'trestbps': <tf.Tensor: shape=(), dtype=int64, numpy=120>, 'chol': <tf.Tensor: shape=(), dtype=int64, numpy=215>, 'fbs': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'restecg': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 0, 0])>, 'thalach': <tf.Tensor: shape=(), dtype=int64, numpy=170>, 'exang': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'oldpeak': <tf.Tensor: shape=(), dtype=float64, numpy=0.0>, 'slope': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'ca': <tf.Tensor: shape=(4,), dtype=int64, numpy=array([1, 0, 0, 0])>, 'thal': <tf.Tensor: shape=(5,), dtype=int64, numpy=array([0, 0, 0, 1, 0])>}
Target: tf.Tensor(0, shape=(), dtype=int64)
Let's batch the datasets:
train_ds = train_ds.batch(32)
val_ds = val_ds.batch(32)
Here we subclass the Autokeras Graph
build
: we override this method to be able to handle model Inputs
passed
as dictionaries. In structured data analysis Inputs are normally passed as
dictionaries for each feature of interestclass Graph(ak.graph.Graph):
def build(self, hp):
"""Build the HyperModel into a Keras Model."""
keras_nodes = {}
keras_input_nodes = []
for node in self.inputs:
node_id = self._node_to_id[node]
input_node = node.build_node(hp)
output_node = node.build(hp, input_node)
keras_input_nodes.append(input_node)
keras_nodes[node_id] = output_node
for block in self.blocks:
temp_inputs = (
{
n.name: keras_nodes[self._node_to_id[n]]
for n in block.inputs
if isinstance(n, ak.Input)
}
if isinstance(block.inputs[0], ak.Input)
else [keras_nodes[self._node_to_id[n]] for n in block.inputs]
)
outputs = tree.flatten(block.build(hp, inputs=temp_inputs))
for n, o in zip(block.outputs, outputs):
keras_nodes[self._node_to_id[n]] = o
model = keras.models.Model(
keras_input_nodes,
[
keras_nodes[self._node_to_id[output_node]]
for output_node in self.outputs
],
)
return self._compile_keras_model(hp, model)
def _compile_keras_model(self, hp, model):
# Specify hyperparameters from compile(...)
optimizer_name = hp.Choice(
"optimizer",
["adam", "sgd"],
default="adam",
)
learning_rate = hp.Choice(
"learning_rate", [1e-1, 1e-2, 1e-3, 1e-4, 2e-5, 1e-5], default=1e-3
)
if optimizer_name == "adam":
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
elif optimizer_name == "sgd":
optimizer = keras.optimizers.SGD(learning_rate=learning_rate)
model.compile(
optimizer=optimizer,
metrics=self._get_metrics(),
loss=self._get_loss(),
)
return model
Input
Here we subclass the Autokeras Input node object and override the dtype attribute
from None to a user supplied value. We also override the build_node
method to
use user supplied name for Inputs layers.
class Input(ak.Input):
def __init__(self, dtype, name=None, **kwargs):
super().__init__(name=name, **kwargs)
# Override dtype to a user dtype value
self.dtype = dtype
self.name = name
def build_node(self, hp):
return keras.Input(name=self.name, shape=self.shape, dtype=self.dtype)
Here we subclass Autokeras ClassificationHead and override the init method, and
we add the method get_expected_shape
to infer the labels shape.
We remove the preprocessing fuctionality as we prefer to conduct such manually.
class ClassifierHead(ak.ClassificationHead):
def __init__(
self,
num_classes: Optional[int] = None,
multi_label: bool = False,
loss: Optional[types.LossType] = None,
metrics: Optional[types.MetricsType] = None,
dropout: Optional[float] = None,
**kwargs,
):
self.num_classes = num_classes
self.multi_label = multi_label
self.dropout = dropout
if metrics is None:
metrics = ["accuracy"]
if loss is None:
loss = self.infer_loss()
ak.Head.__init__(self, loss=loss, metrics=metrics, **kwargs)
self.shape = self.get_expected_shape()
def get_expected_shape(self):
# Compute expected shape from num_classes.
if self.num_classes == 2 and not self.multi_label:
return [1]
return [self.num_classes]
This is a keras layer defined in the script structured_data/classification_with_grn_vsn.py
More details about this layer maybe found in the relevant script
class GatedLinearUnit(layers.Layer):
def __init__(self, num_units, activation, **kwargs):
super().__init__(**kwargs)
self.linear = layers.Dense(num_units)
self.sigmoid = layers.Dense(num_units, activation=activation)
def call(self, inputs):
return self.linear(inputs) * self.sigmoid(inputs)
def build(self):
self.built = True
This is a keras layer defined in the script structured_data/classification_with_grn_vsn.py
More details about this layer maybe found in the relevant script
class GatedResidualNetwork(layers.Layer):
def __init__(
self, num_units, dropout_rate, activation, use_layernorm=None, **kwargs
):
super().__init__(**kwargs)
self.num_units = num_units
self.use_layernorm = use_layernorm
self.elu_dense = layers.Dense(num_units, activation=activation)
self.linear_dense = layers.Dense(num_units)
self.dropout = layers.Dropout(dropout_rate)
self.gated_linear_unit = GatedLinearUnit(num_units, activation)
self.layer_norm = layers.LayerNormalization()
self.project = layers.Dense(num_units)
def call(self, inputs, hp):
x = self.elu_dense(inputs)
x = self.linear_dense(x)
x = self.dropout(x)
if inputs.shape[-1] != self.num_units:
inputs = self.project(inputs)
x = inputs + self.gated_linear_unit(x)
use_layernorm = self.use_layernorm
if use_layernorm is None:
use_layernorm = hp.Boolean("use_layernorm", default=True)
if use_layernorm:
x = self.layer_norm(x)
return x
def build(self):
self.built = True
VariableSelection Block
We have converted the following keras layer to an Autokeras Block to include hyperapameters to tune. Refer to Autokeras blocks API for writing custom Blocks.
class VariableSelection(ak.Block):
def __init__(
self,
num_units: Optional[Union[int, HyperParameters.Choice]] = None,
dropout_rate: Optional[Union[float, HyperParameters.Choice]] = None,
activation: Optional[Union[str, HyperParameters.Choice]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.dropout = utils.get_hyperparameter(
dropout_rate,
HyperParameters().Choice("dropout", [0.0, 0.25, 0.5], default=0.0),
float,
)
self.num_units = utils.get_hyperparameter(
num_units,
HyperParameters().Choice(
"num_units", [16, 32, 64, 128, 256, 512, 1024], default=16
),
int,
)
self.activation = utils.get_hyperparameter(
activation,
HyperParameters().Choice(
"vsn_activation", ["sigmoid", "elu"], default="sigmoid"
),
str,
)
def build(self, hp, inputs):
num_units = utils.add_to_hp(self.num_units, hp, "num_units")
dropout_rate = utils.add_to_hp(self.dropout, hp, "dropout_rate")
activation = utils.add_to_hp(self.activation, hp, "activation")
concat_inputs = []
# Project the features to 'num_units' dimension
for input_ in inputs:
if input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY:
concat_inputs.append(
keras.layers.Dense(units=num_units)(inputs[input_])
)
else:
# Create a Normalization layer for our feature
normalizer = layers.Normalization()
# Prepare a Dataset that only yields our feature
feature_ds = train_ds.map(lambda x, y: x[input_]).map(
lambda x: keras.ops.expand_dims(x, -1)
)
# Learn the statistics of the data
normalizer.adapt(feature_ds)
# Normalize the input feature
normal_feature = normalizer(inputs[input_])
concat_inputs.append(
keras.layers.Dense(units=num_units)(normal_feature)
)
v = layers.concatenate(concat_inputs)
v = GatedResidualNetwork(
num_units=num_units, dropout_rate=dropout_rate, activation=activation
)(v, hp=hp)
v = keras.ops.expand_dims(
layers.Dense(units=len(inputs), activation=activation)(v), axis=-1
)
x = []
x += [
GatedResidualNetwork(num_units, dropout_rate, activation)(i, hp=hp)
for i in concat_inputs
]
x = keras.ops.stack(x, axis=1)
return keras.ops.squeeze(
keras.ops.matmul(keras.ops.transpose(v, axes=[0, 2, 1]), x), axis=1
)
# Categorical features have different shapes after the encoding, dependent on the
# vocabulary or unique values of each feature. We create them accordinly to match the
# input data elements generated by tf.data.Dataset after pre-processing them
def create_model_inputs():
inputs = {
f: (
Input(
name=f,
shape=(len(CATEGORICAL_FEATURES_WITH_VOCABULARY[f]),),
dtype="int64",
)
if f in CATEGORICAL_FEATURES_WITH_VOCABULARY
else Input(name=f, shape=(1,), dtype="float32")
)
for f in FEATURE_NAMES
}
return inputs
HyperModel
Here we use the Autokeras Functional
API to construct a network of BlocksSSS which will
be built into a KerasTuner HyperModel and finally to a Keras Model.
class MyHyperModel(keras_tuner.HyperModel):
def build(self, hp):
inputs = create_model_inputs()
features = VariableSelection()(inputs)
outputs = ClassifierHead(num_classes=2, multi_label=False)(features)
model = Graph(inputs=inputs, outputs=outputs)
model = model.build(hp)
return model
def fit(self, hp, model, *args, **kwargs):
return model.fit(
*args,
# Tune whether to shuffle the data in each epoch.
shuffle=hp.Boolean("shuffle"),
**kwargs,
)
RandomSearch
Tuner to find best HyperParametersWe use the RandomSearch tuner to serach for hyparameters in the search space We also display the search space
print("Start training and searching for the best model...")
tuner = keras_tuner.RandomSearch(
MyHyperModel(),
objective="val_accuracy",
max_trials=3,
overwrite=True,
directory="my_dir",
project_name="tune_hypermodel",
)
# Show the search space summary
print("Tuner search space summary:\n")
tuner.search_space_summary()
# Search for best model
tuner.search(train_ds, epochs=2, validation_data=val_ds)
Trial 3 Complete [00h 00m 16s]
val_accuracy: 0.8032786846160889
Best val_accuracy So Far: 0.8032786846160889
Total elapsed time: 00h 00m 34s
# Get the top model.
models = tuner.get_best_models(num_models=1)
best_model = models[0]
best_model.summary()
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:757: UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 346 variables.
saveable.load_own_variables(weights_store.get(inner_path))
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ age (InputLayer) │ (None, 1) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ chol (InputLayer) │ (None, 1) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ oldpeak │ (None, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ slope (InputLayer) │ (None, 1) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ thalach │ (None, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ trestbps │ (None, 1) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32 │ (None, 1) │ 0 │ age[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ ca (InputLayer) │ (None, 4) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_2 │ (None, 1) │ 0 │ chol[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cp (InputLayer) │ (None, 5) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ exang (InputLayer) │ (None, 2) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ fbs (InputLayer) │ (None, 2) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_6 │ (None, 1) │ 0 │ oldpeak[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ restecg │ (None, 3) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ sex (InputLayer) │ (None, 2) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_9 │ (None, 1) │ 0 │ slope[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ thal (InputLayer) │ (None, 5) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_11 │ (None, 1) │ 0 │ thalach[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_12 │ (None, 1) │ 0 │ trestbps[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ normalization │ (None, 1) │ 3 │ cast_to_float32[… │ │ (Normalization) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_1 │ (None, 4) │ 0 │ ca[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ normalization_1 │ (None, 1) │ 3 │ cast_to_float32_… │ │ (Normalization) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_3 │ (None, 5) │ 0 │ cp[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_4 │ (None, 2) │ 0 │ exang[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_5 │ (None, 2) │ 0 │ fbs[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ normalization_2 │ (None, 1) │ 3 │ cast_to_float32_… │ │ (Normalization) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_7 │ (None, 3) │ 0 │ restecg[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_8 │ (None, 2) │ 0 │ sex[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ normalization_3 │ (None, 1) │ 3 │ cast_to_float32_… │ │ (Normalization) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ cast_to_float32_10 │ (None, 5) │ 0 │ thal[0][0] │ │ (CastToFloat32) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ normalization_4 │ (None, 1) │ 3 │ cast_to_float32_… │ │ (Normalization) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ normalization_5 │ (None, 1) │ 3 │ cast_to_float32_… │ │ (Normalization) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense (Dense) │ (None, 16) │ 32 │ normalization[0]… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_1 (Dense) │ (None, 16) │ 80 │ cast_to_float32_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_2 (Dense) │ (None, 16) │ 32 │ normalization_1[… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_3 (Dense) │ (None, 16) │ 96 │ cast_to_float32_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_4 (Dense) │ (None, 16) │ 48 │ cast_to_float32_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_5 (Dense) │ (None, 16) │ 48 │ cast_to_float32_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_6 (Dense) │ (None, 16) │ 32 │ normalization_2[… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_7 (Dense) │ (None, 16) │ 64 │ cast_to_float32_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_8 (Dense) │ (None, 16) │ 48 │ cast_to_float32_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_9 (Dense) │ (None, 16) │ 32 │ normalization_3[… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_10 (Dense) │ (None, 16) │ 96 │ cast_to_float32_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_11 (Dense) │ (None, 16) │ 32 │ normalization_4[… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_12 (Dense) │ (None, 16) │ 32 │ normalization_5[… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ concatenate │ (None, 208) │ 0 │ dense[0][0], │ │ (Concatenate) │ │ │ dense_1[0][0], │ │ │ │ │ dense_2[0][0], │ │ │ │ │ dense_3[0][0], │ │ │ │ │ dense_4[0][0], │ │ │ │ │ dense_5[0][0], │ │ │ │ │ dense_6[0][0], │ │ │ │ │ dense_7[0][0], │ │ │ │ │ dense_8[0][0], │ │ │ │ │ dense_9[0][0], │ │ │ │ │ dense_10[0][0], │ │ │ │ │ dense_11[0][0], │ │ │ │ │ dense_12[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 7,536 │ concatenate[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_18 (Dense) │ (None, 13) │ 221 │ gated_residual_n… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ expand_dims │ (None, 13, 1) │ 0 │ dense_18[0][0] │ │ (ExpandDims) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_1[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_2[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_3[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_4[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_5[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_6[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_7[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_8[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_9[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_10[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_11[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ gated_residual_net… │ (None, 16) │ 1,120 │ dense_12[0][0] │ │ (GatedResidualNetw… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transpose │ (None, 1, 13) │ 0 │ expand_dims[0][0] │ │ (Transpose) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stack (Stack) │ (None, 13, 16) │ 0 │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ │ │ │ │ gated_residual_n… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ matmul (Matmul) │ (None, 1, 16) │ 0 │ transpose[0][0], │ │ │ │ │ stack[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ squeeze (Squeeze) │ (None, 16) │ 0 │ matmul[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dropout_14 │ (None, 16) │ 0 │ squeeze[0][0] │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_84 (Dense) │ (None, 1) │ 17 │ dropout_14[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ classifier_head_1 │ (None, 1) │ 0 │ dense_84[0][0] │ │ (Activation) │ │ │ │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 23,024 (89.96 KB)
Trainable params: 23,006 (89.87 KB)
Non-trainable params: 18 (96.00 B)
To get a prediction for a new sample, you can simply call model.predict()
. There are
just two things you need to do:
convert_to_tensor
on each featuresample = {
"age": 60,
"sex": 1,
"cp": 1,
"trestbps": 145,
"chol": 233,
"fbs": 1,
"restecg": 2,
"thalach": 150,
"exang": 0,
"oldpeak": 2.3,
"slope": 3,
"ca": 0,
"thal": "fixed",
}
# Given the category (in the sample above - key) and the category value (in the sample above - value),
# we return its one-hot encoding
def get_cat_encoding(cat, cat_value):
# Create a list of zeros with the same length as categories
encoding = [0] * len(cat)
# Find the index of category_value in categories and set the corresponding position to 1
if cat_value in cat:
encoding[cat.index(cat_value)] = 1
return encoding
for name, value in sample.items():
if name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
sample.update(
{
name: get_cat_encoding(
CATEGORICAL_FEATURES_WITH_VOCABULARY[name], sample[name]
)
}
)
# Convert inputs to tensors
input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}
predictions = best_model.predict(input_dict)
print(
f"This particular patient had a {100 * predictions[0][0]:.1f} "
"percent probability of having a heart disease, "
"as evaluated by our model."
)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 136ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 162ms/step
This particular patient had a 28.1 percent probability of having a heart disease, as evaluated by our model.