{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Titanic data set example\n", "\n", "Note: \n", "The focus of this example is less on finding anomalies but rather to illustrate model explanability in the case of categorical and continuous features." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.datasets import fetch_openml\n", "from bhad.model import BHAD" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pclassnamesexagesibspparchticketfarecabinembarkedboatbodyhome.dest
01Allen, Miss. Elisabeth Waltonfemale29.00000024160211.3375B5S2NaNSt Louis, MO
11Allison, Master. Hudson Trevormale0.916712113781151.5500C22 C26S11NaNMontreal, PQ / Chesterville, ON
\n", "
" ], "text/plain": [ " pclass name sex age sibsp parch \\\n", "0 1 Allen, Miss. Elisabeth Walton female 29.0000 0 0 \n", "1 1 Allison, Master. Hudson Trevor male 0.9167 1 2 \n", "\n", " ticket fare cabin embarked boat body \\\n", "0 24160 211.3375 B5 S 2 NaN \n", "1 113781 151.5500 C22 C26 S 11 NaN \n", "\n", " home.dest \n", "0 St Louis, MO \n", "1 Montreal, PQ / Chesterville, ON " ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X, y = fetch_openml(\"titanic\", version=1, as_frame=True, return_X_y=True)\n", "\n", "X.head(2)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Index: 684 entries, 0 to 1281\n", "Data columns (total 8 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 pclass 684 non-null int64 \n", " 1 sex 684 non-null category\n", " 2 age 684 non-null float64 \n", " 3 sibsp 684 non-null int64 \n", " 4 parch 684 non-null int64 \n", " 5 fare 684 non-null float64 \n", " 6 embarked 684 non-null category\n", " 7 home.dest 684 non-null object \n", "dtypes: category(2), float64(2), int64(3), object(1)\n", "memory usage: 39.0+ KB\n" ] } ], "source": [ "X_cleaned = X.drop(['body', 'cabin', 'name', 'ticket', 'boat'], axis=1).dropna() # not needed\n", "y_cleaned = y[X_cleaned.index]\n", "\n", "X_cleaned.info(verbose=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Partition dataset:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(458, 8)\n", "(226, 8)\n", "(array(['0', '1'], dtype=object), array([242, 216]))\n", "(array(['0', '1'], dtype=object), array([122, 104]))\n" ] } ], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X_cleaned, y_cleaned, test_size=0.33, random_state=42)\n", "\n", "print(X_train.shape)\n", "print(X_test.shape)\n", "\n", "print(np.unique(y_train, return_counts=True))\n", "print(np.unique(y_test, return_counts=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train model and create local/global model explanation:\n", "\n", "Retrieve local model explanations. Here: Specify all numeric and categorical columns explicitly" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "num_cols = list(X_train.select_dtypes(include=['float', 'int']).columns) \n", "cat_cols = list(X_train.select_dtypes(include=['object', 'category']).columns)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Score your train set:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training predictions: (array([-1, 1]), array([ 5, 453]))\n" ] } ], "source": [ "model = BHAD(\n", " contamination=0.01,\n", " num_features=num_cols,\n", " cat_features=cat_cols,\n", " nbins=None, \n", " verbose=False\n", ")\n", "\n", "y_pred_train_new = model.fit_predict(X_train)\n", "scores_train_new = model.decision_function(X_train)\n", "\n", "print(\"Training predictions:\", np.unique(y_pred_train_new, return_counts=True))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--- BHAD Model Explainer ---\n", "\n", "Using fitted BHAD and discretizer.\n", "Marginal distributions estimated using train set of shape (458, 8)\n" ] } ], "source": [ "from bhad import explainer\n", "\n", "local_expl = explainer.Explainer(bhad_obj=model, discretize_obj=model._discretizer).fit()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Create local explanations for 458 observations.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "968728c99c00492c90ab2592bf3c6222", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/458 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
avg ranks
embarked0.152058
sex0.257304
parch0.279548
sibsp0.444223
age0.491700
fare0.546813
pclass0.634462
home.dest1.000000
\n", "" ], "text/plain": [ " avg ranks\n", "embarked 0.152058\n", "sex 0.257304\n", "parch 0.279548\n", "sibsp 0.444223\n", "age 0.491700\n", "fare 0.546813\n", "pclass 0.634462\n", "home.dest 1.000000" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "global_feat_imp = local_expl.global_feat_imp # based on X_train\n", "global_feat_imp" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Get global model explanation (in decreasing order):" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from matplotlib import pyplot as plt\n", "\n", "plt.barh(global_feat_imp.index, global_feat_imp.values.flatten())\n", "plt.xlabel(\"Feature importances\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Get local explanations, i.e. feature importances (in decreasing order):" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Obs. 0:\n", " parch (Cumul.perc.: 0.996): 5.0\n", "home.dest (Perc.: 0.011): Sweden Winnipeg, MN\n", "sex (Perc.: 0.4): female\n", "\n", "Obs. 100:\n", " home.dest (Perc.: 0.002): Tofta, Sweden Joliet, IL\n", "fare (Cumul.perc.: 0.07): 7.78\n", "\n", "Obs. 200:\n", " home.dest (Perc.: 0.013): Brooklyn, NY\n", "\n", "Obs. 300:\n", " home.dest (Perc.: 0.007): Bournmouth, England\n", "age (Cumul.perc.: 0.05): 5.0\n", "sex (Perc.: 0.4): female\n", "\n", "Obs. 400:\n", " home.dest (Perc.: 0.002): Taalintehdas, Finland Hoboken, NJ\n" ] } ], "source": [ "for obs, ex in enumerate(df_train.explanation.values):\n", " if (obs % 100) == 0:\n", " print(f'\\nObs. {obs}:\\n', ex)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "y_pred_test = model.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Create local explanations for 226 observations.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "65da9be295504cdba17d044a1ae50524", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/226 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pclasssexagesibspparchfareembarkedhome.destexplanation
02.0male36.01.02.027.7500SBournmouth, Englandhome.dest (Perc.: 0.007): Bournmouth, England
11.0male49.01.01.0110.8833CHaverford, PAhome.dest (Perc.: 0.007): Haverford, PA\\nfare ...
\n", "" ], "text/plain": [ " pclass sex age sibsp parch fare embarked home.dest \\\n", "0 2.0 male 36.0 1.0 2.0 27.7500 S Bournmouth, England \n", "1 1.0 male 49.0 1.0 1.0 110.8833 C Haverford, PA \n", "\n", " explanation \n", "0 home.dest (Perc.: 0.007): Bournmouth, England \n", "1 home.dest (Perc.: 0.007): Haverford, PA\\nfare ... " ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_test = local_expl.get_explanation(nof_feat_expl = 4)\n", "df_test.head(2)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Obs. 0:\n", " home.dest (Perc.: 0.007): Bournmouth, England\n", "\n", "Obs. 50:\n", " home.dest (Perc.: 0.002): Deephaven, MN / Cedar Rapids, IA\n", "fare (Cumul.perc.: 0.91): 106.42\n", "\n", "Obs. 100:\n", " home.dest (Perc.: 0.002): Hudson, NY\n", "sex (Perc.: 0.4): female\n", "\n", "Obs. 150:\n", " home.dest (Perc.: 0.0): ?Havana, Cuba\n", "\n", "Obs. 200:\n", " embarked (Perc.: 0.048): Q\n", "home.dest (Perc.: 0.0): Co Sligo, Ireland Hartford, CT\n", "sex (Perc.: 0.4): female\n", "fare (Cumul.perc.: 0.061): 7.75\n" ] } ], "source": [ "for obs, ex in enumerate(df_test.explanation.values):\n", " if (obs % 50) == 0:\n", " print(f'\\nObs. {obs}:\\n', ex)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
avg ranks
embarked0.157711
parch0.245639
sex0.256804
sibsp0.441731
age0.480112
fare0.575715
pclass0.638521
home.dest1.000000
\n", "
" ], "text/plain": [ " avg ranks\n", "embarked 0.157711\n", "parch 0.245639\n", "sex 0.256804\n", "sibsp 0.441731\n", "age 0.480112\n", "fare 0.575715\n", "pclass 0.638521\n", "home.dest 1.000000" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "local_expl.global_feat_imp # based on X_test" ] } ], "metadata": { "kernelspec": { "display_name": "bayes-anomaly", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 4 }