_Copyright (C) 2025 Jagger Alexander_

_This program is free software: you can redistribute it and/or modify it as published._

_This program is distributed in the hope that it will be informative, but without any warranty; without even the implied warranty of merchantability or fitness for a particular purpose._


### Analyzing Drivers of Heat Waves

1. Import necessary packages 
2. Import dataset
3. Examine dataset  
   i. Count heatwaves by decade  
   ii. Plot heatwave metrics
4. Evaluate model accuracy  
   i. Baseline PR AUC  
   ii. Model PR AUC*  
   iii. Accuracy metrics*  
   iv. Confusion matrices*
5. Interpret models  
   i. SHAP value dictionary*  
   ii. Plot SHAP by city  
   iii. Create SHAP information for one feature*  
   iv. Plot SHAP feature information
   v. Plot bivariate MJO SHAP
6. Las Vegas case study*

_*Will take significant runtime/processing power_

Note: Code blocks should be run sequentially.

In [None]:
## 1. IMPORT NECESSARY PACAKGES

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from matplotlib.colors import LinearSegmentedColormap
import math
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import KFold
from sklearn.dummy import DummyClassifier
from sklearn.metrics import accuracy_score, recall_score, average_precision_score, precision_score, f1_score, precision_recall_curve, confusion_matrix
import shap
import seaborn as sns
from statsmodels.nonparametric.smoothers_lowess import lowess
from scipy.stats import binned_statistic, binned_statistic_2d, sem


In [None]:
## 2. IMPORT DATASET

city_dfs = pd.read_pickle('city_dfs.pkl')
print(list(city_dfs['austin'].columns))

In [None]:
## 3. EXAMINE DATASET

## i. COUNT HEATWAVES BY DECADE 

heatwave_summary = {}
decade_summary = {}

for city, df in city_dfs.items():
    df = df.copy()
    df['date'] = pd.to_datetime(df['date'])
    df['year'] = df['date'].dt.year
    df['month'] = df['date'].dt.month

    summer_df = df[df['month'].between(6, 9)]

    hwd_by_year = summer_df[summer_df['hwd']].groupby('year').size()
    heatwave_summary[city.capitalize()] = hwd_by_year

    df_year = hwd_by_year.reset_index(name='count')
    df_year['decade'] = (df_year['year'] // 10) * 10
    hwd_by_decade = df_year.groupby('decade')['count'].sum()
    hwd_by_decade.index = hwd_by_decade.index.astype(int).astype(str) + 's'
    decade_summary[city.capitalize()] = hwd_by_decade

decade_df = pd.DataFrame(decade_summary).fillna(0).astype(int)

decade_df.loc['Total'] = decade_df.sum()

decade_df['Total'] = decade_df.sum(axis=1)

print("\nHeatwave Days Per Decade (June–September) by City (with Totals):")
print(decade_df)

heatwave_stretch_decade_summary = {}

for city, df in city_dfs.items():
    df = df.copy()
    df['date'] = pd.to_datetime(df['date'])
    df['year'] = df['date'].dt.year
    df['month'] = df['date'].dt.month
    df['decade'] = (df['year'] // 10) * 10

    summer_df = df[df['month'].between(6, 9)].copy()

    summer_df['hwd_shift'] = summer_df['hwd'].shift(fill_value=False)
    stretches = summer_df[(summer_df['hwd']) & (~summer_df['hwd_shift'])]

    stretch_counts = stretches.groupby('decade').size()
    stretch_counts.index = stretch_counts.index.astype(str) + 's'

    heatwave_stretch_decade_summary[city.capitalize()] = stretch_counts

heatwave_stretch_decade_df = (
    pd.DataFrame(heatwave_stretch_decade_summary)
      .fillna(0)
      .astype(int)
)

heatwave_stretch_decade_df.loc['Total'] = heatwave_stretch_decade_df.sum()

heatwave_stretch_decade_df['Total'] = heatwave_stretch_decade_df.sum(axis=1)

print("\nHeatwave Stretches Per Decade (June–September) by City with Totals:")
print(heatwave_stretch_decade_df)

In [None]:
## ii. PLOT HEATWAVE METRICS

sns.set_style("white")
rename_map = {
    "Mexicocity": "Mexico City",
    "Lasvegas":   "Las Vegas"
}


custom_periods = {
    (1980, 1988): "1980-1988",
    (1989, 1997): "1989-1997",
    (1998, 2006): "1998-2006",
    (2007, 2015): "2007-2015",
    (2015, 2024): "2015-2024",
}

def assign_custom_period(year):
    for (start, end), label in custom_periods.items():
        if start <= year <= end:
            return label
    return None

def get_heatwave_durations(df):
    df = df.sort_values("date")
    df["event"] = (df["hwd"] != df["hwd"].shift()).cumsum()
    hw = (
        df[df["hwd"]]
        .groupby("event")
        .agg(start_date=("date","min"), end_date=("date","max"), duration=("date","count"))
        .reset_index()
    )
    hw["duration"] = hw["duration"].clip(lower=3)
    hw["year"] = hw["start_date"].dt.year
    hw["custom_period"] = hw["year"].apply(assign_custom_period)
    return hw.dropna(subset=["custom_period"])

temp_records = []
dur_records = []
days_records = []

for city, df in city_dfs.items():
    cname = city.capitalize()
    df = df.copy()
    df['date'] = pd.to_datetime(df['date'])
    df['city'] = cname
    df['year'] = df['date'].dt.year
    df['month'] = df['date'].dt.month
    df['custom_period'] = df['year'].apply(assign_custom_period)
    df['temperature_c'] = df['temperature_2m_max'] - 273.15

    hw_temp = df[df['hwd'] & df['custom_period'].notna()]
    for _, r in hw_temp.iterrows():
        temp_records.append({'city': cname, 'custom_period': r['custom_period'], 'temperature_c': r['temperature_c']})

    hw_dur = get_heatwave_durations(df)
    for _, r in hw_dur.iterrows():
        dur_records.append({'city': cname, 'custom_period': r['custom_period'], 'duration': r['duration']})

    summer = df[df['month'].between(6, 9)]
    hw_days = summer[summer['hwd'] & summer['custom_period'].notna()]
    days_records.append(
        hw_days
        .groupby(['city','year','custom_period'])
        .size()
        .reset_index(name='num_heatwave_days')
    )

heatwave_temp_df = pd.DataFrame(temp_records)
heatwave_duration_df = pd.DataFrame(dur_records)
heatwave_days_df = pd.concat(days_records, ignore_index=True)

mean_temp = heatwave_temp_df.groupby("custom_period")["temperature_c"].mean()
ci_temp   = heatwave_temp_df.groupby("custom_period")["temperature_c"].apply(lambda x: sem(x)*1.96)

mean_dur = heatwave_duration_df.groupby("custom_period")["duration"].mean()
ci_dur   = heatwave_duration_df.groupby("custom_period")["duration"].apply(lambda x: sem(x)*1.96)

mean_days = heatwave_days_df.groupby("custom_period")["num_heatwave_days"].mean()
ci_days   = heatwave_days_df.groupby("custom_period")["num_heatwave_days"].apply(lambda x: sem(x)*1.96)

city_temp_means = heatwave_temp_df.groupby(["custom_period","city"])["temperature_c"].mean().reset_index()
city_dur_means = heatwave_duration_df.groupby(["custom_period","city"])["duration"].mean().reset_index()
city_day_means = heatwave_days_df.groupby(["custom_period","city"])["num_heatwave_days"].mean().reset_index()

for df in (city_temp_means, city_day_means, city_dur_means):
    df["city"] = df["city"].replace(rename_map)
    
palette = sns.color_palette("muted", len(city_dfs))


def plot_metric(mean_vals, ci_vals, df_all, df_city, y, ylabel, title, ylim=None, show_legend=False):
    categories = list(mean_vals.index)
    df_all = df_all.copy()
    df_city = df_city.copy()
    df_all["custom_period"] = pd.Categorical(df_all["custom_period"], categories=categories, ordered=True)
    df_city["custom_period"] = pd.Categorical(df_city["custom_period"], categories=categories, ordered=True)

    fig, ax = plt.subplots(figsize=(10,6))
    sns.barplot(x=mean_vals.index, y=mean_vals.values, color="white", alpha=0.7, ax=ax)
    ax.errorbar(np.arange(len(mean_vals)), mean_vals.values, yerr=ci_vals,
                fmt="none", capsize=3, capthick=2, color="black", linewidth=1.5,zorder=10)

    sns.violinplot(x="custom_period", y=y, data=df_all, inner=None,
                   color="gray", alpha=0.3, cut=0, ax=ax)

    sns.stripplot(x="custom_period", y=y, data=df_city, hue="city",
                  palette=palette, dodge=False, jitter=False, size=8,
                  ax=ax, legend='brief' if show_legend else False, alpha=0.7)

    ax.set_xlabel("Period")
    ax.set_ylabel(ylabel)
    if ylim is not None:
        ax.set_ylim(*ylim)
    ax.set_title(title)
    ax.grid(False)

    if show_legend:
        ci_line = mlines.Line2D([], [], color='black', marker='_',
                                linestyle='None', markersize=10,
                                markeredgewidth=2, label='95% CI')
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles + [ci_line], labels + ['95% CI'], frameon=True)

    #plt.savefig("title.png",dpi=300,bbox_inches='tight')
    plt.tight_layout()
    plt.show()

plot_metric(mean_temp, ci_temp, heatwave_temp_df, city_temp_means,
            "temperature_c", "Heatwave Temperature (°C)",
            "Mean Heatwave Temperature", ylim=(10,50), show_legend=False)

plot_metric(mean_days, ci_days, heatwave_days_df, city_day_means,
            "num_heatwave_days", "Number of Heatwave Days",
            "Mean Number of Heatwave Days", show_legend=False)

plot_metric(mean_dur, ci_dur, heatwave_duration_df, city_dur_means,
            "duration", "Heatwave Duration (Days)",
            "Mean Heatwave Duration", ylim=(3,None), show_legend=True)


In [None]:
## 3. Evaluate model accuracy

## i. BASELINE PR AUC

baseline_results = {}

for city, df in city_dfs.items():

    if not np.issubdtype(df['date'].dtype, np.datetime64):
        df['date'] = pd.to_datetime(df['date'])
    df = df[df['date'].dt.month.between(6, 9)]

    features = [col for col in df.columns if '_21' in col or '_24' in col or '_28' in col]
    X = df[features]
    y = df['hwd']
    

    dummy = DummyClassifier(strategy='prior')
    dummy.fit(X, y)
    y_dummy_prob = dummy.predict_proba(X)[:, 1]
    

    baseline_ap = average_precision_score(y, y_dummy_prob)
    baseline_results[city] = baseline_ap

baseline_df = pd.DataFrame.from_dict(baseline_results, orient='index', columns=['Baseline PR AUC'])
baseline_df.index = baseline_df.index.str.title()
print(baseline_df)


In [None]:
## ii. MODEL PR AUC

base_params = {
    'objective': 'binary:logistic',
    'eval_metric': 'logloss',
    'learning_rate': 0.2,
    'max_depth': 4,
    'subsample': 0.8,
    'colsample_bytree': 0.8
}

n_models = 10  
seed_list = [1,2,3,4,5,6,7,8,9,10]  

pr_auc_results = {} 

for city, df in city_dfs.items():
    print(f"Processing Precision-Recall AUC for city: {city}")
    
    if not np.issubdtype(df['date'].dtype, np.datetime64):
        df['date'] = pd.to_datetime(df['date'])
    df = df[df['date'].dt.month.between(6, 9)]

    features = [col for col in df.columns if '_21' in col or '_24' in col or '_28' in col]
    X = df[features]
    y = df['hwd']
    
    N_pos = y.sum()          
    N_total = len(df)
    N_neg = N_total - N_pos  
    scale_pos_weight = (N_neg / N_pos) if N_pos > 0 else 1
    
    params = base_params.copy()
    params['scale_pos_weight'] = scale_pos_weight
    print(f"  For {city}: Positives = {N_pos}, Negatives = {N_neg}, scale_pos_weight = {scale_pos_weight:.3f}")
    
    pr_auc_list = []  
    
    for seed in seed_list:
        kf = KFold(n_splits=4, shuffle=True, random_state=seed)
        for train_index, test_index in kf.split(X, y):
            X_train, X_test = X.iloc[train_index], X.iloc[test_index]
            y_train, y_test = y.iloc[train_index], y.iloc[test_index]
            
            dtrain = xgb.DMatrix(X_train, label=y_train)
            dtest = xgb.DMatrix(X_test, label=y_test)
            
            bst = xgb.train(params, dtrain, num_boost_round=100)
            
            y_pred_prob = bst.predict(dtest)
            
            pr_auc = average_precision_score(y_test, y_pred_prob)
            pr_auc_list.append(pr_auc)
    
    mean_pr_auc = np.mean(pr_auc_list)
    std_pr_auc = np.std(pr_auc_list, ddof=1)
    n = len(pr_auc_list)
    se = std_pr_auc / math.sqrt(n)
    ci_low = mean_pr_auc - 1.96 * se
    ci_high = mean_pr_auc + 1.96 * se
    
    pr_auc_results[city] = {
        'Average PR AUC': mean_pr_auc,
        'Std Dev': std_pr_auc,
        '95% CI': (ci_low, ci_high)
    }
    print(f"Average PR AUC for {city}: {mean_pr_auc:.3f} (95% CI: {ci_low:.3f}, {ci_high:.3f})\n")

pr_auc_df = pd.DataFrame.from_dict(pr_auc_results, orient='index')
pr_auc_df.index = pr_auc_df.index.str.title()
pr_auc_df = pr_auc_df.reset_index().rename(columns={'index': 'City'})
print(pr_auc_df.to_string(index=False))

In [None]:
## iii. ACCURACY METRICS

results = {}  

for city, df in city_dfs.items():
    print(f"Processing metrics for city: {city}")

    if not np.issubdtype(df['date'].dtype, np.datetime64):
        df['date'] = pd.to_datetime(df['date'])
    df = df[df['date'].dt.month.between(6, 9)]
    
    features = [col for col in df.columns if '_21' in col or '_24' in col or '_28' in col]
    
    X = df[features]
    y = df['hwd']
    
    N_pos = y.sum()          
    N_total = len(df)
    N_neg = N_total - N_pos 
    if N_pos == 0:
        scale_pos_weight = 1
    else:
        scale_pos_weight = N_neg / N_pos
    
    params = base_params.copy()
    params['scale_pos_weight'] = scale_pos_weight
    print(f"  For {city}: Positives = {N_pos}, Negatives = {N_neg}, scale_pos_weight = {scale_pos_weight:.3f}")
    
    accuracy_list = []
    recall_list = []
    precision_list = []
    f1_list = []
    
    for model_idx, seed in enumerate(seed_list):
        kf = KFold(n_splits=4, shuffle=True, random_state=seed)
        print(f" Model {model_idx+1}: using seed {seed}")
        
        for train_index, test_index in kf.split(X, y):
            X_train, X_test = X.iloc[train_index], X.iloc[test_index]
            y_train, y_test = y.iloc[train_index], y.iloc[test_index]
            
            dtrain = xgb.DMatrix(X_train, label=y_train)
            dtest = xgb.DMatrix(X_test, label=y_test)
            
            bst = xgb.train(params, dtrain, num_boost_round=100)
            
            y_pred_prob = bst.predict(dtest)
            
            prec, rec, thresholds = precision_recall_curve(y_test, y_pred_prob)
            
            if len(thresholds) > 0:
                f1_scores = 2 * prec[1:] * rec[1:] / (prec[1:] + rec[1:] + 1e-8)
                optimal_threshold = thresholds[np.argmax(f1_scores)]
            else:
                optimal_threshold = 0.5

            y_pred = (y_pred_prob >= optimal_threshold).astype(int)
            
            accuracy_list.append(accuracy_score(y_test, y_pred))
            recall_list.append(recall_score(y_test, y_pred))
            precision_list.append(precision_score(y_test, y_pred))
            f1_list.append(f1_score(y_test, y_pred))
    
    def compute_stats(metric_list):
        arr = np.array(metric_list)
        mean_val = np.mean(arr)
        std_val = np.std(arr, ddof=1)
        n = len(arr)
        ci_low = mean_val - 1.96 * std_val / np.sqrt(n)
        ci_high = mean_val + 1.96 * std_val / np.sqrt(n)
        return mean_val, std_val, (ci_low, ci_high)
    
    acc_mean, acc_std, acc_ci = compute_stats(accuracy_list)
    rec_mean, rec_std, rec_ci = compute_stats(recall_list)
    prec_mean, prec_std, prec_ci = compute_stats(precision_list)
    f1_mean, f1_std, f1_ci = compute_stats(f1_list)
    
    metrics_df = pd.DataFrame({
        'Metric': ['Accuracy', 'Recall', 'Precision', 'F1 Score'],
        'Mean': [acc_mean, rec_mean, prec_mean, f1_mean],
        'Standard Deviation': [acc_std, rec_std, prec_std, f1_std],
        '95% Confidence Interval': [acc_ci, rec_ci, prec_ci, f1_ci]
    })
    
    results[city] = metrics_df
    print(f"\nResults for {city.capitalize()}:")
    print(metrics_df)
    print("\n" + "="*60 + "\n")


In [None]:
## iv. CONFUSION MATRICES

conf_matrices = {} 

for city, df in city_dfs.items():
    print(f"Processing confusion matrix for city: {city}")

    if not np.issubdtype(df['date'].dtype, np.datetime64):
        df['date'] = pd.to_datetime(df['date'])
    df = df[df['date'].dt.month.between(6, 9)]
    
    features = [col for col in df.columns if '_21' in col or '_24' in col or '_28' in col]
    X = df[features]
    y = df['hwd']
    
    N_pos = y.sum()          
    N_total = len(df)
    N_neg = N_total - N_pos  
    scale_pos_weight = (N_neg / N_pos) if N_pos > 0 else 1
    
    params = base_params.copy()
    params['scale_pos_weight'] = scale_pos_weight
    print(f"  {city}: Positives = {N_pos}, Negatives = {N_neg}, scale_pos_weight = {scale_pos_weight:.3f}")
    
    agg_cm = np.zeros((2, 2), dtype=int)
    
    for seed in seed_list:
        kf = KFold(n_splits=4, shuffle=True, random_state=seed)
        for train_index, test_index in kf.split(X, y):
            X_train, X_test = X.iloc[train_index], X.iloc[test_index]
            y_train, y_test = y.iloc[train_index], y.iloc[test_index]
            
            dtrain = xgb.DMatrix(X_train, label=y_train)
            dtest = xgb.DMatrix(X_test, label=y_test)
            
            bst = xgb.train(params, dtrain, num_boost_round=100)
            
            y_pred_prob = bst.predict(dtest)
            
            prec, rec, thresholds = precision_recall_curve(y_test, y_pred_prob)
            if len(thresholds) > 0:
                f1_scores = 2 * prec[1:] * rec[1:] / (prec[1:] + rec[1:] + 1e-8)
                optimal_threshold = thresholds[np.argmax(f1_scores)]
            else:
                optimal_threshold = 0.5
            
            y_pred = (y_pred_prob >= optimal_threshold).astype(int)
            
            cm = confusion_matrix(y_test, y_pred)
            agg_cm += cm
            
    conf_matrices[city] = agg_cm
    print(f"Aggregated confusion matrix for {city}:\n{agg_cm}\n")

cities = list(conf_matrices.keys())
n_cities = len(cities)

fig, axs = plt.subplots(2, 3, figsize=(18, 12), constrained_layout=True)
axs = axs.flatten()

vmin, vmax = 0, 100

for i, city in enumerate(cities):
    cm = conf_matrices[city]
    
    percent_matrix = np.zeros_like(cm, dtype=float)
    annot = np.empty_like(cm, dtype=object)
    for r in range(cm.shape[0]):
        row_sum = cm[r, :].sum()
        for c in range(cm.shape[1]):
            if row_sum > 0:
                percent_matrix[r, c] = (cm[r, c] / row_sum) * 100
            else:
                percent_matrix[r, c] = 0
            annot[r, c] = f"{cm[r, c]}\n({percent_matrix[r, c]:.1f}%)"
    
    ax = axs[i]
    hm = sns.heatmap(percent_matrix, 
                     annot=annot, fmt="", cmap="Blues", 
                     vmin=vmin, vmax=vmax,
                     cbar=False, ax=ax, square=True, linewidths=0.5, linecolor='gray')
    ax.set_title(city.title(), fontsize=16)
    ax.set_xticklabels(['Pred Non-HWD', 'Pred HWD'], fontsize=14)
    ax.set_yticklabels(['Actual Non-HWD', 'Actual HWD'], fontsize=14, rotation=0)


norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
mappable = mpl.cm.ScalarMappable(norm=norm, cmap="Blues")
mappable.set_array([]) 
cbar = fig.colorbar(mappable, ax=axs, orientation='vertical', fraction=0.02, pad=0.04)
cbar.set_label("Percentage (%)", fontsize=16)
cbar.ax.tick_params(labelsize=14)

plt.suptitle("Confusion Matrices by City", fontsize=20)
plt.show()


In [None]:
## 5. Interpret models

## i. SHAP VALUE DICTIONARY* 

shap_values_by_city = []  

for city, df in city_dfs.items():
    print(f"Processing SHAP values for city: {city}")

    if not np.issubdtype(df['date'].dtype, np.datetime64):
        df['date'] = pd.to_datetime(df['date'])
    df = df[df['date'].dt.month.between(6, 9)]
    
    features = [col for col in df.columns if '_21' in col or '_24' in col or '_28' in col]
    
    X = df[features]
    y = df['hwd']
    
    N_pos = y.sum()  
    N_total = len(df)
    N_neg = N_total - N_pos  
    scale_pos_weight = (N_neg / N_pos) if N_pos > 0 else 1
    
    params = base_params.copy()
    params['scale_pos_weight'] = scale_pos_weight

    mean_shap_values = {feature: [] for feature in features}

    for seed in seed_list:
        kf = KFold(n_splits=4, shuffle=True, random_state=seed)

        for train_index, test_index in kf.split(X, y):
            X_train, X_test = X.iloc[train_index], X.iloc[test_index]
            y_train, y_test = y.iloc[train_index], y.iloc[test_index]

            dtrain = xgb.DMatrix(X_train, label=y_train)
            dtest = xgb.DMatrix(X_test, label=y_test)

            bst = xgb.train(params, dtrain, num_boost_round=100)

            explainer = shap.TreeExplainer(bst)
            shap_values = explainer.shap_values(X_test)

            for idx, feature in enumerate(features):
                mean_shap_values[feature].append(np.mean(np.abs(shap_values[:, idx])))

    feature_means = {feature: np.mean(vals) for feature, vals in mean_shap_values.items()}
    
    total_shap = sum(feature_means.values())

    for feature, mean_val in feature_means.items():
        normalized_shap = mean_val / total_shap if total_shap != 0 else 0
        shap_values_by_city.append({
            'City': city.capitalize(),
            'Feature': feature,
            'Normalized_SHAP': normalized_shap
        })

shap_df = pd.DataFrame(shap_values_by_city)

print(shap_df)


In [None]:
## ii. PLOT SHAP BY CITY

feature_name_map = {
    'pv_21': 'Potential Vorticity 300hPa 21d',
    'pv_24': 'Potential Vorticity 300hPa 24d',
    'pv_28': 'Potential Vorticity 300hPa 28d',
    'v_21': 'V-Wind 300hPa 21d',
    'v_24': 'V-Wind 300hPa 24d',
    'v_28': 'V-Wind 300hPa 28d',
    'u_21': 'U-Wind 300hPa 21d',
    'u_24': 'U-Wind 300hPa 24d',
    'u_28': 'U-Wind 300hPa 28d',
    'z_21': 'Geopotential 500hPa 21d',
    'z_24': 'Geopotential 500hPa 24d',
    'z_28': 'Geopotential 500hPa 28d',
    'Phase_21': 'MJO Phase 21d',
    'Phase_24': 'MJO Phase 24d',
    'Phase_28': 'MJO Phase 28d',
    'Amplitude_21': 'MJO Amplitude 21d',
    'Amplitude_24': 'MJO Amplitude 24d',
    'Amplitude_28': 'MJO Amplitude 28d',
    'temperature_2m_min_21': 'Tmin. 21d',
    'temperature_2m_min_24': 'Tmin. 24d',
    'temperature_2m_min_28': 'Tmin. 28d',
    'temperature_2m_max_21': 'Tmax. 21d',
    'temperature_2m_max_24': 'Tmax. 24d',
    'temperature_2m_max_28': 'Tmax. 28d',
    'surface_pressure_21': 'Surface Pressure 21d',
    'surface_pressure_24': 'Surface Pressure 24d',
    'surface_pressure_28': 'Surface Pressure 28d',
    'v_component_of_wind_10m_21': 'V-Wind 10m 21d',
    'v_component_of_wind_10m_24': 'V-Wind 10m 24d',
    'v_component_of_wind_10m_28': 'V-Wind 10m 28d',
    'u_component_of_wind_10m_21': 'U-Wind 10m 21d',
    'u_component_of_wind_10m_24': 'U-Wind 10m 24d',
    'u_component_of_wind_10m_28': 'U-Wind 10m 28d',
    'total_precipitation_sum_21': 'Total Precip. 21d',
    'total_precipitation_sum_24': 'Total Precip. 24d',
    'total_precipitation_sum_28': 'Total Precip. 28d',
    'relative_humidity_21': 'Relative Humidity 21d',
    'relative_humidity_24': 'Relative Humidity 24d',
    'relative_humidity_28': 'Relative Humidity 28d',
    'cal_sst_21d': 'California Current SST',
    'mex_sst_21d': 'Gulf of Mexico SST',
    'nino_sst_21d': 'Nino 3.4 Region SST',
    'SM_21d': 'Soil Moisture',
    'NAO_21d': 'NAO Index',
    'AO_21d': 'AO Index'
}

shap_df['Feature Remapped'] = shap_df['Feature'].apply(lambda x: feature_name_map.get(x, x))

shap_df['City'] = shap_df['City'].str.title()
shap_df['City'] = shap_df['City'].replace({
    'Lasvegas': 'Las Vegas',
    'Mexicocity': 'Mexico City'
})

overall_norm_mean = shap_df.groupby('Feature')['Normalized_SHAP'].mean()
shap_df['Overall Norm SHAP'] = shap_df['Feature'].map(overall_norm_mean)

ordered_features = overall_norm_mean.sort_values(ascending=False).index.tolist()
ordered_features_remapped = [feature_name_map.get(f, f) for f in ordered_features]

plt.figure(figsize=(10, 10))
sns.set(style="whitegrid")

unique_cities = shap_df['City'].unique()
city_palette = sns.color_palette("muted", n_colors=len(unique_cities))

for i, feature in enumerate(ordered_features):
    feature_data = shap_df[shap_df['Feature'] == feature]
    
    for _, row in feature_data.iterrows():
        plt.plot([row['Overall Norm SHAP'], row['Normalized_SHAP']], [i, i],
                 color='gray', alpha=0.5, zorder=1)
    
    sns.scatterplot(
        data=feature_data,
        x='Normalized_SHAP',
        y=[i] * len(feature_data),
        hue='City',
        palette=city_palette,
        s=100,
        legend=False,
        alpha=0.7,
        zorder=2
    )
    
    overall_val = feature_data['Overall Norm SHAP'].iloc[0]
    plt.scatter(overall_val, i, color='black', s=100, zorder=3)

plt.xlabel("Normalized SHAP Value", fontsize=14)
plt.ylabel("Feature", fontsize=14)
plt.title("Normalized SHAP Values Across Cities", fontsize=16)

plt.yticks(range(len(ordered_features_remapped)), ordered_features_remapped)

handles = []
labels = []

handles.append(mlines.Line2D([], [], color='black', marker='o', linestyle='None', markersize=8))
labels.append('Overall Mean')

for idx, city in enumerate(unique_cities):
    handles.append(mlines.Line2D([], [], color=city_palette[idx], marker='o', linestyle='None', markersize=8))
    labels.append(city)

plt.legend(handles=handles, labels=labels, loc='upper right', title='City')

plt.savefig("featureimportance.png",dpi=300,bbox_inches='tight')
plt.tight_layout()
plt.show()


In [None]:
## iii. CREATE SHAP FEATURE INFORMATION*

feature_of_interest = "nino_sst_21d"  # Change this to the feature you want

shap_results = {}

for city, df in city_dfs.items():
    if feature_of_interest not in df.columns:
        print(f"Skipping {city} - Feature {feature_of_interest} not found.")
        continue

    print(f"Processing SHAP for {feature_of_interest} in {city}...")

    if not np.issubdtype(df['date'].dtype, np.datetime64):
        df['date'] = pd.to_datetime(df['date'])
    df = df[df['date'].dt.month.between(6, 9)]

    features = [col for col in df.columns if '_21' in col or '_24' in col or '_28' in col]
    X = df[features]
    y = df['hwd']

    N_pos = y.sum()
    N_total = len(df)
    N_neg = N_total - N_pos
    scale_pos_weight = (N_neg / N_pos) if N_pos > 0 else 1

    all_shap_values = []
    grid_values = None

    for model_idx, seed in enumerate(seed_list):
        kf = KFold(n_splits=4, shuffle=True, random_state=seed)

        for train_index, test_index in kf.split(X, y):
            X_train, X_test = X.iloc[train_index], X.iloc[test_index]
            y_train, y_test = y.iloc[train_index], y.iloc[test_index]

            model = xgb.XGBClassifier(
                objective="binary:logistic",
                eval_metric="logloss",
                learning_rate=0.2,
                max_depth=4,
                subsample=0.8,
                colsample_bytree=0.8,
                scale_pos_weight=scale_pos_weight
            )
            model.fit(X_train, y_train)

            explainer = shap.TreeExplainer(model)
            shap_values = explainer.shap_values(X_train)

            try:
                feat_idx = features.index(feature_of_interest)
            except ValueError:
                continue

            x_vals = X_train.iloc[:, feat_idx].values
            shap_feat = shap_values[:, feat_idx]

            valid_mask = np.isfinite(x_vals) & np.isfinite(shap_feat)
            x_vals_clean = x_vals[valid_mask]
            shap_feat_clean = shap_feat[valid_mask]

            stat, bin_edges, _ = binned_statistic(x_vals_clean, shap_feat_clean, statistic='mean', bins=50)
            grid_model = (bin_edges[:-1] + bin_edges[1:]) / 2

            # Compute normalized feature contribution (as a percent of total SHAP contribution)
            per_feature_contrib = np.mean(np.abs(shap_values), axis=0)
            total_contrib = per_feature_contrib.sum()
            normalized_feat_contrib = per_feature_contrib[feat_idx] / total_contrib if total_contrib != 0 else 0

            binned_total = np.nansum(np.abs(stat))
            scale = normalized_feat_contrib / binned_total if binned_total != 0 else 0
            norm_stat = stat * scale

            all_shap_values.append(norm_stat)

            if grid_values is None:
                grid_values = grid_model

    all_shap_values = np.array(all_shap_values)

    mean_shap = np.nanmean(all_shap_values, axis=0)
    std_shap = np.nanstd(all_shap_values, axis=0, ddof=1)
    ci_low = mean_shap - 1.96 * std_shap / np.sqrt(n_models)
    ci_high = mean_shap + 1.96 * std_shap / np.sqrt(n_models)

    loess_result = lowess(mean_shap, grid_values, frac=0.3, return_sorted=True)
    loess_x = loess_result[:, 0]
    loess_y = loess_result[:, 1]

    shap_results[city] = {
        "feature_values": grid_values,
        "mean_shap": mean_shap,
        "ci_low": ci_low,
        "ci_high": ci_high,
        "loess_x": loess_x,
        "loess_y": loess_y
    }


In [None]:
## iv. PLOTTING SHAP FEATURE VALUE

plt.figure(figsize=(10, 6))
sns.set(style="whitegrid")

feature_of_interest = "Nino 3.4 Region SST"

unique_cities = list(shap_results.keys())
city_palette = sns.color_palette("muted", n_colors=len(unique_cities))

for i, (city, data) in enumerate(shap_results.items()):
    loess_x = data["loess_x"]
    loess_y = data["loess_y"]
    ci_low, ci_high = data["ci_low"], data["ci_high"]
    grid_values = data["feature_values"]

    plt.plot(loess_x, loess_y, label=city.capitalize(), color=city_palette[i], linewidth=2)
    plt.fill_between(grid_values, ci_low, ci_high, color=city_palette[i], alpha=0.2)

plt.axhline(0, color='black', linewidth=1, zorder=10)

plt.tick_params(axis='y', colors='black')
plt.gca().spines['left'].set_color('black')

plt.xlabel('SST (K)', fontsize=14)
plt.ylabel("SHAP Value", fontsize=14, color='black')
plt.title(feature_of_interest, fontsize=16)

# plt.legend(title="City", fontsize=12)

plt.tight_layout()
#plt.savefig(f"{feature_of_interest}_pdp.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
## v. BIVARIATE MJO INFORMATION SHAP

feature1 = "Phase_28"
feature2 = "Amplitude_28"
seed_list = [1, 2, 3, 4, 5]
grid_resolution = 20

city_name_map = {
    "austin": "Austin", "houston": "Houston", "dallas": "Dallas",
    "lasvegas": "Las Vegas", "mexicocity": "Mexico City", "phoenix": "Phoenix"
}

bivariate_shap_results = {}

for city, df in city_dfs.items():
    if feature1 not in df.columns or feature2 not in df.columns:
        print(f"Skipping {city} – required features not found.")
        continue

    print(f"Processing bivariate SHAP for {city}...")
    predictors = [col for col in df.columns if '_21' in col or '_24' in col or '_28' in col]
    if feature1 not in predictors or feature2 not in predictors:
        print(f"Skipping {city} – required features not in predictor set.")
        continue

    if not np.issubdtype(df['date'].dtype, np.datetime64):
        df['date'] = pd.to_datetime(df['date'])
    df = df[df['date'].dt.month.between(6, 9)]

    X = df[predictors]
    y = df['hwd']
    N_pos = y.sum()
    N_total = len(df)
    scale_pos_weight = (N_total - N_pos) / N_pos if N_pos > 0 else 1

    shap_2d_accum = np.zeros((grid_resolution, grid_resolution))
    counts = np.zeros((grid_resolution, grid_resolution))
    
    scaling_numerator_total = 0
    scaling_denominator_total = 0

    for seed in seed_list:
        kf = KFold(n_splits=3, shuffle=True, random_state=seed)
        for train_index, test_index in kf.split(X, y):
            X_train = X.iloc[train_index]
            y_train = y.iloc[train_index]

            model = xgb.XGBClassifier(
                objective="binary:logistic", eval_metric="logloss",
                learning_rate=0.2, max_depth=4,
                subsample=0.8, colsample_bytree=0.8,
                scale_pos_weight=scale_pos_weight, n_jobs=-1
            )
            model.fit(X_train, y_train)

            explainer = shap.TreeExplainer(model)
            shap_values = explainer.shap_values(X_train)
            feat_idx1 = predictors.index(feature1)
            feat_idx2 = predictors.index(feature2)

            model_numerator = np.sum(shap_values[:, feat_idx1] + shap_values[:, feat_idx2])
            model_denominator = np.sum(shap_values)
            scaling_numerator_total += model_numerator
            scaling_denominator_total += model_denominator

            x1_vals = X_train.iloc[:, feat_idx1].values
            x2_vals = X_train.iloc[:, feat_idx2].values
            shap_vals_sum = shap_values[:, feat_idx1] + shap_values[:, feat_idx2]

            valid_mask = np.isfinite(x1_vals) & np.isfinite(x2_vals) & np.isfinite(shap_vals_sum)
            x1_vals = x1_vals[valid_mask]
            x2_vals = x2_vals[valid_mask]
            shap_vals_sum = shap_vals_sum[valid_mask]

            bin_stat, x_edge, y_edge, _ = binned_statistic_2d(
                x1_vals, x2_vals, shap_vals_sum, statistic='mean', bins=grid_resolution
            )
            bin_count, _, _, _ = binned_statistic_2d(
                x1_vals, x2_vals, shap_vals_sum, statistic='count', bins=grid_resolution
            )

            bin_stat = np.nan_to_num(bin_stat)
            shap_2d_accum += bin_stat * bin_count
            counts += bin_count

    if scaling_denominator_total != 0:
        scaling_constant = scaling_numerator_total / scaling_denominator_total
    else:
        scaling_constant = 1

    with np.errstate(divide='ignore', invalid='ignore'):
        mean_shap = np.divide(shap_2d_accum, counts, out=np.zeros_like(shap_2d_accum), where=counts != 0)
    mean_shap = scaling_constant * mean_shap

    bivariate_shap_results[city] = {
        "x_edges": x_edge,
        "y_edges": y_edge,
        "mean_shap": mean_shap
    }

    vmin, vmax = -.2, .2
    levels = np.linspace(vmin, vmax, 501)
    ticks = np.arange(vmin, vmax + 0.05, 0.05)

    plt.figure(figsize=(8, 6))
    
    blue_grey_red = LinearSegmentedColormap.from_list(
        "blue_grey_red", ["blue", "lightgrey", "red"]
    )

    cp = plt.contourf(
        x_edge[:-1], y_edge[:-1], mean_shap.T,
        levels=levels, cmap=blue_grey_red, extend='both'
    )

    cb = plt.colorbar(cp, ticks=ticks)
    cb.ax.tick_params(labelsize=10)

    plt.xlabel("MJO Phase")
    plt.ylabel("MJO Amplitude")
    title_city = city_name_map.get(city.lower(), city)
    plt.title(f'{title_city}', fontsize=14)

    #plt.savefig(f"{city}_bivariate_shap_normalized.png", dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

In [None]:
## 6. LAS VEGAS HEAT WAVE TEST CASE 

lv_df = city_dfs['lasvegas'].copy()

lv_df['date'] = pd.to_datetime(lv_df['date'])

conditions_mask = (lv_df['date'] == '2024-07-07')
july7_conditions = lv_df[conditions_mask]

features = [col for col in lv_df.columns if '_21' in col or '_24' in col or '_28' in col]

mean_conditions = july7_conditions[features].mean().to_frame().T 

if not np.issubdtype(lv_df['date'].dtype, np.datetime64):
    lv_df['date'] = pd.to_datetime(df['date'])
lv_df = lv_df[lv_df['date'].dt.month.between(6, 9)]

X_lv = lv_df[features]
y_lv = lv_df['hwd']

N_pos = y_lv.sum()
N_total = len(lv_df)
N_neg = N_total - N_pos
scale_pos_weight = (N_neg / N_pos) if N_pos > 0 else 1

params = {
    'objective': 'binary:logistic',
    'eval_metric': 'logloss',
    'learning_rate': 0.2,
    'max_depth': 4,
    'subsample': 0.8,
    'colsample_bytree': 0.8,
    'scale_pos_weight': scale_pos_weight
}

ensemble_predictions = []
shap_values_ensemble = []
num_boost_round = 100

dtrain = xgb.DMatrix(X_lv, label=y_lv)
for seed in seed_list:
    params['seed'] = seed
    bst = xgb.train(params, dtrain, num_boost_round=num_boost_round)
    
    dmean = xgb.DMatrix(mean_conditions)
    pred_prob = bst.predict(dmean)[0]
    ensemble_predictions.append(pred_prob)
    
    explainer = shap.TreeExplainer(bst)
    shap_vals = explainer.shap_values(mean_conditions)  # shape: (1, n_features)
    shap_values_ensemble.append(shap_vals[0])

ensemble_predictions = np.array(ensemble_predictions)
mean_prediction = np.mean(ensemble_predictions)
lower_pred = np.percentile(ensemble_predictions, 2.5)
upper_pred = np.percentile(ensemble_predictions, 97.5)

print(f"Ensemble prediction for July 7, 2024 conditions: {mean_prediction:.3f}")
print(f"95% CI for prediction: [{lower_pred:.3f}, {upper_pred:.3f}]")

shap_array = np.vstack(shap_values_ensemble)
shap_df = pd.DataFrame(shap_array, columns=features)

shap_summary = shap_df.describe(percentiles=[0.025, 0.975]).T[['mean', '2.5%', '97.5%']]
shap_summary.rename(columns={'2.5%': 'lower', '97.5%': 'upper'}, inplace=True)

shap_summary_sorted = shap_summary.reindex(
    shap_summary['mean'].abs().sort_values(ascending=False).index
)

significant_mask = (shap_summary_sorted['lower'] > 0) | (shap_summary_sorted['upper'] < 0)
shap_summary_sig = shap_summary_sorted[significant_mask]

if shap_summary_sig.empty:
    print("No significant features found.")
else:
    mapped_features = [feature_name_map.get(feat, feat) for feat in shap_summary_sig.index]
    means = shap_summary_sig['mean']
    lower = shap_summary_sig['lower']
    upper = shap_summary_sig['upper']

    error_lower = means - lower
    error_upper = upper - means
    asymmetric_errors = [error_lower.values, error_upper.values]

    colors = ['red' if m > 0 else 'blue' for m in means]

    fig, ax = plt.subplots(figsize=(6, 6))

    ax.barh(
        mapped_features,
        means,
        xerr=asymmetric_errors,
        color=colors,
        alpha=0.7,
        error_kw=dict(capsize=4, capthick=1)  
    )

    ax.spines['top'].set_visible(True)
    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(False)
    ax.spines['right'].set_visible(False)

    ax.tick_params(axis='y', which='major', length=5)

    ax.yaxis.grid(True, linestyle='--', alpha=0.6)
    ax.xaxis.grid(False)
    ax.set_axisbelow(True)

    ax.set_xlabel("Mean SHAP Value")
    ax.set_title("Significant Feature Contributions for Las Vegas\nJuly 7, 2024 Heat Wave Prediction")
    ax.invert_yaxis()

    plt.tight_layout()
    #plt.savefig("testcase.png",dpi=300,bbox_inches='tight')
    plt.show()
