# Copyright (c) 2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# # Credit Card Transaction Data Cleanup and Prep
#
# This source code shows the steps for cleanup and preparing the credit card
# transaction data for training models with Training NIM.
#
# ### The dataset:
#  * IBM TabFormer: https://github.com/IBM/TabFormer
#  * Released under an Apache 2.0 license
#
# Contains 24M records with 15 fields, one field being the "is fraud" label which we use for training.
#
# ### Goals
# The goal is to:
#  * Cleanup the data
#    * Make field names just single word
#      * while field names are not used within the GNN, it makes accessing fields easier during cleanup
#    * Encode categorical fields
#      * use one-hot encoding for fields with less than 8 categories
#      * use binary encoding for fields with more than 8 categories
#    * Create a continuous node index across users, merchants, and transactions
#      * having node ID start at zero and then be contiguous is critical
#        for creation of Compressed Sparse Row (CSR) formatted data
#        without wasting memory.
#  * Produce:
#    * For XGBoost:
#      * Training   - all data before 2018
#      * Validation - all data during 2018
#      * Test.      - all data after 2018
#    * For GNN
#      * Training Data
#        * Edge List
#        * Feature data
#    * Test set - all data after 2018
#
#
#
# ### Graph formation
# Given that we are limited to just the data in the transaction file, the
# ideal model would be to have a bipartite graph of Users to Merchants where
# the edges represent the credit card transaction and then perform Link
# Classification on the Edges to identify fraud. Unfortunately the current
# version of cuGraph does not support GNN Link Prediction. That limitation
# will be lifted over the next few release at which time this code will be
# updated. Luckily, there is precedence for viewing transactions as nodes
# and then doing node classification using the popular GraphSAGE GNN. That
# is the approach this code takes. The produced graph will be a tri-partite
# graph where each transaction is represented as a node.
#
# <img src="../img/3-partite.jpg" width="35%"/>
#
#
# ### Features
# For the XGBoost approach, there is no need to generate empty features
# for the Merchants. However, for GNN processing, every node needs to
# have the same set of feature data. Therefore, we need to generate
# empty features for the User and Merchant nodes.
#
# -----

# #### Import the necessary libraries.  In this case will be use cuDF and perform most of the data prep in GPU
#


import logging
import os

import cudf
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import scipy.stats as ss
from category_encoders import BinaryEncoder
from scipy.stats import pointbiserialr
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, RobustScaler

COL_USER = "User"
COL_CARD = "Card"
COL_AMOUNT = "Amount"
COL_MCC = "MCC"
COL_TIME = "Time"
COL_DAY = "Day"
COL_MONTH = "Month"
COL_YEAR = "Year"

COL_MERCHANT = "Merchant"
COL_STATE = "State"
COL_CITY = "City"
COL_ZIP = "Zip"
COL_ERROR = "Errors"
COL_CHIP = "Chip"
COL_FRAUD = "Fraud"
COL_TRANSACTION_ID = "Tx_ID"
COL_MERCHANT_ID = "Merchant_ID"
COL_USER_ID = "User_ID"

UNKNOWN_STRING_MARKER = "XX"
UNKNOWN_ZIP_CODE = 0

COL_GRAPH_SRC = "src"
COL_GRAPH_DST = "dst"
COL_GRAPH_WEIGHT = "wgt"
MERCHANT_AND_USER_COLS = [COL_MERCHANT, COL_CARD, COL_MCC]

logger = logging.getLogger(__name__)


def cramers_v(x, y):
    """ "
    Compute correlation of categorical field x with target y.
    See https://en.wikipedia.org/wiki/Cram%C3%A9r's_V
    """
    confusion_matrix = cudf.crosstab(x, y).to_numpy()
    chi2 = ss.chi2_contingency(confusion_matrix)[0]
    n = confusion_matrix.sum().sum()
    r, k = confusion_matrix.shape
    return np.sqrt(chi2 / (n * (min(k - 1, r - 1))))


def create_feature_mask(columns, start_mask_id=0):
    # Dictionary to store mapping from original column to mask value
    mask_mapping = {}
    mask_values = []
    current_mask = start_mask_id

    for col in columns:
        # For encoded columns, assume the base is before the underscore
        if "_" in col:
            base_feature = col.split("_")[0]
        else:
            base_feature = col  # For non-encoded columns, use the column name directly

        # Assign a new mask value if this base feature hasn't been seen before
        if base_feature not in mask_mapping:
            mask_mapping[base_feature] = current_mask
            current_mask += 1

        # Append the mask value for this column
        mask_values.append(mask_mapping[base_feature])

    # Convert list to numpy array for further processing if needed
    feature_mask = np.array(mask_values)

    return mask_mapping, feature_mask


def preprocess_data(
    raw_csv_path: str,
    output_base_path: str,
    fraud_ratio: float = 0.1,
    under_sample: bool = True,
) -> tuple[dict, dict, dict, dict]:

    tabformer_xgb = os.path.join(output_base_path, "xgb")
    tabformer_gnn = os.path.join(output_base_path, "gnn")

    if not os.path.exists(tabformer_xgb):
        os.makedirs(tabformer_xgb)
    if not os.path.exists(tabformer_gnn):
        os.makedirs(tabformer_gnn)

    # Read the dataset

    data = cudf.read_csv(raw_csv_path)

    _ = data.rename(
        columns={
            "Merchant Name": COL_MERCHANT,
            "Merchant State": COL_STATE,
            "Merchant City": COL_CITY,
            "Errors?": COL_ERROR,
            "Use Chip": COL_CHIP,
            "Is Fraud?": COL_FRAUD,
        },
        inplace=True,
    )

    # #### Handle missing values
    # * Zip codes are numeral, replace missing zip codes by 0
    # * State and Error are string, replace missing values by marker 'XX'

    # Make sure that 'XX' doesn't exist in State and Error field before we replace missing values by 'XX'
    assert UNKNOWN_STRING_MARKER not in set(data[COL_STATE].unique().to_pandas())
    assert UNKNOWN_STRING_MARKER not in set(data[COL_ERROR].unique().to_pandas())

    # Make sure that 0 or 0.0 doesn't exist in Zip field before we replace missing values by 0
    assert float(0) not in set(data[COL_ZIP].unique().to_pandas())
    assert 0 not in set(data[COL_ZIP].unique().to_pandas())

    # Replace missing values with markers
    data[COL_STATE] = data[COL_STATE].fillna(UNKNOWN_STRING_MARKER)
    data[COL_ERROR] = data[COL_ERROR].fillna(UNKNOWN_STRING_MARKER)
    data[COL_ZIP] = data[COL_ZIP].fillna(UNKNOWN_ZIP_CODE)

    # There shouldn't be any missing values in the data now.
    assert data.isnull().sum().sum() == 0

    # ### Clean up the Amount field
    # * Drop the "$" from the Amount field and then convert from string to float
    # * Look into spread of Amount and choose right scaler for it

    # Drop the "$" from the Amount field and then convert from string to float
    data[COL_AMOUNT] = (
        data[COL_AMOUNT].str.replace("$", "", regex=False).astype("float")
    )

    # #### Change the 'Fraud' values to be integer where
    #   * 1 == Fraud
    #   * 0 == Non-fraud

    fraud_to_binary = {"No": 0, "Yes": 1}
    data[COL_FRAUD] = data[COL_FRAUD].map(fraud_to_binary).astype("int8")

    # Remove ',' in error descriptions
    data[COL_ERROR] = data[COL_ERROR].str.replace(",", "")

    # Split the time column into hours and minutes and then cast to int32
    T = data[COL_TIME].str.split(":", expand=True)
    T[0] = T[0].astype("int32")
    T[1] = T[1].astype("int32")

    # replace the 'Time' column with the new columns
    data[COL_TIME] = (T[0] * 60) + T[1]
    data[COL_TIME] = data[COL_TIME].astype("int32")

    # Delete temporary DataFrame
    del T

    # #### Convert Merchant column to str type
    data[COL_MERCHANT] = data[COL_MERCHANT].astype("str")

    # Combine User and Card to generate unique numbers
    data[COL_CARD] = data[COL_USER] * len(data[COL_CARD].unique()) + data[COL_CARD]
    data[COL_CARD] = data[COL_CARD].astype("int")

    # Collect unique merchant, card and MCC in a dataframe and fit a binary transformer
    data = data.to_pandas()

    data_ids = pd.DataFrame()

    nr_unique_card = data[COL_CARD].unique().shape[0]
    nr_unique_merchant = data[COL_MERCHANT].unique().shape[0]
    nr_unique_mcc = data[COL_MCC].unique().shape[0]
    nr_elements = max(nr_unique_merchant, nr_unique_card)

    data_ids[COL_CARD] = [data[COL_CARD][0]] * nr_elements
    data_ids[COL_MERCHANT] = [data[COL_MERCHANT][0]] * nr_elements
    data_ids[COL_MCC] = [data[COL_MCC][0]] * nr_elements

    data_ids.loc[np.arange(nr_unique_card), COL_CARD] = data[COL_CARD].unique()
    data_ids.loc[np.arange(nr_unique_merchant), COL_MERCHANT] = data[
        COL_MERCHANT
    ].unique()
    data_ids.loc[np.arange(nr_unique_mcc), COL_MCC] = data[COL_MCC].unique()

    data_ids = data_ids[MERCHANT_AND_USER_COLS].astype("category")

    id_bin_encoder = Pipeline(
        steps=[
            ("binary", BinaryEncoder(handle_missing="value", handle_unknown="value"))
        ]
    )

    id_transformer = ColumnTransformer(
        transformers=[
            ("binary", id_bin_encoder, MERCHANT_AND_USER_COLS),
        ],
        remainder="passthrough",
    )

    pd.set_option("future.no_silent_downcasting", True)
    id_transformer = id_transformer.fit(data_ids)

    preprocessed_id_data_raw = id_transformer.transform(
        data[MERCHANT_AND_USER_COLS].astype("category")
    )

    # transformed column names
    columns_of_transformed_id_data = list(
        map(
            lambda name: name.split("__")[1],
            list(id_transformer.get_feature_names_out(MERCHANT_AND_USER_COLS)),
        )
    )

    # data type of transformed columns
    id_col_type_mapping = {}
    for col in columns_of_transformed_id_data:
        if col.split("_")[0] in MERCHANT_AND_USER_COLS:
            id_col_type_mapping[col] = "int8"

    assert data_ids.isnull().sum().sum() == 0

    preprocessed_id_data = pd.DataFrame(
        preprocessed_id_data_raw, columns=columns_of_transformed_id_data
    )

    del data_ids
    del preprocessed_id_data_raw

    data = pd.concat(
        [data.reset_index(drop=True), preprocessed_id_data.reset_index(drop=True)],
        axis=1,
    )

    # ##### Compute correlation of different fields with target
    sparse_factor = 1
    columns_to_compute_corr = [
        COL_CARD,
        COL_CHIP,
        COL_ERROR,
        COL_STATE,
        COL_CITY,
        COL_ZIP,
        COL_MCC,
        COL_MERCHANT,
        COL_USER,
        COL_DAY,
        COL_MONTH,
        COL_YEAR,
    ]
    for c1 in columns_to_compute_corr:
        for c2 in [COL_FRAUD]:
            coff = 100 * cramers_v(data[c1][::sparse_factor], data[c2][::sparse_factor])
            logger.info(f"Correlation ({c1}, {c2}) = {coff:6.2f}%")

    # ### Correlation of target with numerical columns

    for col in [COL_TIME, COL_AMOUNT]:
        r_pb, p_value = pointbiserialr(
            # data[COL_FRAUD].to_pandas(), data[col].to_pandas()
            data[COL_FRAUD],
            data[col],
        )
        logger.info(f"r_pb ({col}) = {r_pb:3.2f} with p_value {p_value:3.2f}")

    numerical_predictors = [COL_AMOUNT]
    nominal_predictors = [
        COL_ERROR,
        COL_CARD,
        COL_CHIP,
        COL_CITY,
        COL_ZIP,
        COL_MCC,
        COL_MERCHANT,
    ]

    predictor_columns = numerical_predictors + nominal_predictors
    target_column = [COL_FRAUD]

    # #### Remove duplicates non-fraud data points

    # Remove duplicates data points
    fraud_data = data[data[COL_FRAUD] == 1]
    data = data[data[COL_FRAUD] == 0]

    data = data.drop_duplicates(subset=nominal_predictors)
    data = pd.concat([data, fraud_data])

    # ### Split the data into
    # The data will be split into thee groups based on event date
    #  * Training   - all data before 2018
    #  * Validation - all data during 2018
    #  * Test.      - all data after 2018

    if under_sample:
        fraud_df = data[data[COL_FRAUD] == 1]
        non_fraud_df = data[data[COL_FRAUD] == 0]
        nr_non_fraud_samples = min(
            (len(data) - len(fraud_df)), int(len(fraud_df) / fraud_ratio)
        )
        data = pd.concat(
            [fraud_df, non_fraud_df.sample(nr_non_fraud_samples, random_state=42)]
        )

    predictor_columns = list(set(predictor_columns) - set(MERCHANT_AND_USER_COLS))
    nominal_predictors = list(set(nominal_predictors) - set(MERCHANT_AND_USER_COLS))

    data = data.sample(frac=1, random_state=42).reset_index(drop=True)

    training_idx = data[COL_YEAR] < 2018
    validation_idx = data[COL_YEAR] == 2018
    test_idx = data[COL_YEAR] > 2018

    # ### Scale numerical columns and encode categorical columns of training data

    # As some of the encoder we want to use is not available in cuml, we can use pandas for now.
    # Move training data to pandas for preprocessing
    pdf_training = data[training_idx][predictor_columns + target_column]

    # Use one-hot encoding for columns with <= 8 categories, and binary encoding for columns with more categories
    columns_for_binary_encoding = []
    columns_for_one_hot_encoding = []
    for col in nominal_predictors:
        if len(data[col].unique()) <= 8:
            columns_for_one_hot_encoding.append(col)
        else:
            columns_for_binary_encoding.append(col)

    assert (training_idx.sum() + validation_idx.sum() + test_idx.sum()) == data.shape[0]

    # Mark categorical column as "category"
    pdf_training[nominal_predictors] = pdf_training[nominal_predictors].astype(
        "category"
    )

    # encoders to encode categorical columns and scalers to scale numerical columns

    bin_encoder = Pipeline(
        steps=[
            ("binary", BinaryEncoder(handle_missing="value", handle_unknown="value"))
        ]
    )
    one_hot_encoder = Pipeline(steps=[("onehot", OneHotEncoder())])

    robust_scaler = Pipeline(
        steps=[
            ("imputer", SimpleImputer(strategy="median")),
            ("robust", RobustScaler()),
        ],
    )

    # compose encoders and scalers in a column transformer
    transformer = ColumnTransformer(
        transformers=[
            ("binary", bin_encoder, columns_for_binary_encoding),
            ("onehot", one_hot_encoder, columns_for_one_hot_encoding),
            ("robust", robust_scaler, [COL_AMOUNT]),
        ],
        remainder="passthrough",
    )

    # Fit column transformer with training data

    pd.set_option("future.no_silent_downcasting", True)
    transformer = transformer.fit(pdf_training[predictor_columns])

    # transformed column names
    columns_of_transformed_txs = list(
        map(
            lambda name: name.split("__")[1],
            list(transformer.get_feature_names_out(predictor_columns)),
        )
    )

    # data type of transformed columns
    type_mapping = {}
    for col in columns_of_transformed_txs:
        if col.split("_")[0] in nominal_predictors:
            type_mapping[col] = "int8"
        elif col in numerical_predictors:
            type_mapping[col] = "float"
        elif col in target_column:
            type_mapping[col] = data.dtypes.to_dict()[col]

    # transform training data
    preprocessed_training_data = transformer.transform(pdf_training[predictor_columns])

    # Convert transformed data to panda DataFrame
    preprocessed_training_data = pd.DataFrame(
        preprocessed_training_data, columns=columns_of_transformed_txs
    )

    # Transform test data using the transformer fitted on training data
    pdf_test = data[test_idx][predictor_columns + target_column]
    pdf_test[nominal_predictors] = pdf_test[nominal_predictors].astype("category")

    preprocessed_test_data = transformer.transform(pdf_test[predictor_columns])
    preprocessed_test_data = pd.DataFrame(
        preprocessed_test_data, columns=columns_of_transformed_txs
    )

    # Transform validation data using the transformer fitted on training data
    pdf_validation = data[validation_idx][predictor_columns + target_column]
    pdf_validation[nominal_predictors] = pdf_validation[nominal_predictors].astype(
        "category"
    )

    preprocessed_validation_data = transformer.transform(
        pdf_validation[predictor_columns]
    )
    preprocessed_validation_data = pd.DataFrame(
        preprocessed_validation_data, columns=columns_of_transformed_txs
    )

    preprocessed_id_data_train = pd.DataFrame(
        id_transformer.transform(data[training_idx][MERCHANT_AND_USER_COLS]),
        columns=columns_of_transformed_id_data,
    )
    preprocessed_training_data = pd.concat(
        [preprocessed_training_data, preprocessed_id_data_train], axis=1
    )

    # ## Write out the data for XGB

    # Copy target column
    preprocessed_training_data[COL_FRAUD] = pdf_training[COL_FRAUD].values
    preprocessed_training_data = preprocessed_training_data.astype(type_mapping)

    assert preprocessed_training_data.columns[-1] == COL_FRAUD
    assert (
        set(preprocessed_training_data.columns)
        - set(
            columns_of_transformed_txs + columns_of_transformed_id_data + target_column
        )
        == set()
    )
    assert (
        set(columns_of_transformed_txs + columns_of_transformed_id_data + target_column)
        - set(preprocessed_training_data.columns)
        == set()
    )

    ## Training data
    out_path = os.path.join(tabformer_xgb, "training.csv")
    if not os.path.exists(os.path.dirname(out_path)):
        os.makedirs(os.path.dirname(out_path))
    preprocessed_training_data.to_csv(
        out_path,
        header=True,
        index=False,
        columns=preprocessed_training_data.columns,
    )

    preprocessed_id_data_val = pd.DataFrame(
        id_transformer.transform(data[validation_idx][MERCHANT_AND_USER_COLS]),
        columns=columns_of_transformed_id_data,
    )
    preprocessed_validation_data = pd.concat(
        [preprocessed_validation_data, preprocessed_id_data_val], axis=1
    )

    # Copy target column
    preprocessed_validation_data[COL_FRAUD] = pdf_validation[COL_FRAUD].values
    preprocessed_validation_data = preprocessed_validation_data.astype(type_mapping)

    assert preprocessed_validation_data.columns[-1] == COL_FRAUD
    assert (
        set(preprocessed_validation_data.columns)
        - set(
            columns_of_transformed_txs + columns_of_transformed_id_data + target_column
        )
        == set()
    )
    assert (
        set(columns_of_transformed_txs + columns_of_transformed_id_data + target_column)
        - set(preprocessed_validation_data.columns)
        == set()
    )

    ## validation data
    out_path = os.path.join(tabformer_xgb, "validation.csv")
    if not os.path.exists(os.path.dirname(out_path)):
        os.makedirs(os.path.dirname(out_path))
    preprocessed_validation_data.to_csv(
        out_path,
        header=True,
        index=False,
        columns=preprocessed_validation_data.columns,
    )
    # preprocessed_validation_data.to_parquet(out_path, index=False, compression='gzip')

    preprocessed_id_data_test = pd.DataFrame(
        id_transformer.transform(data[test_idx][MERCHANT_AND_USER_COLS]),
        columns=columns_of_transformed_id_data,
    )
    preprocessed_test_data = pd.concat(
        [preprocessed_test_data, preprocessed_id_data_test], axis=1
    )

    # Copy target column
    preprocessed_test_data[COL_FRAUD] = pdf_test[COL_FRAUD].values
    preprocessed_test_data = preprocessed_test_data.astype(type_mapping)

    assert preprocessed_test_data.columns[-1] == COL_FRAUD
    assert (
        set(preprocessed_test_data.columns)
        - set(
            columns_of_transformed_txs + columns_of_transformed_id_data + target_column
        )
        == set()
    )
    assert (
        set(columns_of_transformed_txs + columns_of_transformed_id_data + target_column)
        - set(preprocessed_test_data.columns)
        == set()
    )

    ## test data
    out_path = os.path.join(tabformer_xgb, "test.csv")
    preprocessed_test_data.to_csv(
        out_path,
        header=True,
        index=False,
        columns=preprocessed_test_data.columns,
    )

    # Delete dataFrames that are not needed anymore
    del pdf_training
    del pdf_validation
    del pdf_test
    del preprocessed_training_data
    del preprocessed_validation_data
    del preprocessed_test_data

    # ### GNN Data

    # #### Setting Vertex IDs
    # In order to create a graph, the different vertices need to be assigned
    # unique vertex IDs. Additionally, the IDs needs to be consecutive
    # and positive.
    #
    # There are three nodes groups here: Transactions, Users, and Merchants.
    #
    # This IDs are not used in training, just used for graph processing.

    # Use the same training data as used for XGBoost

    data_all = data.copy()
    data = pd.concat([data[training_idx], data[validation_idx]])
    data.reset_index(inplace=True, drop=True)

    # The number of transaction is the same as the size of the list, and hence the index value
    data[COL_TRANSACTION_ID] = data.index

    merchant_name_to_id = dict(
        zip(
            data[COL_MERCHANT].unique(),
            np.arange(len(data[COL_MERCHANT].unique())),
            strict=False,
        )
    )

    data[COL_MERCHANT_ID] = data[COL_MERCHANT].map(merchant_name_to_id)

    # ##### NOTE: the 'User' and 'Card' columns of the original data were used to crate updated 'Card' column
    # * You can use user or card as nodes

    id_to_consecutive_id = dict(
        zip(
            data[COL_CARD].unique(),
            np.arange(len(data[COL_CARD].unique())),
            strict=False,
        )
    )

    # Convert Card to consecutive IDs
    data[COL_USER_ID] = data[COL_CARD].map(id_to_consecutive_id)

    NR_USERS = data[COL_USER_ID].max() + 1
    NR_MXS = data[COL_MERCHANT_ID].max() + 1
    NR_TXS = data[COL_TRANSACTION_ID].max() + 1

    # Check the the transaction, merchant and user ids are consecutive
    id_range = data[COL_MERCHANT_ID].min(), data[COL_MERCHANT_ID].max()
    logger.info(f"Merchant ID range {id_range}")
    id_range = data[COL_USER_ID].min(), data[COL_USER_ID].max()
    logger.info(f"User ID range {id_range}")

    # #### Create Edge in COO format

    U_2_M = cudf.DataFrame()
    U_2_M[COL_GRAPH_SRC] = data[COL_USER_ID]
    U_2_M[COL_GRAPH_DST] = data[COL_MERCHANT_ID]

    Edge = cudf.concat([U_2_M])

    # Write out Edge data
    out_path = os.path.join(tabformer_gnn, "edges/user_to_merchant.csv")

    if not os.path.exists(os.path.dirname(out_path)):
        os.makedirs(os.path.dirname(out_path))

    Edge.to_csv(out_path, header=True, index=False)

    # ### Now the feature data
    # Feature data needs to be is sorted in order, where the row index corresponds to the node ID
    #
    # The data is comprised of three sets of features
    # * Transactions
    # * Merchants
    # * Users

    # #### To get feature vectors of Transaction, transform the training data using pre-fitted transformer

    transaction_feature_df = pd.DataFrame(
        transformer.transform(data[predictor_columns]),
        columns=columns_of_transformed_txs,
    ).astype(type_mapping)

    transaction_feature_df[COL_FRAUD] = data[COL_FRAUD]

    data_merchant = data[[COL_MERCHANT, COL_MCC, COL_CARD]].drop_duplicates(
        subset=[COL_MERCHANT]
    )
    data_merchant[COL_MERCHANT_ID] = data_merchant[COL_MERCHANT].map(
        merchant_name_to_id
    )
    data_merchant_sorted = data_merchant.sort_values(by=COL_MERCHANT_ID)

    data_user = data[[COL_MERCHANT, COL_MCC, COL_CARD]].drop_duplicates(
        subset=[COL_CARD]
    )
    data_user[COL_USER_ID] = data_user[COL_CARD].map(id_to_consecutive_id)
    data_user_sorted = data_user.sort_values(by=COL_USER_ID)

    user_feature_columns = []
    mx_feature_columns = []
    for c in columns_of_transformed_id_data:
        if c.startswith("Card"):
            user_feature_columns.append(c)
        else:
            mx_feature_columns.append(c)

    preprocessed_merchant_data = pd.DataFrame(
        id_transformer.transform(data_merchant_sorted[MERCHANT_AND_USER_COLS]),
        columns=columns_of_transformed_id_data,
    )[mx_feature_columns]

    preprocessed_user_data = pd.DataFrame(
        id_transformer.transform(data_user_sorted[MERCHANT_AND_USER_COLS]),
        columns=columns_of_transformed_id_data,
    )[user_feature_columns]

    # User features

    out_path = os.path.join(tabformer_gnn, "nodes/user.csv")
    if not os.path.exists(os.path.dirname(out_path)):
        os.makedirs(os.path.dirname(out_path))

    preprocessed_user_data.to_csv(
        out_path, header=True, index=False, columns=user_feature_columns
    )

    # Merchant features

    out_path = os.path.join(tabformer_gnn, "nodes/merchant.csv")
    if not os.path.exists(os.path.dirname(out_path)):
        os.makedirs(os.path.dirname(out_path))
    preprocessed_merchant_data.to_csv(
        out_path, header=True, index=False, columns=mx_feature_columns
    )

    # User to merchant edge labels

    out_path = os.path.join(tabformer_gnn, "edges/user_to_merchant_label.csv")
    if not os.path.exists(os.path.dirname(out_path)):
        os.makedirs(os.path.dirname(out_path))
    transaction_feature_df[[COL_FRAUD]].to_csv(
        out_path, header=True, index=False, columns=[COL_FRAUD]
    )

    # User to merchant edge features

    out_path = os.path.join(tabformer_gnn, "edges/user_to_merchant_attr.csv")
    if not os.path.exists(os.path.dirname(out_path)):
        os.makedirs(os.path.dirname(out_path))
    transaction_feature_df[columns_of_transformed_txs].to_csv(
        out_path, header=True, index=False, columns=columns_of_transformed_txs
    )

    # # Create and save feature masks for training data
    user_mask_map, user_mask = create_feature_mask(user_feature_columns, 0)
    mx_mask_map, mx_mask = create_feature_mask(
        mx_feature_columns, np.max(user_mask) + 1
    )
    tx_mask_map, tx_mask = create_feature_mask(
        columns_of_transformed_txs, np.max(mx_mask) + 1
    )

    # np.savetxt(
    #     os.path.join(tabformer_gnn, "nodes/user_feature_mask.csv"),
    #     user_mask,
    #     delimiter=",",
    #     fmt="%d",
    # )
    # np.savetxt(
    #     os.path.join(tabformer_gnn, "nodes/merchant_feature_mask.csv"),
    #     mx_mask,
    #     delimiter=",",
    #     fmt="%d",
    # )
    # np.savetxt(
    #     os.path.join(tabformer_gnn, "edges/user_to_merchant_feature_mask.csv"),
    #     tx_mask,
    #     delimiter=",",
    #     fmt="%d",
    # )

    ## Test data

    data = data_all[test_idx].copy()

    data.reset_index(inplace=True, drop=True)

    # The number of transaction is the same as the size of the list, and hence the index value
    data[COL_TRANSACTION_ID] = data.index

    merchant_name_to_id = dict(
        zip(
            data[COL_MERCHANT].unique(),
            np.arange(len(data[COL_MERCHANT].unique())),
            strict=False,
        )
    )

    data[COL_MERCHANT_ID] = data[COL_MERCHANT].map(merchant_name_to_id)

    # ##### NOTE: the 'User' and 'Card' columns of the original data were used to crate updated 'Card' column
    # * You can use user or card as nodes

    id_to_consecutive_id = dict(
        zip(
            data[COL_CARD].unique(),
            np.arange(len(data[COL_CARD].unique())),
            strict=False,
        )
    )

    # Convert Card to consecutive IDs
    data[COL_USER_ID] = data[COL_CARD].map(id_to_consecutive_id)

    # Check the the transaction, merchant and user ids are consecutive
    id_range = data[COL_MERCHANT_ID].min(), data[COL_MERCHANT_ID].max()
    logger.info(f"Merchant ID range {id_range}")
    id_range = data[COL_USER_ID].min(), data[COL_USER_ID].max()
    logger.info(f"User ID range {id_range}")

    NR_USERS = data[COL_USER_ID].max() + 1
    NR_MXS = data[COL_MERCHANT_ID].max() + 1
    NR_TXS = data[COL_TRANSACTION_ID].max() + 1

    # #### Test Edges in COO format

    # User to Merchant edges
    U_2_M = cudf.DataFrame()
    U_2_M[COL_GRAPH_SRC] = data[COL_USER_ID]
    U_2_M[COL_GRAPH_DST] = data[COL_MERCHANT_ID]

    Edge = cudf.concat([U_2_M])

    # Write out Edge data
    out_path = os.path.join(tabformer_gnn, "test_gnn/edges/user_to_merchant.csv")

    if not os.path.exists(os.path.dirname(out_path)):
        os.makedirs(os.path.dirname(out_path))

    Edge.to_csv(out_path, header=True, index=False)

    # ### Now the feature data
    # Feature data needs to be is sorted in order, where the row index corresponds to the node ID
    #
    # The data is comprised of three sets of features
    # * Transactions
    # * Merchants
    # * Users

    transaction_feature_df = pd.DataFrame(
        transformer.transform(data[predictor_columns]),
        columns=columns_of_transformed_txs,
    ).astype(type_mapping)

    transaction_feature_df[COL_FRAUD] = data[COL_FRAUD]

    data_merchant = data[[COL_MERCHANT, COL_MCC, COL_CARD]].drop_duplicates(
        subset=[COL_MERCHANT]
    )
    data_merchant[COL_MERCHANT_ID] = data_merchant[COL_MERCHANT].map(
        merchant_name_to_id
    )
    data_merchant_sorted = data_merchant.sort_values(by=COL_MERCHANT_ID)

    data_user = data[[COL_MERCHANT, COL_MCC, COL_CARD]].drop_duplicates(
        subset=[COL_CARD]
    )
    data_user[COL_USER_ID] = data_user[COL_CARD].map(id_to_consecutive_id)
    data_user_sorted = data_user.sort_values(by=COL_USER_ID)

    preprocessed_merchant_data = pd.DataFrame(
        id_transformer.transform(data_merchant_sorted[MERCHANT_AND_USER_COLS]),
        columns=columns_of_transformed_id_data,
    )[mx_feature_columns]

    preprocessed_user_data = pd.DataFrame(
        id_transformer.transform(data_user_sorted[MERCHANT_AND_USER_COLS]),
        columns=columns_of_transformed_id_data,
    )[user_feature_columns]

    ## feature matrix

    out_path = os.path.join(tabformer_gnn, "test_gnn/nodes/user.csv")
    if not os.path.exists(os.path.dirname(out_path)):
        os.makedirs(os.path.dirname(out_path))

    preprocessed_user_data.to_csv(
        out_path, header=True, index=False, columns=user_feature_columns
    )

    out_path = os.path.join(tabformer_gnn, "test_gnn/nodes/merchant.csv")
    if not os.path.exists(os.path.dirname(out_path)):
        os.makedirs(os.path.dirname(out_path))
    preprocessed_merchant_data.to_csv(
        out_path, header=True, index=False, columns=mx_feature_columns
    )

    out_path = os.path.join(tabformer_gnn, "test_gnn/edges/user_to_merchant_label.csv")
    if not os.path.exists(os.path.dirname(out_path)):
        os.makedirs(os.path.dirname(out_path))
    transaction_feature_df[[COL_FRAUD]].to_csv(
        out_path, header=True, index=False, columns=[COL_FRAUD]
    )

    out_path = os.path.join(tabformer_gnn, "test_gnn/edges/user_to_merchant_attr.csv")
    if not os.path.exists(os.path.dirname(out_path)):
        os.makedirs(os.path.dirname(out_path))

    transaction_feature_df[columns_of_transformed_txs].to_csv(
        out_path, header=True, index=False, columns=columns_of_transformed_txs
    )

    # Test feature masks were already created for training data, just save them for test
    np.savetxt(
        os.path.join(tabformer_gnn, "test_gnn/nodes/user_feature_mask.csv"),
        user_mask,
        delimiter=",",
        fmt="%d",
    )
    np.savetxt(
        os.path.join(tabformer_gnn, "test_gnn/nodes/merchant_feature_mask.csv"),
        mx_mask,
        delimiter=",",
        fmt="%d",
    )
    np.savetxt(
        os.path.join(tabformer_gnn, "test_gnn/edges/user_to_merchant_feature_mask.csv"),
        tx_mask,
        delimiter=",",
        fmt="%d",
    )

    metadata = {
        "row_count": len(data_all),
        "fraud_ratio": float(fraud_ratio),
        "under_sample": under_sample,
        "num_users": int(NR_USERS),
        "num_merchants": int(NR_MXS),
        "num_transactions": int(NR_TXS),
    }
    return metadata, user_mask_map, mx_mask_map, tx_mask_map


def load_hetero_graph(base):
    """
    Reads:
      - All node CSVs from nodes/, plus their matching feature masks (<node>_feature_mask.csv)
        If missing, a mask of all ones is created (np.int32).
      - All edge CSVs from edges/:
          base        -> edge_index_<edge> (np.int64)
          *_attr.csv  -> edge_attr_<edge>  (np.float32)
          *_label.csv -> exactly one -> edge_label_<edge> (DataFrame)
    """

    nodes_dir = os.path.join(base, "nodes")
    edges_dir = os.path.join(base, "edges")

    out = {}
    node_feature_mask = {}

    # --- Nodes: every CSV becomes x_<node>; also read/create feature_mask_<node> ---
    if os.path.isdir(nodes_dir):
        for fname in os.listdir(nodes_dir):
            if fname.lower().endswith(".csv") and not fname.lower().endswith(
                "_feature_mask.csv"
            ):
                node_name = fname[: -len(".csv")]
                node_path = os.path.join(nodes_dir, fname)
                node_df = pd.read_csv(node_path)
                out[f"x_{node_name}"] = node_df.to_numpy(dtype=np.float32)

                # feature mask file (optional)
                mask_fname = f"{node_name}_feature_mask.csv"
                mask_path = os.path.join(nodes_dir, mask_fname)
                if os.path.exists(mask_path):
                    mask_df = pd.read_csv(mask_path, header=None)
                    node_feature_mask[node_name] = mask_df
                    feature_mask = mask_df.to_numpy(dtype=np.int32).ravel()
                else:
                    # create a mask with all zeros
                    feature_mask = np.zeros(node_df.shape[1], dtype=np.int32)
                out[f"feature_mask_{node_name}"] = feature_mask

    # --- Edges: group into base, attr, label by filename suffix ---
    base_edges = {}
    edge_attrs = {}
    edge_labels = {}
    edge_feature_mask = {}

    if os.path.isdir(edges_dir):
        for fname in os.listdir(edges_dir):
            if not fname.lower().endswith(".csv"):
                continue
            path = os.path.join(edges_dir, fname)
            lower = fname.lower()
            if lower.endswith("_attr.csv"):
                edge_name = fname[: -len("_attr.csv")]
                edge_attrs[edge_name] = pd.read_csv(path)  # , header=None)
            elif lower.endswith("_label.csv"):
                edge_name = fname[: -len("_label.csv")]
                edge_labels[edge_name] = pd.read_csv(path)
            elif lower.endswith("_feature_mask.csv"):
                edge_name = fname[: -len("_feature_mask.csv")]
                edge_feature_mask[edge_name] = pd.read_csv(path, header=None)
            else:
                edge_name = fname[: -len(".csv")]
                base_edges[edge_name] = pd.read_csv(path)  # , header=None)

    # Enforce: only one label file total
    if len(edge_labels) == 0:
        raise FileNotFoundError(
            "No '*_label.csv' found in edges/. Exactly one label file is required."
        )
    if len(edge_labels) > 1:
        raise ValueError(
            f"Found multiple label files: {list(edge_labels.keys())}. Exactly one is allowed."
        )

    # Build output keys for edges
    for edge_name, df in base_edges.items():
        out[f"edge_index_{edge_name}"] = df.to_numpy(dtype=np.int64).T
        if edge_name in edge_attrs:
            out[f"edge_attr_{edge_name}"] = edge_attrs[edge_name].to_numpy(
                dtype=np.float32
            )
        if edge_name in edge_feature_mask:
            out[f"edge_feature_mask_{edge_name}"] = (
                edge_feature_mask[edge_name].to_numpy(dtype=np.int32).ravel()
            )
        else:
            # create a mask with all zeros
            out[f"edge_feature_mask_{edge_name}"] = np.zeros(
                edge_attrs[edge_name].shape[1], dtype=np.int32
            )

    # Add the single label file (kept as DataFrame)
    ((label_edge_name, label_df),) = edge_labels.items()
    out[f"edge_label_{label_edge_name}"] = label_df

    return out


def prepare_bipartite_structures(edge_list: np.ndarray):
    """
    From a (2, E) edge array (row0=A, row1=B):
      - returns unique node arrays A_nodes, B_nodes
      - returns neighbor dicts neighbors_A[a]->set(B), neighbors_B[b]->set(A)
      - returns highest-degree node in A, plus its 1-hop (B) and 2-hop (A) neighbors
    """
    assert (
        edge_list.ndim == 2 and edge_list.shape[0] == 2
    ), "edge_list must be shape (2, E)"
    A_nodes = np.unique(edge_list[0, :])
    B_nodes = np.unique(edge_list[1, :])

    # Neighbor maps
    neighbors_A = {int(a): set() for a in A_nodes}
    neighbors_B = {int(b): set() for b in B_nodes}
    for a, b in edge_list.T:
        a = int(a)
        b = int(b)
        neighbors_A[a].add(b)
        neighbors_B[b].add(a)

    # Degrees in A and highest-degree anchor
    degrees_A = {a: len(neighbors_A[a]) for a in neighbors_A}
    if not degrees_A:
        raise ValueError("No nodes found in partition A.")

    max_deg = min(5, max(degrees_A.values()))
    # deterministic tie-break: smallest node id
    anchor_A = min([a for a, d in degrees_A.items() if d == max_deg])

    # 1-hop (B) and 2-hop (A) around the anchor
    one_hop_B = sorted(neighbors_A[anchor_A])
    two_hop_A = set()

    for b in one_hop_B:
        two_hop_A.update(list(neighbors_B[b])[:3])
    two_hop_A.discard(anchor_A)
    two_hop_A = sorted(two_hop_A)

    return {
        "A_nodes": A_nodes,
        "B_nodes": B_nodes,
        "neighbors_A": neighbors_A,
        "neighbors_B": neighbors_B,
        "degrees_A": degrees_A,
        "anchor_A": anchor_A,
        "anchor_degree": max_deg,
        "one_hop_B": one_hop_B,
        "two_hop_A": two_hop_A,
    }


def build_bipartite_graph(edge_list: np.ndarray) -> nx.Graph:
    """Build a NetworkX bipartite graph with node attribute 'bipartite' (0 for A, 1 for B)."""
    A_nodes = np.unique(edge_list[0, :]).astype(int)
    B_nodes = np.unique(edge_list[1, :]).astype(int)
    G = nx.Graph()
    G.add_nodes_from(A_nodes, bipartite=0)
    G.add_nodes_from(B_nodes, bipartite=1)
    G.add_edges_from([(int(a), int(b)) for a, b in edge_list.T])
    return G


def induced_ego_two_hop_subgraph_namespaced(edge_list: np.ndarray):
    """
    Same logic as your induced_ego_two_hop_subgraph, but nodes are namespaced:
      A-node a  -> ('A', a)
      B-node b  -> ('B', b)
    This guarantees edges draw only between the partitions.
    """
    info = prepare_bipartite_structures(edge_list)

    A_nodes = info["A_nodes"].astype(int)
    B_nodes = info["B_nodes"].astype(int)
    A_map = {a: ("A", a) for a in A_nodes}
    B_map = {b: ("B", b) for b in B_nodes}

    anchor_A = info["anchor_A"]
    one_hop_B = set(info["one_hop_B"])
    two_hop_A = set(info["two_hop_A"])

    # Build subgraph with namespaced nodes and bipartite attribute
    G_sub = nx.Graph()
    G_sub.add_nodes_from([A_map[anchor_A], *[A_map[a] for a in two_hop_A]], bipartite=0)
    G_sub.add_nodes_from([B_map[b] for b in one_hop_B], bipartite=1)

    for a, b in edge_list.T:
        a = int(a)
        b = int(b)
        if (a == anchor_A or a in two_hop_A) and (b in one_hop_B):
            G_sub.add_edge(A_map[a], B_map[b])

    # Return maps so the plotter can highlight the anchor
    return G_sub, info, A_map, B_map


def plot_bipartite_subgraph_namespaced(G_sub: nx.Graph, info: dict, A_map: dict):
    """
    Plot the 2-hop ego subgraph with namespaced nodes.
    Circles = A, Squares = B. The anchor A-node is larger.
    """
    anchor_tuple = A_map[info["anchor_A"]]
    sub_A = [n for n, d in G_sub.nodes(data=True) if d.get("bipartite") == 0]
    sub_B = [n for n, d in G_sub.nodes(data=True) if d.get("bipartite") == 1]

    pos = nx.bipartite_layout(G_sub, nodes=sub_A)

    plt.figure(figsize=(6, 5))
    nx.draw_networkx_nodes(
        G_sub,
        pos,
        nodelist=sub_A,
        node_size=[600 if n == anchor_tuple else 300 for n in sub_A],
        node_shape="o",
    )
    nx.draw_networkx_nodes(G_sub, pos, nodelist=sub_B, node_size=400, node_shape="s")
    nx.draw_networkx_edges(G_sub, pos)

    # Pretty labels: show only the numeric part (n[1])
    labels = {n: str(n[1]) for n in G_sub.nodes()}
    nx.draw_networkx_labels(G_sub, pos, labels=labels, font_size=10)

    plt.title("A subgraph with a few user and merchant nodes.")
    plt.axis("off")
    plt.show()
