Author: fchollet
Date created: 2020/06/09
Last modified: 2020/06/09
Description: Binary classification of structured data including numerical and categorical features.
This example demonstrates how to do structured data classification, starting from a raw CSV file. Our data includes both numerical and categorical features. We will use Keras preprocessing layers to normalize the numerical features and vectorize the categorical ones.
Note that this example should be run with TensorFlow 2.5 or higher.
Our dataset is provided by the Cleveland Clinic Foundation for Heart Disease. It's a CSV file with 303 rows. Each row contains information about a patient (a sample), and each column describes an attribute of the patient (a feature). We use the features to predict whether a patient has a heart disease (binary classification).
Here's the description of each feature:
Column | Description | Feature Type |
---|---|---|
Age | Age in years | Numerical |
Sex | (1 = male; 0 = female) | Categorical |
CP | Chest pain type (0, 1, 2, 3, 4) | Categorical |
Trestbpd | Resting blood pressure (in mm Hg on admission) | Numerical |
Chol | Serum cholesterol in mg/dl | Numerical |
FBS | fasting blood sugar in 120 mg/dl (1 = true; 0 = false) | Categorical |
RestECG | Resting electrocardiogram results (0, 1, 2) | Categorical |
Thalach | Maximum heart rate achieved | Numerical |
Exang | Exercise induced angina (1 = yes; 0 = no) | Categorical |
Oldpeak | ST depression induced by exercise relative to rest | Numerical |
Slope | Slope of the peak exercise ST segment | Numerical |
CA | Number of major vessels (0-3) colored by fluoroscopy | Both numerical & categorical |
Thal | 3 = normal; 6 = fixed defect; 7 = reversible defect | Categorical |
Target | Diagnosis of heart disease (1 = true; 0 = false) | Target |
import os
# TensorFlow is the only backend that supports string inputs.
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import pandas as pd
import keras
from keras import layers
Let's download the data and load it into a Pandas dataframe:
file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
dataframe = pd.read_csv(file_url)
The dataset includes 303 samples with 14 columns per sample (13 features, plus the target label):
dataframe.shape
(303, 14)
Here's a preview of a few samples:
dataframe.head()
age | sex | cp | trestbps | chol | fbs | restecg | thalach | exang | oldpeak | slope | ca | thal | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 63 | 1 | 1 | 145 | 233 | 1 | 2 | 150 | 0 | 2.3 | 3 | 0 | fixed | 0 |
1 | 67 | 1 | 4 | 160 | 286 | 0 | 2 | 108 | 1 | 1.5 | 2 | 3 | normal | 1 |
2 | 67 | 1 | 4 | 120 | 229 | 0 | 2 | 129 | 1 | 2.6 | 2 | 2 | reversible | 0 |
3 | 37 | 1 | 3 | 130 | 250 | 0 | 0 | 187 | 0 | 3.5 | 3 | 0 | normal | 0 |
4 | 41 | 0 | 2 | 130 | 204 | 0 | 2 | 172 | 0 | 1.4 | 1 | 0 | normal | 0 |
The last column, "target", indicates whether the patient has a heart disease (1) or not (0).
Let's split the data into a training and validation set:
val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
train_dataframe = dataframe.drop(val_dataframe.index)
print(
f"Using {len(train_dataframe)} samples for training "
f"and {len(val_dataframe)} for validation"
)
Using 242 samples for training and 61 for validation
Let's generate tf.data.Dataset
objects for each dataframe:
def dataframe_to_dataset(dataframe):
dataframe = dataframe.copy()
labels = dataframe.pop("target")
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
ds = ds.shuffle(buffer_size=len(dataframe))
return ds
train_ds = dataframe_to_dataset(train_dataframe)
val_ds = dataframe_to_dataset(val_dataframe)
Each Dataset
yields a tuple (input, target)
where input
is a dictionary of features
and target
is the value 0
or 1
:
for x, y in train_ds.take(1):
print("Input:", x)
print("Target:", y)
Input: {'age': <tf.Tensor: shape=(), dtype=int64, numpy=64>, 'sex': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'cp': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'trestbps': <tf.Tensor: shape=(), dtype=int64, numpy=128>, 'chol': <tf.Tensor: shape=(), dtype=int64, numpy=263>, 'fbs': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'restecg': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'thalach': <tf.Tensor: shape=(), dtype=int64, numpy=105>, 'exang': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'oldpeak': <tf.Tensor: shape=(), dtype=float64, numpy=0.2>, 'slope': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'ca': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'thal': <tf.Tensor: shape=(), dtype=string, numpy=b'reversible'>}
Target: tf.Tensor(0, shape=(), dtype=int64)
Let's batch the datasets:
train_ds = train_ds.batch(32)
val_ds = val_ds.batch(32)
The following features are categorical features encoded as integers:
sex
cp
fbs
restecg
exang
ca
We will encode these features using one-hot encoding. We have two options here:
CategoryEncoding()
, which requires knowing the range of input values
and will error on input outside the range.IntegerLookup()
which will build a lookup table for inputs and reserve
an output index for unkown input values.For this example, we want a simple solution that will handle out of range inputs
at inference, so we will use IntegerLookup()
.
We also have a categorical feature encoded as a string: thal
. We will create an
index of all possible features and encode output using the StringLookup()
layer.
Finally, the following feature are continuous numerical features:
age
trestbps
chol
thalach
oldpeak
slope
For each of these features, we will use a Normalization()
layer to make sure the mean
of each feature is 0 and its standard deviation is 1.
Below, we define 2 utility functions to do the operations:
encode_numerical_feature
to apply featurewise normalization to numerical features.encode_categorical_feature
to one-hot encode string or integer categorical features.def encode_numerical_feature(feature, name, dataset):
# Create a Normalization layer for our feature
normalizer = layers.Normalization()
# Prepare a Dataset that only yields our feature
feature_ds = dataset.map(lambda x, y: x[name])
feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))
# Learn the statistics of the data
normalizer.adapt(feature_ds)
# Normalize the input feature
encoded_feature = normalizer(feature)
return encoded_feature
def encode_categorical_feature(feature, name, dataset, is_string):
lookup_class = layers.StringLookup if is_string else layers.IntegerLookup
# Create a lookup layer which will turn strings into integer indices
lookup = lookup_class(output_mode="binary")
# Prepare a Dataset that only yields our feature
feature_ds = dataset.map(lambda x, y: x[name])
feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))
# Learn the set of possible string values and assign them a fixed integer index
lookup.adapt(feature_ds)
# Turn the string input into integer indices
encoded_feature = lookup(feature)
return encoded_feature
With this done, we can create our end-to-end model:
# Categorical features encoded as integers
sex = keras.Input(shape=(1,), name="sex", dtype="int64")
cp = keras.Input(shape=(1,), name="cp", dtype="int64")
fbs = keras.Input(shape=(1,), name="fbs", dtype="int64")
restecg = keras.Input(shape=(1,), name="restecg", dtype="int64")
exang = keras.Input(shape=(1,), name="exang", dtype="int64")
ca = keras.Input(shape=(1,), name="ca", dtype="int64")
# Categorical feature encoded as string
thal = keras.Input(shape=(1,), name="thal", dtype="string")
# Numerical features
age = keras.Input(shape=(1,), name="age")
trestbps = keras.Input(shape=(1,), name="trestbps")
chol = keras.Input(shape=(1,), name="chol")
thalach = keras.Input(shape=(1,), name="thalach")
oldpeak = keras.Input(shape=(1,), name="oldpeak")
slope = keras.Input(shape=(1,), name="slope")
all_inputs = [
sex,
cp,
fbs,
restecg,
exang,
ca,
thal,
age,
trestbps,
chol,
thalach,
oldpeak,
slope,
]
# Integer categorical features
sex_encoded = encode_categorical_feature(sex, "sex", train_ds, False)
cp_encoded = encode_categorical_feature(cp, "cp", train_ds, False)
fbs_encoded = encode_categorical_feature(fbs, "fbs", train_ds, False)
restecg_encoded = encode_categorical_feature(restecg, "restecg", train_ds, False)
exang_encoded = encode_categorical_feature(exang, "exang", train_ds, False)
ca_encoded = encode_categorical_feature(ca, "ca", train_ds, False)
# String categorical features
thal_encoded = encode_categorical_feature(thal, "thal", train_ds, True)
# Numerical features
age_encoded = encode_numerical_feature(age, "age", train_ds)
trestbps_encoded = encode_numerical_feature(trestbps, "trestbps", train_ds)
chol_encoded = encode_numerical_feature(chol, "chol", train_ds)
thalach_encoded = encode_numerical_feature(thalach, "thalach", train_ds)
oldpeak_encoded = encode_numerical_feature(oldpeak, "oldpeak", train_ds)
slope_encoded = encode_numerical_feature(slope, "slope", train_ds)
all_features = layers.concatenate(
[
sex_encoded,
cp_encoded,
fbs_encoded,
restecg_encoded,
exang_encoded,
slope_encoded,
ca_encoded,
thal_encoded,
age_encoded,
trestbps_encoded,
chol_encoded,
thalach_encoded,
oldpeak_encoded,
]
)
x = layers.Dense(32, activation="relu")(all_features)
x = layers.Dropout(0.5)(x)
output = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(all_inputs, output)
model.compile("adam", "binary_crossentropy", metrics=["accuracy"])
Let's visualize our connectivity graph:
# `rankdir='LR'` is to make the graph horizontal.
keras.utils.plot_model(model, show_shapes=True, rankdir="LR")
model.fit(train_ds, epochs=50, validation_data=val_ds)
Epoch 1/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 46ms/step - accuracy: 0.3932 - loss: 0.8749 - val_accuracy: 0.3303 - val_loss: 0.7814
Epoch 2/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - accuracy: 0.4262 - loss: 0.8375 - val_accuracy: 0.4914 - val_loss: 0.6980
Epoch 3/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.4835 - loss: 0.7350 - val_accuracy: 0.6541 - val_loss: 0.6320
Epoch 4/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.5932 - loss: 0.6665 - val_accuracy: 0.7543 - val_loss: 0.5743
Epoch 5/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.5861 - loss: 0.6600 - val_accuracy: 0.7683 - val_loss: 0.5360
Epoch 6/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.6489 - loss: 0.6020 - val_accuracy: 0.7748 - val_loss: 0.4998
Epoch 7/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.6880 - loss: 0.5668 - val_accuracy: 0.7699 - val_loss: 0.4800
Epoch 8/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.7572 - loss: 0.5009 - val_accuracy: 0.7559 - val_loss: 0.4573
Epoch 9/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.7492 - loss: 0.5192 - val_accuracy: 0.8060 - val_loss: 0.4414
Epoch 10/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.7212 - loss: 0.4973 - val_accuracy: 0.8077 - val_loss: 0.4259
Epoch 11/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.7616 - loss: 0.4704 - val_accuracy: 0.7904 - val_loss: 0.4143
Epoch 12/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8374 - loss: 0.4342 - val_accuracy: 0.7872 - val_loss: 0.4061
Epoch 13/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.7863 - loss: 0.4630 - val_accuracy: 0.7888 - val_loss: 0.3980
Epoch 14/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.7742 - loss: 0.4492 - val_accuracy: 0.7996 - val_loss: 0.3998
Epoch 15/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8083 - loss: 0.4280 - val_accuracy: 0.8060 - val_loss: 0.3855
Epoch 16/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8058 - loss: 0.4191 - val_accuracy: 0.8217 - val_loss: 0.3819
Epoch 17/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8071 - loss: 0.4111 - val_accuracy: 0.8389 - val_loss: 0.3763
Epoch 18/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.8533 - loss: 0.3676 - val_accuracy: 0.8373 - val_loss: 0.3792
Epoch 19/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8170 - loss: 0.3850 - val_accuracy: 0.8357 - val_loss: 0.3744
Epoch 20/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8207 - loss: 0.3767 - val_accuracy: 0.8168 - val_loss: 0.3759
Epoch 21/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8151 - loss: 0.3596 - val_accuracy: 0.8217 - val_loss: 0.3685
Epoch 22/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.7988 - loss: 0.4087 - val_accuracy: 0.8184 - val_loss: 0.3701
Epoch 23/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8180 - loss: 0.3632 - val_accuracy: 0.8217 - val_loss: 0.3614
Epoch 24/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8295 - loss: 0.3504 - val_accuracy: 0.8200 - val_loss: 0.3683
Epoch 25/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8386 - loss: 0.3864 - val_accuracy: 0.8200 - val_loss: 0.3655
Epoch 26/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8482 - loss: 0.3345 - val_accuracy: 0.8044 - val_loss: 0.3639
Epoch 27/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.8340 - loss: 0.3470 - val_accuracy: 0.8077 - val_loss: 0.3616
Epoch 28/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8418 - loss: 0.3684 - val_accuracy: 0.8060 - val_loss: 0.3629
Epoch 29/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8309 - loss: 0.3147 - val_accuracy: 0.8060 - val_loss: 0.3637
Epoch 30/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8722 - loss: 0.3151 - val_accuracy: 0.8044 - val_loss: 0.3672
Epoch 31/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.8746 - loss: 0.3043 - val_accuracy: 0.8060 - val_loss: 0.3637
Epoch 32/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8794 - loss: 0.3245 - val_accuracy: 0.8200 - val_loss: 0.3685
Epoch 33/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.8644 - loss: 0.3541 - val_accuracy: 0.8357 - val_loss: 0.3714
Epoch 34/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8867 - loss: 0.3007 - val_accuracy: 0.8373 - val_loss: 0.3680
Epoch 35/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8737 - loss: 0.3168 - val_accuracy: 0.8357 - val_loss: 0.3695
Epoch 36/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8191 - loss: 0.3298 - val_accuracy: 0.8357 - val_loss: 0.3736
Epoch 37/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8613 - loss: 0.3543 - val_accuracy: 0.8357 - val_loss: 0.3745
Epoch 38/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8835 - loss: 0.2835 - val_accuracy: 0.8357 - val_loss: 0.3707
Epoch 39/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8784 - loss: 0.2893 - val_accuracy: 0.8357 - val_loss: 0.3716
Epoch 40/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8919 - loss: 0.2587 - val_accuracy: 0.8168 - val_loss: 0.3770
Epoch 41/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8882 - loss: 0.2660 - val_accuracy: 0.8217 - val_loss: 0.3674
Epoch 42/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8790 - loss: 0.2931 - val_accuracy: 0.8200 - val_loss: 0.3723
Epoch 43/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8851 - loss: 0.2892 - val_accuracy: 0.8200 - val_loss: 0.3733
Epoch 44/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8504 - loss: 0.3189 - val_accuracy: 0.8200 - val_loss: 0.3755
Epoch 45/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8610 - loss: 0.3116 - val_accuracy: 0.8184 - val_loss: 0.3788
Epoch 46/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.8956 - loss: 0.2544 - val_accuracy: 0.8184 - val_loss: 0.3738
Epoch 47/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9080 - loss: 0.2895 - val_accuracy: 0.8217 - val_loss: 0.3750
Epoch 48/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8706 - loss: 0.2993 - val_accuracy: 0.8217 - val_loss: 0.3757
Epoch 49/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8724 - loss: 0.2979 - val_accuracy: 0.8184 - val_loss: 0.3781
Epoch 50/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.8609 - loss: 0.2937 - val_accuracy: 0.8217 - val_loss: 0.3791
<keras.src.callbacks.history.History at 0x7efc32e01780>
We quickly get to 80% validation accuracy.
To get a prediction for a new sample, you can simply call model.predict()
. There are
just two things you need to do:
convert_to_tensor
on each featuresample = {
"age": 60,
"sex": 1,
"cp": 1,
"trestbps": 145,
"chol": 233,
"fbs": 1,
"restecg": 2,
"thalach": 150,
"exang": 0,
"oldpeak": 2.3,
"slope": 3,
"ca": 0,
"thal": "fixed",
}
input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}
predictions = model.predict(input_dict)
print(
f"This particular patient had a {100 * predictions[0][0]:.1f} "
"percent probability of having a heart disease, "
"as evaluated by our model."
)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 252ms/step
This particular patient had a 27.6 percent probability of having a heart disease, as evaluated by our model.