大数据分析与挖掘
¶

05. Imbalanced data with classification model
¶

主讲人:丁平尖

什么是类别不平衡?¶

类别不平衡是指机器学习中的一个问题,即数据中的类别代表性不均等。例如,如果有 100 个数据点,其中 90 个属于 A 类,10 个属于 B 类,则这些类别是不平衡的。

image.png

存在类别不平衡问题的实际问题示例¶

类别不平衡类型 应用场景 实际案例
二分类不平衡 最直接的类别不平衡类型。仅有两个类别,其中一个类别远多于另一个。
- 欺诈检测
- 疾病诊断
多分类不平衡 一个或多个类别的样本数量明显少于其他类别。
- 自然语言处理(NLP)
- 多分类疾病诊断
时序不平衡 类别分布随时间变化,导致时序上的不平衡。
- 社交媒体趋势
- 股票价格预测
空间不平衡 类别分布在不同区域或位置上存在差异。
- 疾病在不同地区的传播
- 不同地区的客户行为

二分类不平衡问题¶

  • 少数类泛化能力差:在垃圾邮件检测系统中,如果“非垃圾邮件”邮件的数量明显多于“垃圾邮件”,模型可能非常擅长识别“非垃圾邮件”,但却无法正确分类“垃圾邮件”。这可能会导致许多垃圾邮件被遗漏。
  • 评估模型性能的难度:在信用卡欺诈检测中,如果欺诈交易很少见,模型可能会通过将所有交易预测为非欺诈交易来实现高精度。然而,这种高精度具有误导性,因为模型无法识别那些罕见但重要的欺诈案例。
  • 过度拟合多数类的趋势:在机械预测性维护中,如果与正常运行数据(多数类)相比,机器故障数据(少数类)稀缺,则模型可能会过度拟合正常运行模式,并且无法准确预测故障。
  • 模型偏差风险:在招聘算法中,如果数据集包含更多来自特定人群的成功候选人的例子,则该模型可能会对该人群产生偏见,从而导致不公平的招聘行为。
  • 平衡精确度和召回率:在法律文件分析中,为了识别相关文件,平衡精确度(确保识别的文件是相关的)和召回率(确保识别所有相关文件)至关重要,因为假阳性和假阴性都会带来严重后果。

处理类别不平衡的技术¶

  • 类别加权:例如,如果数据集有两个类别 A 和 B,其中 A 类出现的频率高于 B 类,那么分配给 A 类的权重将低于分配给 B 类的权重。这意味着在训练期间,算法将更多地关注 A 类,而较少关注 B 类,以确保两个类别具有同等重要性。
  • 与多数类相关的欠采样数据:减少多数类的样本数量,以平衡类别分布。
  • 与少数类相关的过采样数据:增加少数类的样本数量,以平衡类别分布。
  • 生成合成数据:通过从少数类中生成新的合成样本来解决。合成采样背后的理念是创建与少数类相似的新样本,从而增加少数类在数据集中的代表性。

模型中的类别权重(Class weights)¶

  • 按类别在输入数据中的频率反比调整权重
    • 权重 = 样本总数 / (类别数 × 当前类别样本数)
  • 权重示例*
  • 9 行数据分为 3 类:
    • 猫:$9 / (3 \times 3) = 1.0$
    • 狗:$9 / (3 \times 5) = 0.6$
    • 马:$9 / (3 \times 1) = 3.0$

image.png

不平衡数据¶

  • 使用加权模型纠正类别不平衡
  • 信用卡欺诈数据
  • 数据集下载:
wget https://raw.githubusercontent.com/nsethi31/Kaggle-Data-Credit-Card-Fraud-Detection/master/creditcard.csv
In [1]:
import findspark
findspark.init()
import warnings
warnings.filterwarnings('ignore')

from pyspark.sql import SparkSession

spark = SparkSession.builder.master("spark://spark-master:7077").config("spark.ui.port", "8080").config("spark.acls.enable", "false").config("spark.ui.view.acls", "*").config("spark.modify.acls", "*").getOrCreate()
sc = spark.sparkContext

df = spark.read.csv('../dataset/creditcard.csv', inferSchema=True, header=True, mode='DROPMALFORMED')
In [2]:
## Review data
df.show(2)
+----+------------+------------+-----------+-----------+-----------+------------+------------+-----------+------------+------------+------------+------------+------------+------------+-----------+------------+------------+-----------+------------+------------+------------+------------+-----------+------------+-----------+------------+------------+------------+------+-----+
|Time|          V1|          V2|         V3|         V4|         V5|          V6|          V7|         V8|          V9|         V10|         V11|         V12|         V13|         V14|        V15|         V16|         V17|        V18|         V19|         V20|         V21|         V22|        V23|         V24|        V25|         V26|         V27|         V28|Amount|Class|
+----+------------+------------+-----------+-----------+-----------+------------+------------+-----------+------------+------------+------------+------------+------------+------------+-----------+------------+------------+-----------+------------+------------+------------+------------+-----------+------------+-----------+------------+------------+------------+------+-----+
| 0.0|-1.359807134|-0.072781173|2.536346738|1.378155224|-0.33832077| 0.462387778| 0.239598554|0.098697901|  0.36378697| 0.090794172|-0.551599533|-0.617800856|-0.991389847|-0.311169354|1.468176972|-0.470400525| 0.207971242| 0.02579058|  0.40399296| 0.251412098|-0.018306778| 0.277837576|-0.11047391| 0.066928075|0.128539358|-0.189114844| 0.133558377|-0.021053053|149.62|    0|
| 0.0| 1.191857111| 0.266150712|0.166480113|0.448154078|0.060017649|-0.082360809|-0.078802983|0.085101655|-0.255425128|-0.166974414| 1.612726661| 1.065235311| 0.489095016|-0.143772296|0.635558093| 0.463917041|-0.114804663|-0.18336127|-0.145783041|-0.069083135|-0.225775248|-0.638671953|0.101288021|-0.339846476|0.167170404| 0.125894532|-0.008983099| 0.014724169|  2.69|    0|
+----+------------+------------+-----------+-----------+-----------+------------+------------+-----------+------------+------------+------------+------------+------------+------------+-----------+------------+------------+-----------+------------+------------+------------+------------+-----------+------------+-----------+------------+------------+------------+------+-----+
only showing top 2 rows

In [3]:
df.summary().show()
+-------+-----------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+--------------------+
|summary|             Time|                  V1|                  V2|                  V3|                  V4|                  V5|                  V6|                  V7|                  V8|                  V9|                 V10|                 V11|                 V12|                 V13|                 V14|                 V15|                 V16|                 V17|                 V18|                 V19|                 V20|                 V21|                 V22|                 V23|                 V24|                 V25|                 V26|                 V27|                 V28|           Amount|               Class|
+-------+-----------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+--------------------+
|  count|           284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|              284807|           284807|              284807|
|   mean|94813.85957508067|1.759046617564544...|-8.24816056209747...|-9.65586281111362...|8.319229056158216...|1.650622057742291...|4.248446153940018...|-3.05450822720766...|8.76532715895825E-14|-1.17974498483826...|7.09132013778678E-13|1.875112916604381...|1.052928198930065...|7.13127870194277E-13|-1.47817996925521...|-5.23388882220346...|-2.28215303605140...|-6.42337199498324...|4.950189652922139...|7.057091680529474...|1.766135504253067...|-3.40550172400656...|-5.72547650927198...|-9.72570584866590...|1.464143627330664...|-6.98998795740368...|-5.61766228267496...|3.332084255351006...|-3.51887745743949...|88.34961925092918|0.001727485630620034|
| stddev| 47488.1459545661|  1.9586958038519584|  1.6513085794832496|  1.5162550051772055|   1.415868574942034|  1.3802467340261948|   1.332271089758185|   1.237093598166435|   1.194352902672247|  1.0986320892235726|   1.088849765406632|  1.0207130277116543|  0.9992013895273293|  0.9952742301246713|  0.9585956112586888|  0.9153160116102039|  0.8762528873875522|  0.8493370636754588|   0.838176209528987|  0.8140405007679462|    0.77092502488458|  0.7345240143739621|  0.7257015604417664|  0.6244602955966106|  0.6056470678277418|  0.5212780705402165| 0.48222701326032286|  0.4036324949661106| 0.33008326415627354|250.1201092401886| 0.04152718963546506|
|    min|              0.0|        -56.40750963|        -72.71572756|        -48.32558936|        -5.683171198|        -113.7433067|        -26.16050594|        -43.55724157|        -73.21671846|        -13.43406632|        -24.58826244|        -4.797473465|        -18.68371463|        -5.791881206|        -19.21432549|        -4.498944677|        -14.12985452|        -25.16279937|        -9.498745921|         -7.21352743|        -54.49772049|        -34.83038214|         -10.9331437|         -44.8077352|        -2.836626919|        -10.29539707|        -2.604550553|        -22.56567932|        -15.43008391|              0.0|                   0|
|    25%|          54196.0|         -0.92047203|         -0.59865528|         -0.89057473|        -0.848761618|         -0.69169996|        -0.768401417|        -0.554146437|        -0.208684621|        -0.643182642|        -0.535567489|        -0.762639554|        -0.405616682|        -0.648611855|        -0.425694631|        -0.583045603|        -0.468251967|        -0.483813379|        -0.498849799|        -0.456379548|        -0.211763269|        -0.228425296|         -0.54245288|        -0.161873518|        -0.354651882|        -0.317194011|         -0.32702731|        -0.070847739|        -0.052968026|             5.59|                   0|
|    50%|          84680.0|         0.017870266|         0.065279961|         0.179707515|          -0.0200417|        -0.054535811|        -0.274311571|         0.039987629|         0.022307799|        -0.051571616|        -0.093027333|        -0.032892971|         0.139851034|         -0.01371465|         0.050473753|         0.047824253|          0.06629415|        -0.065738381|        -0.003733459|          0.00356731|        -0.062510572|        -0.029518194|         0.006674895|        -0.011221835|         0.040882334|         0.016423246|        -0.052211672|         0.001325831|         0.011230722|             22.0|                   0|
|    75%|         139309.0|          1.31551523|          0.80351967|         1.026839686|         0.742906322|         0.611691997|          0.39818411|         0.570242205|         0.327210628|         0.596890798|         0.453649692|         0.739309809|         0.618001764|         0.662166002|         0.493008924|         0.648503344|         0.523061582|         0.399525418|         0.500614101|         0.458715159|         0.132942461|         0.186259032|          0.52833769|          0.14755632|         0.439396308|         0.350653477|         0.240797514|         0.091008841|          0.07824027|            77.09|                   0|
|    max|         172792.0|         2.454929991|         22.05772899|         9.382558433|         16.87534403|         34.80166588|         73.30162555|         120.5894939|         20.00720837|         15.59499461|         23.74513612|         12.01891318|         7.848392076|         7.126882959|         10.52676605|         8.877741598|         17.31511152|          9.25352625|         5.041069185|         5.591971427|         39.42090425|         27.20283916|         10.50309009|         22.52841169|         4.584549137|         7.519588679|         3.517345612|         31.61219811|         33.84780782|         25691.16|                   1|
+-------+-----------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+--------------------+

In [4]:
df.groupBy('Class').count().show()
+-----+------+
|Class| count|
+-----+------+
|    1|   492|
|    0|284315|
+-----+------+

In [5]:
from pyspark.sql import DataFrame
import pyspark.sql.functions as F
import pyspark.sql.types as T
import pandas as pd

# 计算类别权重和权重列的相关函数
# class_weights: 计算每个类别的权重,权重与类别样本数成反比
def class_weights(df:DataFrame, column='Class'):
  """
  Compute rebalancing weights for the given `column`.
  """
  class_counts = list(df.groupBy(column).count().toPandas().to_dict(orient='list').values())
  num_classes = len(class_counts[0])
  total_samples = sum(class_counts[1])
  pairs = list(zip(class_counts[0], class_counts[1]))
  return {k:total_samples/(num_classes * v) for (k, v) in pairs}
In [6]:
# weight_column: 通过pandas_udf为每个样本分配对应类别的权重
def weight_column(classes: pd.Series, class_weights) -> pd.Series:
# UDF 用于计算类别权重。
# 参数 `class_weights` 是预先计算好的类别权重字典。
# 示例调用:sdf.withColumn('weight', weight_column(sdf['Class'], class_weights(sdf, column='Class')))
  @F.pandas_udf(T.DoubleType())
  def weight_column_udf(classes: pd.Series) -> pd.Series:
    # UDF 用于根据类别分配权重。
    # 依赖于预先计算好的类别权重字典 class_weights。
    # 注意:不在 pandas_udf 内部重新计算权重,使用外部传入的权重字典。
    result = []
    for _, value in classes.items():
      result += [class_weights[value]]
    return pd.Series(result)
  
  return weight_column_udf(classes)
In [7]:
print(f"The class weights are {class_weights(df)}")
The class weights are {1: 289.4380081300813, 0: 0.5008652375006595}
In [8]:
df_weighted = df.withColumn('weight', weight_column(df['Class'], class_weights(df)))
df_weighted.groupBy('Class', 'weight').agg(F.count('Class').alias('count')).show()
+-----+------------------+------+
|Class|            weight| count|
+-----+------------------+------+
|    1| 289.4380081300813|   492|
|    0|0.5008652375006595|284315|
+-----+------------------+------+

In [9]:
cc_fraud = df_weighted
cc_fraud.show(2)
+----+------------+------------+-----------+-----------+-----------+------------+------------+-----------+------------+------------+------------+------------+------------+------------+-----------+------------+------------+-----------+------------+------------+------------+------------+-----------+------------+-----------+------------+------------+------------+------+-----+------------------+
|Time|          V1|          V2|         V3|         V4|         V5|          V6|          V7|         V8|          V9|         V10|         V11|         V12|         V13|         V14|        V15|         V16|         V17|        V18|         V19|         V20|         V21|         V22|        V23|         V24|        V25|         V26|         V27|         V28|Amount|Class|            weight|
+----+------------+------------+-----------+-----------+-----------+------------+------------+-----------+------------+------------+------------+------------+------------+------------+-----------+------------+------------+-----------+------------+------------+------------+------------+-----------+------------+-----------+------------+------------+------------+------+-----+------------------+
| 0.0|-1.359807134|-0.072781173|2.536346738|1.378155224|-0.33832077| 0.462387778| 0.239598554|0.098697901|  0.36378697| 0.090794172|-0.551599533|-0.617800856|-0.991389847|-0.311169354|1.468176972|-0.470400525| 0.207971242| 0.02579058|  0.40399296| 0.251412098|-0.018306778| 0.277837576|-0.11047391| 0.066928075|0.128539358|-0.189114844| 0.133558377|-0.021053053|149.62|    0|0.5008652375006595|
| 0.0| 1.191857111| 0.266150712|0.166480113|0.448154078|0.060017649|-0.082360809|-0.078802983|0.085101655|-0.255425128|-0.166974414| 1.612726661| 1.065235311| 0.489095016|-0.143772296|0.635558093| 0.463917041|-0.114804663|-0.18336127|-0.145783041|-0.069083135|-0.225775248|-0.638671953|0.101288021|-0.339846476|0.167170404| 0.125894532|-0.008983099| 0.014724169|  2.69|    0|0.5008652375006595|
+----+------------+------------+-----------+-----------+-----------+------------+------------+-----------+------------+------------+------------+------------+------------+------------+-----------+------------+------------+-----------+------------+------------+------------+------------+-----------+------------+-----------+------------+------------+------------+------+-----+------------------+
only showing top 2 rows

In [10]:
from pyspark.ml.feature import RFormula
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline, Model
# Split into train and test
train_data,test_data  = cc_fraud.randomSplit([0.6, 0.4], 24)   # proportions [], seed for random
# all the columns except these
x_cols = list(set(cc_fraud.columns) - {'Class', 'Time', 'weight'})
# RFormula will create the feature column, instead of VectorAssembler
formula = "{} ~ {}".format("Class", " + ".join(x_cols))
print("Formula : {}".format(formula))
# Steps in sequence; output of one is input to next:
pipeline = Pipeline(stages=[RFormula(formula = formula, featuresCol='features'),
                            LogisticRegression(labelCol ='Class') # no weight
                           ])
# And then run the steps with one command:
fitted_model = pipeline.fit(train_data)
Formula : Class ~ V21 + V12 + V1 + V3 + V10 + V25 + V13 + V14 + V19 + V27 + V20 + V22 + V18 + V4 + V23 + V28 + V16 + V8 + V6 + V11 + V9 + V2 + V17 + V7 + V5 + V15 + V24 + V26 + Amount
In [11]:
lm = fitted_model.stages[-1]
In [12]:
def equation_str(coeffs, intercept, columns):
   return ' + '.join([f"{p[0]:,.2f} * {p[1]}" for p in zip(coeffs, columns)]) + f"  {intercept:,.2f}"
  
def classification_model_summary(lm, formula, columns):
  """
  Multi-class classification summary.
  """
  print("Classification model summary")
  print(f"Formula:\n\t{formula}")
  print(f"Equation:")
  if lm.numClasses <3:
    print(f"\t{equation_str(lm.coefficients.toArray(), lm.intercept, columns)}")
    print(f"\nAuC ROC: {lm.summary.areaUnderROC:,.3f}")
  else:
    for i, eq in enumerate([equation_str(coeffs, intercept, x_cols) for coeffs,intercept in zip(lm.coefficientMatrix.toArray(), lm.interceptVector.toArray())]):
      print(f"\t{i}: {eq}\n")
  print()
  trainingSummary = lm.summary
  accuracy = trainingSummary.accuracy
  falsePositiveRate = trainingSummary.weightedFalsePositiveRate
  truePositiveRate = trainingSummary.weightedTruePositiveRate
  fMeasure = trainingSummary.weightedFMeasure()
  precision = trainingSummary.weightedPrecision
  recall = trainingSummary.weightedRecall
  print(f"Accuracy: {accuracy:.3f} FPR: {falsePositiveRate:.3f} TPR: {truePositiveRate:.3f}")
  print(f"F-measure: {fMeasure:.3f} Precision: {precision:.3f} Recall: {recall:.3f}")
  print(f"{'Labels':>20}")
  print("         " + ''.join([f"{l:>10}" for l in lm.summary.labels]))
  print("F-measure" + ''.join([f"{l:>10,.3f}" for l in lm.summary.fMeasureByLabel()]))
  print("Precision" + ''.join([f"{l:>10,.3f}" for l in lm.summary.precisionByLabel]))
  print("Recall   " + ''.join([f"{l:>10,.3f}" for l in lm.summary.recallByLabel]))
In [13]:
classification_model_summary(lm, formula, x_cols)
Classification model summary
Formula:
	Class ~ V21 + V12 + V1 + V3 + V10 + V25 + V13 + V14 + V19 + V27 + V20 + V22 + V18 + V4 + V23 + V28 + V16 + V8 + V6 + V11 + V9 + V2 + V17 + V7 + V5 + V15 + V24 + V26 + Amount
Equation:
	0.44 * V21 + 0.16 * V12 + 0.09 * V1 + -0.02 * V3 + -0.73 * V10 + -0.00 * V25 + -0.44 * V13 + -0.63 * V14 + 0.10 * V19 + -0.87 * V27 + -0.42 * V20 + 0.69 * V22 + -0.09 * V18 + 0.69 * V4 + -0.11 * V23 + -0.36 * V28 + -0.12 * V16 + -0.17 * V8 + -0.18 * V6 + -0.16 * V11 + -0.27 * V9 + 0.01 * V2 + -0.01 * V17 + -0.09 * V7 + 0.07 * V5 + -0.10 * V15 + 0.07 * V24 + -0.04 * V26 + 0.00 * Amount  -8.77

AuC ROC: 0.975

Accuracy: 0.999 FPR: 0.412 TPR: 0.999
F-measure: 0.999 Precision: 0.999 Recall: 0.999
              Labels
                0.0       1.0
F-measure     1.000     0.696
Precision     0.999     0.854
Recall        1.000     0.588
In [14]:
from pyspark.mllib.evaluation import MulticlassMetrics
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context('notebook')
sns.set_style('white')
sns.set_palette("bright")
predictions = fitted_model.transform(train_data)
# Confusion matrix from labels and predictions
metrics = MulticlassMetrics(predictions.select('label', 'prediction').rdd.map(tuple))
confusion_matrix = metrics.confusionMatrix().toArray()
# Pandas DataFrame from Spark confusion matrix
cnf_matrix = pd.DataFrame(confusion_matrix)
plt.figure(figsize = (10,7))
p = sns.heatmap(cnf_matrix/np.sum(cnf_matrix), annot=True, fmt=".2%", linewidth=0.5, annot_kws={'fontsize':10}, cmap='RdBu')
# If you want numbers, instead of percent
# p = sns.heatmap(cnf_matrix, annot=True, fmt=",.1f", linewidth=0.5, annot_kws={'fontsize':10}, cmap='RdBu');
p.set(xlabel='Predicted', ylabel='Actual', title='Confusion Matrix')
Out[14]:
[Text(0.5, 43.249999999999986, 'Predicted'),
 Text(91.25, 0.5, 'Actual'),
 Text(0.5, 1.0, 'Confusion Matrix')]
No description has been provided for this image
In [15]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context('notebook')
sns.set_style('white')
sns.set_palette("bright")

# data
roc_data = lm.summary.roc.select('FPR', 'TPR').toPandas()
# basic figure
plt.figure(figsize=(5,5))
p = sns.lineplot(data=roc_data, x='FPR', y='TPR', linewidth=2)
# figure properties
p.set(title=f"ROC {lm.summary.areaUnderROC:.3f}")
p.set_xlabel("False positive rate (FPR)")
p.set_ylabel("True positive rate (TPR)")
# Get the xy data from the lines so that we can shade
l1 = p.lines[0]
x1 = l1.get_xydata()[:,0]
y1 = l1.get_xydata()[:,1]
p.fill_between(x1,y1, color="lightblue", alpha=0.3)
# Add dashed line with a slope of 1
plt.plot([0,1], [0,1], linestyle=(0, (5, 5)), linewidth=2);
No description has been provided for this image
In [16]:
# Use the weightCol
pipeline = Pipeline(stages=[RFormula(formula = formula, featuresCol='features'),
                            LogisticRegression(labelCol ='Class', weightCol='weight') # use weight
                           ])


# And then run the steps with one command:
fitted_model = pipeline.fit(train_data)

lm = fitted_model.stages[1]
In [17]:
classification_model_summary(lm, formula, x_cols)
Classification model summary
Formula:
	Class ~ V21 + V12 + V1 + V3 + V10 + V25 + V13 + V14 + V19 + V27 + V20 + V22 + V18 + V4 + V23 + V28 + V16 + V8 + V6 + V11 + V9 + V2 + V17 + V7 + V5 + V15 + V24 + V26 + Amount
Equation:
	0.21 * V21 + -1.38 * V12 + 0.77 * V1 + 0.45 * V3 + -1.03 * V10 + -0.01 * V25 + -0.59 * V13 + -1.37 * V14 + 0.38 * V19 + -0.10 * V27 + -1.00 * V20 + 0.95 * V22 + -0.43 * V18 + 0.87 * V4 + 0.50 * V23 + 1.19 * V28 + -0.64 * V16 + -0.46 * V8 + -0.68 * V6 + 0.53 * V11 + -0.55 * V9 + 0.51 * V2 + -0.84 * V17 + -0.69 * V7 + 0.62 * V5 + -0.28 * V15 + -0.27 * V24 + -0.42 * V26 + 0.01 * Amount  -4.68

AuC ROC: 0.988

Accuracy: 0.950 FPR: 0.051 TPR: 0.950
F-measure: 0.950 Precision: 0.951 Recall: 0.950
              Labels
                0.0       1.0
F-measure     0.953     0.948
Precision     0.932     0.972
Recall        0.974     0.925
In [18]:
predictions = fitted_model.transform(train_data)
# Confusion matrix from labels and predictions
metrics = MulticlassMetrics(predictions.select('label', 'prediction').rdd.map(tuple))
confusion_matrix = metrics.confusionMatrix().toArray()
# Pandas DataFrame from Spark confusion matrix
cnf_matrix = pd.DataFrame(confusion_matrix)
plt.figure(figsize = (10,7))
p = sns.heatmap(cnf_matrix/np.sum(cnf_matrix), annot=True, fmt=".2%", linewidth=0.5, annot_kws={'fontsize':10}, cmap='RdBu')
# If you want numbers, instead of percent
# p = sns.heatmap(cnf_matrix, annot=True, fmt=",.1f", linewidth=0.5, annot_kws={'fontsize':10}, cmap='RdBu');
p.set(xlabel='Predicted', ylabel='Actual', title='Confusion Matrix')
Out[18]:
[Text(0.5, 43.249999999999986, 'Predicted'),
 Text(91.25, 0.5, 'Actual'),
 Text(0.5, 1.0, 'Confusion Matrix')]
No description has been provided for this image
In [19]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context('notebook')
sns.set_style('white')
sns.set_palette("bright")
# data
roc_data = lm.summary.roc.select('FPR', 'TPR').toPandas()
# basic figure
plt.figure(figsize=(5,5))
p = sns.lineplot(data=roc_data, x='FPR', y='TPR', linewidth=2)
# figure properties
p.set(title=f"ROC {lm.summary.areaUnderROC:.3f}")
p.set_xlabel("False positive rate (FPR)")
p.set_ylabel("True positive rate (TPR)")
# Get the xy data from the lines so that we can shade
l1 = p.lines[0]
x1 = l1.get_xydata()[:,0]
y1 = l1.get_xydata()[:,1]
p.fill_between(x1,y1, color="lightblue", alpha=0.3)
# Add dashed line with a slope of 1
plt.plot([0,1], [0,1], linestyle=(0, (5, 5)), linewidth=2);
No description has been provided for this image

Multiple class classification¶

  • https://www.kaggle.com/code/krishnamsheth31/faulty-steel-plate-classification/data
  • https://github.com/daines-analytics/tabular-data-projects/blob/master/r-classification-faulty-steel-plates/faults.csv
  • 故障钢板数据集用于多类别分类任务,包含多种钢板缺陷类型(如 Pastry、Z_Scratch、K_Scatch、Stains、Dirtiness、Bumps 等),适合演示类别不平衡问题及其处理方法。
wget https://raw.githubusercontent.com/daines-analytics/tabular-data-projects/master/r-classification-faulty-steel-plates/faults.csv
In [20]:
df = spark.read.csv('../dataset/faults.csv', inferSchema=True, header=True, mode='DROPMALFORMED')
In [21]:
# Review data
df.show(2)
+---------+---------+---------+---------+------------+-----------+-----------+-----------------+---------------------+---------------------+------------------+----------------+----------------+---------------------+-----------+-----------+------------+---------------+-------------+-------------+--------------------+----------+-----------+-----------+-----------------+----------------+--------------+------+---------+--------+------+---------+-----+------------+
|X_Minimum|X_Maximum|Y_Minimum|Y_Maximum|Pixels_Areas|X_Perimeter|Y_Perimeter|Sum_of_Luminosity|Minimum_of_Luminosity|Maximum_of_Luminosity|Length_of_Conveyer|TypeOfSteel_A300|TypeOfSteel_A400|Steel_Plate_Thickness|Edges_Index|Empty_Index|Square_Index|Outside_X_Index|Edges_X_Index|Edges_Y_Index|Outside_Global_Index|LogOfAreas|Log_X_Index|Log_Y_Index|Orientation_Index|Luminosity_Index|SigmoidOfAreas|Pastry|Z_Scratch|K_Scatch|Stains|Dirtiness|Bumps|Other_Faults|
+---------+---------+---------+---------+------------+-----------+-----------+-----------------+---------------------+---------------------+------------------+----------------+----------------+---------------------+-----------+-----------+------------+---------------+-------------+-------------+--------------------+----------+-----------+-----------+-----------------+----------------+--------------+------+---------+--------+------+---------+-----+------------+
|       42|       50|   270900|   270944|         267|         17|         44|            24220|                   76|                  108|              1687|               1|               0|                   80|     0.0498|     0.2415|      0.1818|         0.0047|       0.4706|          1.0|                 1.0|    2.4265|     0.9031|     1.6435|           0.8182|         -0.2913|        0.5822|     1|        0|       0|     0|        0|    0|           0|
|      645|      651|  2538079|  2538108|         108|         10|         30|            11397|                   84|                  123|              1687|               1|               0|                   80|     0.7647|     0.3793|      0.2069|         0.0036|          0.6|       0.9667|                 1.0|    2.0334|     0.7782|     1.4624|           0.7931|         -0.1756|        0.2984|     1|        0|       0|     0|        0|    0|           0|
+---------+---------+---------+---------+------------+-----------+-----------+-----------------+---------------------+---------------------+------------------+----------------+----------------+---------------------+-----------+-----------+------------+---------------+-------------+-------------+--------------------+----------+-----------+-----------+-----------------+----------------+--------------+------+---------+--------+------+---------+-----+------------+
only showing top 2 rows

In [22]:
import pyspark.sql.functions as F
# 该代码用于将原始钢板故障数据集中的多类别故障类型(如 Pastry、Z_Scratch 等)合并为一个新的“fault”列。
# 每一行根据各故障类型的数值,判断属于哪种故障,并赋予对应的类别标签。
# 然后删除原有的各故障类型列,只保留新生成的“fault”类别列和其他特征列,便于后续多类别分类建模。
dff = df.withColumn('fault',
             F.when(df.Pastry>0,'Pastry') \
                    .when(df.Z_Scratch>0,'Z_Scratch') \
                    .when(df.K_Scatch>0,'K_Scatch') \
                    .when(df.Stains>0,'Stains') \
                    .when(df.Dirtiness>0,'Dirtiness') \
                    .when(df.Bumps>0,'Bumps') \
                    .otherwise('Other_Faults'))
dff = dff.drop('Pastry').drop('Z_Scratch').drop('K_Scatch').drop('Stains').drop('Dirtiness').drop('Bumps').drop('Other_Faults')
dff.show(2)
+---------+---------+---------+---------+------------+-----------+-----------+-----------------+---------------------+---------------------+------------------+----------------+----------------+---------------------+-----------+-----------+------------+---------------+-------------+-------------+--------------------+----------+-----------+-----------+-----------------+----------------+--------------+------+
|X_Minimum|X_Maximum|Y_Minimum|Y_Maximum|Pixels_Areas|X_Perimeter|Y_Perimeter|Sum_of_Luminosity|Minimum_of_Luminosity|Maximum_of_Luminosity|Length_of_Conveyer|TypeOfSteel_A300|TypeOfSteel_A400|Steel_Plate_Thickness|Edges_Index|Empty_Index|Square_Index|Outside_X_Index|Edges_X_Index|Edges_Y_Index|Outside_Global_Index|LogOfAreas|Log_X_Index|Log_Y_Index|Orientation_Index|Luminosity_Index|SigmoidOfAreas| fault|
+---------+---------+---------+---------+------------+-----------+-----------+-----------------+---------------------+---------------------+------------------+----------------+----------------+---------------------+-----------+-----------+------------+---------------+-------------+-------------+--------------------+----------+-----------+-----------+-----------------+----------------+--------------+------+
|       42|       50|   270900|   270944|         267|         17|         44|            24220|                   76|                  108|              1687|               1|               0|                   80|     0.0498|     0.2415|      0.1818|         0.0047|       0.4706|          1.0|                 1.0|    2.4265|     0.9031|     1.6435|           0.8182|         -0.2913|        0.5822|Pastry|
|      645|      651|  2538079|  2538108|         108|         10|         30|            11397|                   84|                  123|              1687|               1|               0|                   80|     0.7647|     0.3793|      0.2069|         0.0036|          0.6|       0.9667|                 1.0|    2.0334|     0.7782|     1.4624|           0.7931|         -0.1756|        0.2984|Pastry|
+---------+---------+---------+---------+------------+-----------+-----------+-----------------+---------------------+---------------------+------------------+----------------+----------------+---------------------+-----------+-----------+------------+---------------+-------------+-------------+--------------------+----------+-----------+-----------+-----------------+----------------+--------------+------+
only showing top 2 rows

In [23]:
print(f"The class weights are {class_weights(dff, 'fault')}")
The class weights are {'Stains': 3.8511904761904763, 'Z_Scratch': 1.4593984962406015, 'Other_Faults': 0.4120144343026958, 'Bumps': 0.6897654584221748, 'K_Scatch': 0.7091706247716478, 'Dirtiness': 5.041558441558442, 'Pastry': 1.7549728752260398}
In [24]:
df_weighted = dff.withColumn('weight', weight_column(dff['fault'], class_weights(dff, column='fault')))
df_weighted.groupBy('fault', 'weight').agg(F.count('fault').alias('count')).show()
+------------+------------------+-----+
|       fault|            weight|count|
+------------+------------------+-----+
|Other_Faults|0.4120144343026958|  673|
|       Bumps|0.6897654584221748|  402|
|      Stains|3.8511904761904763|   72|
|   Z_Scratch|1.4593984962406015|  190|
|      Pastry|1.7549728752260398|  158|
|   Dirtiness| 5.041558441558442|   55|
|    K_Scatch|0.7091706247716478|  391|
+------------+------------------+-----+

In [25]:
faults = df_weighted
faults.show(2)
+---------+---------+---------+---------+------------+-----------+-----------+-----------------+---------------------+---------------------+------------------+----------------+----------------+---------------------+-----------+-----------+------------+---------------+-------------+-------------+--------------------+----------+-----------+-----------+-----------------+----------------+--------------+------+------------------+
|X_Minimum|X_Maximum|Y_Minimum|Y_Maximum|Pixels_Areas|X_Perimeter|Y_Perimeter|Sum_of_Luminosity|Minimum_of_Luminosity|Maximum_of_Luminosity|Length_of_Conveyer|TypeOfSteel_A300|TypeOfSteel_A400|Steel_Plate_Thickness|Edges_Index|Empty_Index|Square_Index|Outside_X_Index|Edges_X_Index|Edges_Y_Index|Outside_Global_Index|LogOfAreas|Log_X_Index|Log_Y_Index|Orientation_Index|Luminosity_Index|SigmoidOfAreas| fault|            weight|
+---------+---------+---------+---------+------------+-----------+-----------+-----------------+---------------------+---------------------+------------------+----------------+----------------+---------------------+-----------+-----------+------------+---------------+-------------+-------------+--------------------+----------+-----------+-----------+-----------------+----------------+--------------+------+------------------+
|       42|       50|   270900|   270944|         267|         17|         44|            24220|                   76|                  108|              1687|               1|               0|                   80|     0.0498|     0.2415|      0.1818|         0.0047|       0.4706|          1.0|                 1.0|    2.4265|     0.9031|     1.6435|           0.8182|         -0.2913|        0.5822|Pastry|1.7549728752260398|
|      645|      651|  2538079|  2538108|         108|         10|         30|            11397|                   84|                  123|              1687|               1|               0|                   80|     0.7647|     0.3793|      0.2069|         0.0036|          0.6|       0.9667|                 1.0|    2.0334|     0.7782|     1.4624|           0.7931|         -0.1756|        0.2984|Pastry|1.7549728752260398|
+---------+---------+---------+---------+------------+-----------+-----------+-----------------+---------------------+---------------------+------------------+----------------+----------------+---------------------+-----------+-----------+------------+---------------+-------------+-------------+--------------------+----------+-----------+-----------+-----------------+----------------+--------------+------+------------------+
only showing top 2 rows

In [26]:
from pyspark.ml.feature import RFormula
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline, Model
# Training and testing data
train_data,test_data  = faults.randomSplit([0.6, 0.4], 24)   # proportions [], seed for random
# all the columns except these
x_cols = list(set(faults.columns) - {'fault', 'weight'})
# RFormula will create the feature column, instead of VectorAssembler
formula = "{} ~ {}".format("fault", " + ".join(x_cols))
print("Formula : {}".format(formula))
pipeline = Pipeline(stages=[RFormula(formula = formula),
                            LogisticRegression() # no weight
                           ])
# And then run the steps with one command:
fitted_model = pipeline.fit(train_data)
Formula : fault ~ Edges_Index + Empty_Index + Length_of_Conveyer + SigmoidOfAreas + Pixels_Areas + Log_Y_Index + LogOfAreas + Y_Minimum + Y_Perimeter + Y_Maximum + Edges_Y_Index + TypeOfSteel_A300 + Outside_X_Index + TypeOfSteel_A400 + Luminosity_Index + Edges_X_Index + X_Perimeter + Orientation_Index + Steel_Plate_Thickness + X_Minimum + Sum_of_Luminosity + Maximum_of_Luminosity + X_Maximum + Minimum_of_Luminosity + Log_X_Index + Square_Index + Outside_Global_Index
In [27]:
lm = fitted_model.stages[-1]
classification_model_summary(lm, formula, x_cols)
Classification model summary
Formula:
	fault ~ Edges_Index + Empty_Index + Length_of_Conveyer + SigmoidOfAreas + Pixels_Areas + Log_Y_Index + LogOfAreas + Y_Minimum + Y_Perimeter + Y_Maximum + Edges_Y_Index + TypeOfSteel_A300 + Outside_X_Index + TypeOfSteel_A400 + Luminosity_Index + Edges_X_Index + X_Perimeter + Orientation_Index + Steel_Plate_Thickness + X_Minimum + Sum_of_Luminosity + Maximum_of_Luminosity + X_Maximum + Minimum_of_Luminosity + Log_X_Index + Square_Index + Outside_Global_Index
Equation:
	0: -0.10 * Edges_Index + 0.54 * Empty_Index + 0.01 * Length_of_Conveyer + 0.49 * SigmoidOfAreas + -0.00 * Pixels_Areas + -1.21 * Log_Y_Index + 1.36 * LogOfAreas + 0.00 * Y_Minimum + 0.00 * Y_Perimeter + 0.00 * Y_Maximum + -3.14 * Edges_Y_Index + -0.60 * TypeOfSteel_A300 + 53.09 * Outside_X_Index + 0.60 * TypeOfSteel_A400 + -7.03 * Luminosity_Index + 1.63 * Edges_X_Index + -0.01 * X_Perimeter + 2.62 * Orientation_Index + 0.04 * Steel_Plate_Thickness + -0.00 * X_Minimum + -0.00 * Sum_of_Luminosity + 0.03 * Maximum_of_Luminosity + 0.00 * X_Maximum + 0.02 * Minimum_of_Luminosity + 2.02 * Log_X_Index + 0.25 * Square_Index + -0.49 * Outside_Global_Index  -19.78

	1: 0.81 * Edges_Index + -1.54 * Empty_Index + 0.01 * Length_of_Conveyer + 1.85 * SigmoidOfAreas + -0.00 * Pixels_Areas + 0.48 * Log_Y_Index + 0.57 * LogOfAreas + 0.00 * Y_Minimum + -0.01 * Y_Perimeter + 0.00 * Y_Maximum + -1.22 * Edges_Y_Index + 0.11 * TypeOfSteel_A300 + 51.67 * Outside_X_Index + -0.11 * TypeOfSteel_A400 + -7.05 * Luminosity_Index + 0.47 * Edges_X_Index + 0.00 * X_Perimeter + 0.94 * Orientation_Index + 0.03 * Steel_Plate_Thickness + 0.00 * X_Minimum + 0.00 * Sum_of_Luminosity + 0.04 * Maximum_of_Luminosity + -0.00 * X_Maximum + 0.01 * Minimum_of_Luminosity + 0.95 * Log_X_Index + 1.53 * Square_Index + 0.09 * Outside_Global_Index  -19.60

	2: -4.23 * Edges_Index + -8.87 * Empty_Index + 0.01 * Length_of_Conveyer + 10.28 * SigmoidOfAreas + 0.00 * Pixels_Areas + 5.50 * Log_Y_Index + -1.52 * LogOfAreas + -0.00 * Y_Minimum + -0.00 * Y_Perimeter + -0.00 * Y_Maximum + -5.06 * Edges_Y_Index + -0.11 * TypeOfSteel_A300 + 0.89 * Outside_X_Index + 0.11 * TypeOfSteel_A400 + 21.06 * Luminosity_Index + 10.23 * Edges_X_Index + 0.01 * X_Perimeter + -4.28 * Orientation_Index + -0.16 * Steel_Plate_Thickness + -0.00 * X_Minimum + 0.00 * Sum_of_Luminosity + -0.05 * Maximum_of_Luminosity + -0.00 * X_Maximum + -0.04 * Minimum_of_Luminosity + -4.03 * Log_X_Index + 3.70 * Square_Index + 2.25 * Outside_Global_Index  2.86

	3: -2.05 * Edges_Index + -0.18 * Empty_Index + -0.03 * Length_of_Conveyer + 2.77 * SigmoidOfAreas + 0.00 * Pixels_Areas + 0.66 * Log_Y_Index + 0.66 * LogOfAreas + -0.00 * Y_Minimum + 0.00 * Y_Perimeter + -0.00 * Y_Maximum + -0.87 * Edges_Y_Index + 2.18 * TypeOfSteel_A300 + -57.75 * Outside_X_Index + -2.18 * TypeOfSteel_A400 + 5.18 * Luminosity_Index + -0.67 * Edges_X_Index + 0.00 * X_Perimeter + 0.86 * Orientation_Index + 0.02 * Steel_Plate_Thickness + -0.00 * X_Minimum + 0.00 * Sum_of_Luminosity + -0.18 * Maximum_of_Luminosity + -0.00 * X_Maximum + 0.09 * Minimum_of_Luminosity + 3.23 * Log_X_Index + 1.67 * Square_Index + -0.75 * Outside_Global_Index  51.27

	4: -0.13 * Edges_Index + -3.29 * Empty_Index + 0.01 * Length_of_Conveyer + 2.09 * SigmoidOfAreas + -0.00 * Pixels_Areas + -3.85 * Log_Y_Index + -0.11 * LogOfAreas + 0.00 * Y_Minimum + 0.00 * Y_Perimeter + 0.00 * Y_Maximum + 9.73 * Edges_Y_Index + -0.63 * TypeOfSteel_A300 + -47.89 * Outside_X_Index + 0.63 * TypeOfSteel_A400 + -1.66 * Luminosity_Index + 1.56 * Edges_X_Index + 0.00 * X_Perimeter + 6.52 * Orientation_Index + 0.03 * Steel_Plate_Thickness + 0.00 * X_Minimum + -0.00 * Sum_of_Luminosity + 0.05 * Maximum_of_Luminosity + -0.00 * X_Maximum + -0.05 * Minimum_of_Luminosity + 2.69 * Log_X_Index + -3.45 * Square_Index + -1.42 * Outside_Global_Index  -24.75

	5: 3.11 * Edges_Index + 16.24 * Empty_Index + -0.02 * Length_of_Conveyer + -14.60 * SigmoidOfAreas + 0.00 * Pixels_Areas + -4.26 * Log_Y_Index + -2.75 * LogOfAreas + -0.00 * Y_Minimum + 0.00 * Y_Perimeter + -0.00 * Y_Maximum + 2.48 * Edges_Y_Index + -0.52 * TypeOfSteel_A300 + 19.59 * Outside_X_Index + 0.52 * TypeOfSteel_A400 + 3.33 * Luminosity_Index + -6.65 * Edges_X_Index + 0.00 * X_Perimeter + -4.63 * Orientation_Index + 0.01 * Steel_Plate_Thickness + 0.00 * X_Minimum + 0.00 * Sum_of_Luminosity + 0.06 * Maximum_of_Luminosity + 0.00 * X_Maximum + -0.10 * Minimum_of_Luminosity + -4.59 * Log_X_Index + 3.55 * Square_Index + 0.69 * Outside_Global_Index  32.22

	6: 2.59 * Edges_Index + -2.89 * Empty_Index + 0.01 * Length_of_Conveyer + -2.88 * SigmoidOfAreas + -0.00 * Pixels_Areas + 2.67 * Log_Y_Index + 1.79 * LogOfAreas + 0.00 * Y_Minimum + -0.00 * Y_Perimeter + 0.00 * Y_Maximum + -1.93 * Edges_Y_Index + -0.45 * TypeOfSteel_A300 + -19.59 * Outside_X_Index + 0.45 * TypeOfSteel_A400 + -13.83 * Luminosity_Index + -6.56 * Edges_X_Index + -0.01 * X_Perimeter + -2.03 * Orientation_Index + 0.03 * Steel_Plate_Thickness + 0.00 * X_Minimum + -0.00 * Sum_of_Luminosity + 0.05 * Maximum_of_Luminosity + 0.00 * X_Maximum + 0.07 * Minimum_of_Luminosity + -0.27 * Log_X_Index + -7.24 * Square_Index + -0.37 * Outside_Global_Index  -22.22


Accuracy: 0.751 FPR: 0.086 TPR: 0.751
F-measure: 0.749 Precision: 0.750 Recall: 0.751
              Labels
                0.0       1.0       2.0       3.0       4.0       5.0       6.0
F-measure     0.693     0.636     0.954     0.855     0.604     0.882     0.679
Precision     0.672     0.659     0.954     0.822     0.655     0.872     0.731
Recall        0.715     0.614     0.954     0.890     0.561     0.891     0.633
In [28]:
predictions = fitted_model.transform(train_data)
# Confusion matrix from labels and predictions
metrics = MulticlassMetrics(predictions.select('label', 'prediction').rdd.map(tuple))
confusion_matrix = metrics.confusionMatrix().toArray()
# Pandas DataFrame from Spark confusion matrix
cnf_matrix = pd.DataFrame(confusion_matrix)
plt.figure(figsize = (10,7))
p = sns.heatmap(cnf_matrix/np.sum(cnf_matrix), annot=True, fmt=".2%", linewidth=0.5, annot_kws={'fontsize':10}, cmap='RdBu')
# If you want numbers, instead of percent
# p = sns.heatmap(cnf_matrix, annot=True, fmt=",.1f", linewidth=0.5, annot_kws={'fontsize':10}, cmap='RdBu');
p.set(xlabel='Predicted', ylabel='Actual', title='Confusion Matrix')
Out[28]:
[Text(0.5, 43.249999999999986, 'Predicted'),
 Text(91.25, 0.5, 'Actual'),
 Text(0.5, 1.0, 'Confusion Matrix')]
No description has been provided for this image
In [29]:
from pyspark.ml.feature import RFormula
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline, Model
# Training and testing data
train_data,test_data  = faults.randomSplit([0.6, 0.4], 24)   # proportions [], seed for random
# all the columns except these
x_cols = list(set(faults.columns) - {'fault', 'weight'})
# RFormula will create the feature column, instead of VectorAssembler
formula = "{} ~ {}".format("fault", " + ".join(x_cols))
print("Formula : {}".format(formula))
pipeline = Pipeline(stages=[RFormula(formula = formula),
                            LogisticRegression(weightCol='weight') # use weight
                           ])
# And then run the steps with one command:
fitted_model = pipeline.fit(train_data)
Formula : fault ~ Edges_Index + Empty_Index + Length_of_Conveyer + SigmoidOfAreas + Pixels_Areas + Log_Y_Index + LogOfAreas + Y_Minimum + Y_Perimeter + Y_Maximum + Edges_Y_Index + TypeOfSteel_A300 + Outside_X_Index + TypeOfSteel_A400 + Luminosity_Index + Edges_X_Index + X_Perimeter + Orientation_Index + Steel_Plate_Thickness + X_Minimum + Sum_of_Luminosity + Maximum_of_Luminosity + X_Maximum + Minimum_of_Luminosity + Log_X_Index + Square_Index + Outside_Global_Index
In [30]:
lm = fitted_model.stages[1]
classification_model_summary(lm, formula, x_cols)
Classification model summary
Formula:
	fault ~ Edges_Index + Empty_Index + Length_of_Conveyer + SigmoidOfAreas + Pixels_Areas + Log_Y_Index + LogOfAreas + Y_Minimum + Y_Perimeter + Y_Maximum + Edges_Y_Index + TypeOfSteel_A300 + Outside_X_Index + TypeOfSteel_A400 + Luminosity_Index + Edges_X_Index + X_Perimeter + Orientation_Index + Steel_Plate_Thickness + X_Minimum + Sum_of_Luminosity + Maximum_of_Luminosity + X_Maximum + Minimum_of_Luminosity + Log_X_Index + Square_Index + Outside_Global_Index
Equation:
	0: -0.74 * Edges_Index + -0.14 * Empty_Index + 0.01 * Length_of_Conveyer + 2.43 * SigmoidOfAreas + -0.00 * Pixels_Areas + -1.64 * Log_Y_Index + 1.93 * LogOfAreas + 0.00 * Y_Minimum + 0.00 * Y_Perimeter + 0.00 * Y_Maximum + -3.53 * Edges_Y_Index + -0.58 * TypeOfSteel_A300 + 77.12 * Outside_X_Index + 0.58 * TypeOfSteel_A400 + -12.94 * Luminosity_Index + 3.24 * Edges_X_Index + -0.01 * X_Perimeter + 3.73 * Orientation_Index + 0.05 * Steel_Plate_Thickness + -0.00 * X_Minimum + -0.00 * Sum_of_Luminosity + 0.06 * Maximum_of_Luminosity + 0.00 * X_Maximum + 0.03 * Minimum_of_Luminosity + 0.94 * Log_X_Index + 0.54 * Square_Index + -1.11 * Outside_Global_Index  -26.20

	1: 0.32 * Edges_Index + -2.15 * Empty_Index + 0.01 * Length_of_Conveyer + 4.07 * SigmoidOfAreas + -0.00 * Pixels_Areas + -0.22 * Log_Y_Index + 0.65 * LogOfAreas + 0.00 * Y_Minimum + -0.01 * Y_Perimeter + 0.00 * Y_Maximum + -1.54 * Edges_Y_Index + 0.08 * TypeOfSteel_A300 + 78.18 * Outside_X_Index + -0.08 * TypeOfSteel_A400 + -13.87 * Luminosity_Index + 1.85 * Edges_X_Index + 0.00 * X_Perimeter + 2.21 * Orientation_Index + 0.04 * Steel_Plate_Thickness + -0.00 * X_Minimum + 0.00 * Sum_of_Luminosity + 0.07 * Maximum_of_Luminosity + -0.00 * X_Maximum + 0.02 * Minimum_of_Luminosity + 0.66 * Log_X_Index + 2.03 * Square_Index + -0.32 * Outside_Global_Index  -25.91

	2: -5.25 * Edges_Index + -9.24 * Empty_Index + 0.01 * Length_of_Conveyer + 12.22 * SigmoidOfAreas + 0.00 * Pixels_Areas + 1.67 * Log_Y_Index + -1.42 * LogOfAreas + -0.00 * Y_Minimum + -0.00 * Y_Perimeter + -0.00 * Y_Maximum + -4.57 * Edges_Y_Index + 0.44 * TypeOfSteel_A300 + 14.14 * Outside_X_Index + -0.44 * TypeOfSteel_A400 + 22.06 * Luminosity_Index + 11.18 * Edges_X_Index + 0.01 * X_Perimeter + -1.14 * Orientation_Index + -0.20 * Steel_Plate_Thickness + -0.00 * X_Minimum + 0.00 * Sum_of_Luminosity + -0.05 * Maximum_of_Luminosity + -0.00 * X_Maximum + -0.05 * Minimum_of_Luminosity + -2.24 * Log_X_Index + 4.68 * Square_Index + 1.57 * Outside_Global_Index  6.00

	3: -2.61 * Edges_Index + -3.03 * Empty_Index + -0.02 * Length_of_Conveyer + 5.20 * SigmoidOfAreas + 0.00 * Pixels_Areas + 3.32 * Log_Y_Index + -0.39 * LogOfAreas + -0.00 * Y_Minimum + 0.01 * Y_Perimeter + -0.00 * Y_Maximum + -0.26 * Edges_Y_Index + 2.22 * TypeOfSteel_A300 + -35.29 * Outside_X_Index + -2.22 * TypeOfSteel_A400 + 2.35 * Luminosity_Index + -0.54 * Edges_X_Index + 0.00 * X_Perimeter + -0.62 * Orientation_Index + 0.03 * Steel_Plate_Thickness + -0.00 * X_Minimum + 0.00 * Sum_of_Luminosity + -0.21 * Maximum_of_Luminosity + -0.00 * X_Maximum + 0.13 * Minimum_of_Luminosity + 1.44 * Log_X_Index + 1.47 * Square_Index + -0.75 * Outside_Global_Index  44.93

	4: -1.10 * Edges_Index + -5.01 * Empty_Index + 0.01 * Length_of_Conveyer + 3.65 * SigmoidOfAreas + -0.00 * Pixels_Areas + -2.58 * Log_Y_Index + -1.29 * LogOfAreas + 0.00 * Y_Minimum + 0.01 * Y_Perimeter + 0.00 * Y_Maximum + 10.58 * Edges_Y_Index + -0.61 * TypeOfSteel_A300 + -96.91 * Outside_X_Index + 0.61 * TypeOfSteel_A400 + 3.16 * Luminosity_Index + 3.57 * Edges_X_Index + 0.00 * X_Perimeter + 6.48 * Orientation_Index + 0.04 * Steel_Plate_Thickness + 0.00 * X_Minimum + -0.00 * Sum_of_Luminosity + 0.03 * Maximum_of_Luminosity + -0.00 * X_Maximum + -0.07 * Minimum_of_Luminosity + 3.59 * Log_X_Index + -4.46 * Square_Index + -1.60 * Outside_Global_Index  -16.86

	5: 7.87 * Edges_Index + 17.75 * Empty_Index + -0.01 * Length_of_Conveyer + -22.16 * SigmoidOfAreas + 0.00 * Pixels_Areas + -5.31 * Log_Y_Index + -3.86 * LogOfAreas + -0.00 * Y_Minimum + 0.00 * Y_Perimeter + -0.00 * Y_Maximum + 1.56 * Edges_Y_Index + -0.77 * TypeOfSteel_A300 + 18.33 * Outside_X_Index + 0.77 * TypeOfSteel_A400 + 15.57 * Luminosity_Index + -12.69 * Edges_X_Index + 0.00 * X_Perimeter + -8.64 * Orientation_Index + 0.00 * Steel_Plate_Thickness + 0.00 * X_Minimum + 0.00 * Sum_of_Luminosity + 0.00 * Maximum_of_Luminosity + 0.00 * X_Maximum + -0.13 * Minimum_of_Luminosity + -6.15 * Log_X_Index + 3.21 * Square_Index + 2.24 * Outside_Global_Index  50.92

	6: 1.50 * Edges_Index + 1.82 * Empty_Index + 0.01 * Length_of_Conveyer + -5.40 * SigmoidOfAreas + -0.00 * Pixels_Areas + 4.76 * Log_Y_Index + 4.38 * LogOfAreas + 0.00 * Y_Minimum + -0.01 * Y_Perimeter + 0.00 * Y_Maximum + -2.25 * Edges_Y_Index + -0.78 * TypeOfSteel_A300 + -55.56 * Outside_X_Index + 0.78 * TypeOfSteel_A400 + -16.33 * Luminosity_Index + -6.62 * Edges_X_Index + -0.02 * X_Perimeter + -2.03 * Orientation_Index + 0.04 * Steel_Plate_Thickness + 0.00 * X_Minimum + -0.00 * Sum_of_Luminosity + 0.09 * Maximum_of_Luminosity + 0.00 * X_Maximum + 0.06 * Minimum_of_Luminosity + 1.77 * Log_X_Index + -7.47 * Square_Index + -0.03 * Outside_Global_Index  -32.87


Accuracy: 0.796 FPR: 0.034 TPR: 0.796
F-measure: 0.791 Precision: 0.788 Recall: 0.796
              Labels
                0.0       1.0       2.0       3.0       4.0       5.0       6.0
F-measure     0.485     0.668     0.951     0.898     0.745     0.963     0.815
Precision     0.557     0.665     0.943     0.880     0.717     0.948     0.798
Recall        0.429     0.671     0.959     0.917     0.776     0.978     0.833
In [31]:
predictions = fitted_model.transform(train_data)
# Confusion matrix from labels and predictions
metrics = MulticlassMetrics(predictions.select('label', 'prediction').rdd.map(tuple))
confusion_matrix = metrics.confusionMatrix().toArray()
# Pandas DataFrame from Spark confusion matrix
cnf_matrix = pd.DataFrame(confusion_matrix)
plt.figure(figsize = (10,7))
p = sns.heatmap(cnf_matrix/np.sum(cnf_matrix), annot=True, fmt=".2%", linewidth=0.5, annot_kws={'fontsize':10}, cmap='RdBu')
# If you want numbers, instead of percent
# p = sns.heatmap(cnf_matrix, annot=True, fmt=",.1f", linewidth=0.5, annot_kws={'fontsize':10}, cmap='RdBu');
p.set(xlabel='Predicted', ylabel='Actual', title='Confusion Matrix')
Out[31]:
[Text(0.5, 43.249999999999986, 'Predicted'),
 Text(91.25, 0.5, 'Actual'),
 Text(0.5, 1.0, 'Confusion Matrix')]
No description has been provided for this image

数据重采样方法¶

  • 对多数类进行欠采样:随机从多数类中删除样本,使多数类和少数类的样本数量相等。
  • 对少数群体进行过度采样:随机重复少数类中的样本,使少数类和多数类的样本数量相等。
  • 合成少数过采样技术(SMOTE):通过合成新的少数类样本,使少数类和多数类的样本数量相等。

image.png

数据重采样比较和用例¶

  • 欠采样会减少多数类别的规模,可能会导致信息丢失,但计算效率较高。

  • 过采样会增加少数类的规模,当您拥有较小的数据集但有足够计算资源时,过采样是合适的。

  • SMOTE通过在现有少数类样本之间进行插值、引入多样性和处理复杂的决策边界来生成合成样本。

数据重采样比较和用例¶

数据重采样方法 适用场景 典型用例
多数类欠采样 当数据集较大且类别极度不平衡时,通过减少多数类样本数量来实现类别平衡。 - 大型数据集但计算资源有限。
- 多数类样本全部保留并非必要。
- 希望通过减少多数类简化问题。
少数类过采样 当少数类样本较少时,通过增加少数类样本数量来实现类别平衡。 - 小型数据集且类别分布不均。
- 希望增强少数类信息以提升模型表现。
- 有足够计算资源处理数据扩充。
SMOTE(合成少数类过采样技术) 通过生成少数类的合成样本来平衡数据集,适用于少数类样本极少且简单过采样易导致过拟合的场景。 - 小型且类别不平衡的数据集。
- 希望通过合成样本引入多样性,减少过拟合。
- 需要捕捉少数类复杂决策边界和模式。

classification on imbanlance data¶

In [32]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, DoubleType, StringType, IntegerType
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.classification import RandomForestClassifier, RandomForestClassificationModel
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics
from sklearn.metrics import confusion_matrix
In [33]:
from pyspark.sql import SparkSession
pyspark_df = spark.read.csv('../dataset/creditcard.csv', inferSchema=True, header=True, mode='DROPMALFORMED')
pyspark_df.printSchema()
root
 |-- Time: double (nullable = true)
 |-- V1: double (nullable = true)
 |-- V2: double (nullable = true)
 |-- V3: double (nullable = true)
 |-- V4: double (nullable = true)
 |-- V5: double (nullable = true)
 |-- V6: double (nullable = true)
 |-- V7: double (nullable = true)
 |-- V8: double (nullable = true)
 |-- V9: double (nullable = true)
 |-- V10: double (nullable = true)
 |-- V11: double (nullable = true)
 |-- V12: double (nullable = true)
 |-- V13: double (nullable = true)
 |-- V14: double (nullable = true)
 |-- V15: double (nullable = true)
 |-- V16: double (nullable = true)
 |-- V17: double (nullable = true)
 |-- V18: double (nullable = true)
 |-- V19: double (nullable = true)
 |-- V20: double (nullable = true)
 |-- V21: double (nullable = true)
 |-- V22: double (nullable = true)
 |-- V23: double (nullable = true)
 |-- V24: double (nullable = true)
 |-- V25: double (nullable = true)
 |-- V26: double (nullable = true)
 |-- V27: double (nullable = true)
 |-- V28: double (nullable = true)
 |-- Amount: double (nullable = true)
 |-- Class: integer (nullable = true)

In [34]:
pyspark_df.filter(F.col('Class')==1).count()
Out[34]:
492
In [35]:
pyspark_df.filter(F.col('Class')==0).count()
Out[35]:
284315
In [36]:
# Set the train/test ratio
train_test_ratio = 0.8
# Split the data into training and testing sets using stratified sampling
train_data_class_b = pyspark_df.filter(F.col('Class')=='1').sample(fraction=train_test_ratio, seed=88)
train_data_class_a = pyspark_df.filter(F.col('Class')=='0').sample(fraction=train_test_ratio, seed=88)
train_data = train_data_class_a.union(train_data_class_b)
test_data = pyspark_df.subtract(train_data)
# Print the number of samples in each set
print("Training Set Size:", train_data.count())
print("Testing Set Size:", test_data.count())
print("Class B fraction in Training Set Size:", train_data.filter(F.col('Class')=='1').count()/train_data.count())
print("Class B fraction in Test Set Size:", test_data.filter(F.col('Class')=='1').count()/test_data.count())
Training Set Size: 227866
Testing Set Size: 56580
Class B fraction in Training Set Size: 0.0016720353190032738
Class B fraction in Test Set Size: 0.001855779427359491
In [37]:
pyspark_df = pyspark_df.withColumnRenamed("V1", "PC1") \
    .withColumnRenamed("V2", "PC2") \
    .withColumnRenamed("V3", "PC3") \
    .withColumnRenamed("V4", "PC4") \
    .withColumnRenamed("V5", "PC5")
# Set the train/test ratio
train_test_ratio = 0.8
# Split the data into training and testing sets using stratified sampling
train_data_class_b = pyspark_df.filter(F.col('Class')=='1').sample(fraction=train_test_ratio, seed=88)
train_data_class_a = pyspark_df.filter(F.col('Class')=='0').sample(fraction=train_test_ratio, seed=88)
train_data = train_data_class_a.union(train_data_class_b)
# dataframe.subtract 用于从完整数据集中移除训练集样本,得到测试集
test_data = pyspark_df.subtract(train_data)
# Print the number of samples in each set
print("Training Set Size:", train_data.count())
print("Testing Set Size:", test_data.count())
print("Class B fraction in Training Set Size:", train_data.filter(F.col('Class')=='1').count()/train_data.count())
print("Class B fraction in Test Set Size:", test_data.filter(F.col('Class')=='1').count()/test_data.count())
Training Set Size: 227866
Testing Set Size: 56580
Class B fraction in Training Set Size: 0.0016720353190032738
Class B fraction in Test Set Size: 0.001855779427359491
In [38]:
# Assemble the features into a Vector to feed the Model.
feature_list = ["PC1","PC2","PC3","PC4","PC5"]
assembler = VectorAssembler(inputCols=feature_list, outputCol='features_for_model')
model_df1 = assembler.transform(train_data)
# Convert the label from String to Numeric using a StringIndexer
indexer = StringIndexer(inputCol='Class', outputCol='label')
model_df2 = indexer.fit(model_df1).transform(model_df1)
In [39]:
# To see what model_df looks like (after Vector Assembler)
model_df1.toPandas().head()
Out[39]:
Time PC1 PC2 PC3 PC4 PC5 V6 V7 V8 V9 ... V22 V23 V24 V25 V26 V27 V28 Amount Class features_for_model
0 0.0 1.191857 0.266151 0.166480 0.448154 0.060018 -0.082361 -0.078803 0.085102 -0.255425 ... -0.638672 0.101288 -0.339846 0.167170 0.125895 -0.008983 0.014724 2.69 0 [1.191857111, 0.266150712, 0.166480113, 0.4481...
1 1.0 -1.358354 -1.340163 1.773209 0.379780 -0.503198 1.800499 0.791461 0.247676 -1.514654 ... 0.771679 0.909412 -0.689281 -0.327642 -0.139097 -0.055353 -0.059752 378.66 0 [-1.358354062, -1.340163075, 1.773209343, 0.37...
2 2.0 -1.158233 0.877737 1.548718 0.403034 -0.407193 0.095921 0.592941 -0.270533 0.817739 ... 0.798278 -0.137458 0.141267 -0.206010 0.502292 0.219422 0.215153 69.99 0 [-1.158233093, 0.877736755, 1.548717847, 0.403...
3 2.0 -0.425966 0.960523 1.141109 -0.168252 0.420987 -0.029728 0.476201 0.260314 -0.568671 ... -0.559825 -0.026398 -0.371427 -0.232794 0.105915 0.253844 0.081080 3.67 0 [-0.425965884, 0.960523045, 1.141109342, -0.16...
4 4.0 1.229658 0.141004 0.045371 1.202613 0.191881 0.272708 -0.005159 0.081213 0.464960 ... -0.270710 -0.154104 -0.780055 0.750137 -0.257237 0.034507 0.005168 4.99 0 [1.229657635, 0.141003507, 0.045370774, 1.2026...

5 rows × 32 columns

In [40]:
# To see what model_df looks like (after StringIndexer) - Where Class = 0
model_df2.filter(F.col('Class')=='0').toPandas().head()
Out[40]:
Time PC1 PC2 PC3 PC4 PC5 V6 V7 V8 V9 ... V23 V24 V25 V26 V27 V28 Amount Class features_for_model label
0 0.0 1.191857 0.266151 0.166480 0.448154 0.060018 -0.082361 -0.078803 0.085102 -0.255425 ... 0.101288 -0.339846 0.167170 0.125895 -0.008983 0.014724 2.69 0 [1.191857111, 0.266150712, 0.166480113, 0.4481... 0.0
1 1.0 -1.358354 -1.340163 1.773209 0.379780 -0.503198 1.800499 0.791461 0.247676 -1.514654 ... 0.909412 -0.689281 -0.327642 -0.139097 -0.055353 -0.059752 378.66 0 [-1.358354062, -1.340163075, 1.773209343, 0.37... 0.0
2 2.0 -1.158233 0.877737 1.548718 0.403034 -0.407193 0.095921 0.592941 -0.270533 0.817739 ... -0.137458 0.141267 -0.206010 0.502292 0.219422 0.215153 69.99 0 [-1.158233093, 0.877736755, 1.548717847, 0.403... 0.0
3 2.0 -0.425966 0.960523 1.141109 -0.168252 0.420987 -0.029728 0.476201 0.260314 -0.568671 ... -0.026398 -0.371427 -0.232794 0.105915 0.253844 0.081080 3.67 0 [-0.425965884, 0.960523045, 1.141109342, -0.16... 0.0
4 4.0 1.229658 0.141004 0.045371 1.202613 0.191881 0.272708 -0.005159 0.081213 0.464960 ... -0.154104 -0.780055 0.750137 -0.257237 0.034507 0.005168 4.99 0 [1.229657635, 0.141003507, 0.045370774, 1.2026... 0.0

5 rows × 33 columns

In [41]:
# To see what model_df looks like (after StringIndexer) - Where Class = 1
model_df2.filter(F.col('Class')=='1').toPandas().head()
Out[41]:
Time PC1 PC2 PC3 PC4 PC5 V6 V7 V8 V9 ... V23 V24 V25 V26 V27 V28 Amount Class features_for_model label
0 472.0 -3.043541 -3.157307 1.088463 2.288644 1.359805 -1.064823 0.325574 -0.067794 -0.270953 ... 1.375966 -0.293803 0.279798 -0.145362 -0.252773 0.035764 529.00 1 [-3.043540624, -3.157307121, 1.08846278, 2.288... 1.0
1 4462.0 -2.303350 1.759247 -0.359745 2.330243 -0.821628 -0.075788 0.562320 -0.399147 -0.238253 ... 0.172726 -0.087330 -0.156114 -0.542628 0.039566 -0.153029 239.93 1 [-2.303349568, 1.75924746, -0.359744743, 2.330... 1.0
2 7519.0 1.234235 3.019740 -4.304597 4.732795 3.624201 -1.357746 1.713445 -0.496358 -1.282858 ... -0.656805 -1.632653 1.488901 0.566797 -0.010016 0.146793 1.00 1 [1.234235046, 3.019740421, -4.304596885, 4.732... 1.0
3 7526.0 0.008430 4.137837 -6.240697 6.675732 0.768307 -3.353060 -1.631735 0.154612 -2.795892 ... -0.539528 0.128940 1.488481 0.507963 0.735822 0.513574 1.00 1 [0.008430365, 4.137836835, -6.240696572, 6.675... 1.0
4 7535.0 0.026779 4.132464 -6.560600 6.348557 1.329666 -2.513479 -1.689102 0.303253 -3.139409 ... -0.669605 -0.759908 1.605056 0.540675 0.737040 0.496699 1.00 1 [0.026779226, 4.132463897, -6.560599968, 6.348... 1.0

5 rows × 33 columns

In [42]:
def preprocess_data(df):
    # Assemble the features into a Vector to feed the Model.
    feature_list = ["PC1","PC2","PC3","PC4","PC5"]
    assembler = VectorAssembler(inputCols=feature_list, outputCol='features_for_model')
    model_df1 = assembler.transform(df)
    # Convert the label from String to Numeric using a StringIndexer
    indexer = StringIndexer(inputCol='Class', outputCol='label')
    model_df2 = indexer.fit(model_df1).transform(model_df1)
    return model_df2

def train_model(train_df, smote=0):
    if smote == 0:
        model_df = preprocess_data(train_df)
    elif smote == 1:
        model_df = train_df
    # Initiate the Random Forest Classifier Model
    rf = RandomForestClassifier(featuresCol='features_for_model', labelCol='label', seed=88)
    rf_model = rf.fit(model_df)
    return rf_model

def test_model(test_df, model):
    model_df = preprocess_data(test_df)
    prediction = model.transform(model_df)
    return prediction
In [43]:
# Initiate the Random Forest Classifier Model
rf_model = train_model(train_data)
In [44]:
# Test the Model
prediction = test_model(test_data, rf_model)
# To see what the output looks like. There will be additional columns.
# rawPrediction, probability, prediction
prediction.toPandas().head()
Out[44]:
Time PC1 PC2 PC3 PC4 PC5 V6 V7 V8 V9 ... V26 V27 V28 Amount Class features_for_model label rawPrediction probability prediction
0 1632.0 1.299392 1.284574 -2.115055 1.203150 1.645984 -0.704279 0.714218 -0.210203 -0.766219 ... -0.230843 0.040773 0.075337 1.99 0 [1.299391629, 1.284574315, -2.115054733, 1.203... 0.0 [19.964261192832808, 0.035738807167193014] [0.9982130596416404, 0.0017869403583596506] 0.0
1 38843.0 1.175618 0.062691 0.457418 0.685153 -0.548273 -0.787123 -0.037092 -0.023735 0.238564 ... 0.187040 -0.031048 0.013726 14.40 0 [1.175617558, 0.062691047, 0.457418485, 0.6851... 0.0 [19.98916846637526, 0.010831533624742635] [0.9994584233187629, 0.0005415766812371317] 0.0
2 55051.0 -0.431918 -3.826174 -1.665834 -0.012201 -1.317788 0.027159 0.920153 -0.368210 -0.513858 ... -0.080679 -0.213854 0.162163 1022.40 0 [-0.431917549, -3.826173594, -1.665833845, -0.... 0.0 [19.98916846637526, 0.010831533624742635] [0.9994584233187629, 0.0005415766812371317] 0.0
3 86500.0 -0.188338 0.753188 0.535544 -0.242559 0.103888 -0.044430 0.423675 0.077219 0.369646 ... -0.238929 0.166404 0.236026 43.74 0 [-0.188337516, 0.753188458, 0.535544219, -0.24... 0.0 [19.98916846637526, 0.010831533624742635] [0.9994584233187629, 0.0005415766812371317] 0.0
4 87405.0 -1.232733 0.931978 1.008111 -0.686739 -0.433994 -0.430197 0.157923 0.295404 0.874859 ... 0.589815 0.305465 0.079097 29.99 0 [-1.232733294, 0.931978281, 1.008110975, -0.68... 0.0 [19.98916846637526, 0.010831533624742635] [0.9994584233187629, 0.0005415766812371317] 0.0

5 rows × 36 columns

In [45]:
# Evaluate the model using a BinaryClassificationEvaluator
evaluator = BinaryClassificationEvaluator(labelCol='label')
auroc = evaluator.evaluate(prediction,{evaluator.metricName:'areaUnderROC'})
aupr = evaluator.evaluate(prediction,{evaluator.metricName:'areaUnderPR'})
print(f"AUROC: {auroc}")
print(f"AUPR: {aupr}")
AUROC: 0.9429864879107906
AUPR: 0.4550431778587215
In [46]:
pred_rdd = prediction.select('prediction','label').rdd.map(tuple)
# Evaluate using MultiClass Metrics
metrics = MulticlassMetrics(pred_rdd)
accuracy_o = metrics.accuracy   # Positive class
precision_o = metrics.precision(1.0)  # Positive class
recall_o = metrics.recall(1.0)  # Positive class
f1_o = metrics.fMeasure(1.0)  # Positive class
print("Accuracy:", accuracy_o)
print("Precision:", precision_o)
print("Recall:", recall_o)
print("F1-score:", f1_o)
Accuracy: 0.9986037469070342
Precision: 0.7241379310344828
Recall: 0.4
F1-score: 0.5153374233128835
In [47]:
# Returns confusion matrix: predicted classes are in columns, they are ordered by class label ascending, as in “labels”.
# Meaning 0 1 in columns. Rows also 0 then 1.
print(metrics.confusionMatrix())
confusion_matrix = metrics.confusionMatrix().toArray()
labels = ['Class 0', 'Class 1']
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(confusion_matrix, cmap=plt.cm.Blues)
# Add actual values to the cells
for i in range(len(labels)):
    for j in range(len(labels)):
        plt.text(j, i, str(int(confusion_matrix[i, j])), fontsize=12, ha='center', va='center', color='red')
fig.colorbar(cax)
ax.set_xticklabels([''] + labels)
ax.set_yticklabels([''] + labels)
plt.xlabel('Predicted')
plt.ylabel('Expected')
plt.show()
DenseMatrix([[5.6459e+04, 1.6000e+01],
             [6.3000e+01, 4.2000e+01]])
No description has been provided for this image
In [48]:
# Traditional way to evaluate TN/TP/FP/FN
# Countercheck with confusion matrix above.
tn = prediction.filter((F.col('label')==0)&(F.col('prediction')==0)).count()
tp = prediction.filter((F.col('label')==1)&(F.col('prediction')==1)).count()
fp = prediction.filter((F.col('label')==0)&(F.col('prediction')==1)).count()
fn = prediction.filter((F.col('label')==1)&(F.col('prediction')==0)).count()
print(f"TN: {tn}")
print(f"TP: {tp}")
print(f"FP: {fp}")
print(f"FN: {fn}")
TN: 56459
TP: 42
FP: 16
FN: 63

undersampling majority¶

In [49]:
# 该函数用于对多数类样本进行欠采样,以实现类别平衡。
# 参数说明:
# df: 输入的 Spark DataFrame,包含类别标签列 'class'(0 为多数类,1 为少数类)。
# ratio: 欠采样后多数类与少数类的比例。例如 ratio=1 表示多数类与少数类数量相等,ratio=5 表示多数类数量为少数类的 5 倍。
# 返回值:返回一个新的 DataFrame,包含所有少数类样本和按指定比例随机抽取的多数类样本。
def undersample_majority(df, ratio=1):
    minority_count = df.filter(F.col('class')==1).count()
    whole_count = df.count()
    undersampled_majority = df.filter(F.col('class')==0)\
                                .sample(withReplacement=False, fraction=(ratio*minority_count/whole_count),seed=88)
    undersampled_df = df.filter(F.col('class')==1).union(undersampled_majority)
    return undersampled_df
In [50]:
# Create evaluation function
def evaluate_test_data(test_df, model):
    model_df = preprocess_data(test_df)
    prediction = model.transform(model_df)
    
    # Evaluate the model using a BinaryClassificationEvaluator
    evaluator = BinaryClassificationEvaluator(labelCol='label')
    auroc = evaluator.evaluate(prediction,{evaluator.metricName:'areaUnderROC'})
    aupr = evaluator.evaluate(prediction,{evaluator.metricName:'areaUnderPR'})

    print(f"AUROC: {auroc}")
    print(f"AUPR: {aupr}")
    
    # Evaluate using MultiClass Metrics
    pred_rdd = prediction.select('prediction','label').rdd.map(tuple)
    metrics = MulticlassMetrics(pred_rdd)
    accuracy_o = metrics.accuracy   # Positive class
    precision_o = metrics.precision(1.0)  # Positive class
    recall_o = metrics.recall(1.0)  # Positive class
    f1_o = metrics.fMeasure(1.0)  # Positive class

    print("Accuracy:", accuracy_o)
    print("Precision:", precision_o)
    print("Recall:", recall_o)
    print("F1-score:", f1_o)
    
    confusion_matrix = metrics.confusionMatrix().toArray()
    labels = ['Class 0', 'Class 1']
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(confusion_matrix, cmap=plt.cm.Blues)
    # Add actual values to the cells
    for i in range(len(labels)):
        for j in range(len(labels)):
            plt.text(j, i, str(int(confusion_matrix[i, j])), fontsize=12, ha='center', va='center', color='red')
    fig.colorbar(cax)
    ax.set_xticklabels([''] + labels)
    ax.set_yticklabels([''] + labels)
    plt.xlabel('Predicted')
    plt.ylabel('Expected')
    plt.show()

    # Traditional way to evaluate TN/TP/FP/FN
    # Countercheck with confusion matrix above.
    tn = prediction.filter((F.col('label')==0)&(F.col('prediction')==0)).count()
    tp = prediction.filter((F.col('label')==1)&(F.col('prediction')==1)).count()
    fp = prediction.filter((F.col('label')==0)&(F.col('prediction')==1)).count()
    fn = prediction.filter((F.col('label')==1)&(F.col('prediction')==0)).count()

    print(f"TN: {tn}")
    print(f"TP: {tp}")
    print(f"FP: {fp}")
    print(f"FN: {fn}")
    
In [51]:
# Evaluate Undersampling 1:1
undersampled_train_df_1_1 = undersample_majority(train_data,1)
rf_model1 = train_model(undersampled_train_df_1_1)
evaluate_test_data(test_data, rf_model1)
AUROC: 0.9558591875882712
AUPR: 0.29089072740117095
Accuracy: 0.9579710144927536
Precision: 0.037444037444037445
Recall: 0.8761904761904762
F1-score: 0.07181889149102264
No description has been provided for this image
TN: 54110
TP: 92
FP: 2365
FN: 13
In [52]:
# Evaluate Undersampling 1:2
undersampled_train_df_1_2 = undersample_majority(train_data,2)
rf_model2 = train_model(undersampled_train_df_1_2)
evaluate_test_data(test_data, rf_model2)
AUROC: 0.953676257931238
AUPR: 0.2836127409470774
Accuracy: 0.9836161187698833
Precision: 0.0840080971659919
Recall: 0.7904761904761904
F1-score: 0.15187557182067704
No description has been provided for this image
TN: 55570
TP: 83
FP: 905
FN: 22
In [53]:
# Evaluate Undersampling 1:5
undersampled_train_df_1_5 = undersample_majority(train_data,5)
rf_model3 = train_model(undersampled_train_df_1_5)
evaluate_test_data(test_data, rf_model3)
AUROC: 0.9508069731655391
AUPR: 0.43192419643636615
Accuracy: 0.9900141392718275
Precision: 0.1254071661237785
Recall: 0.7333333333333333
F1-score: 0.2141863699582754
No description has been provided for this image
TN: 55938
TP: 77
FP: 537
FN: 28

Oversampling minority¶

In [54]:
# Create oversampling function
def oversample_minority(df, ratio=1):
    '''
    ratio is the ratio of majority to minority
    Eg. ratio 1 is equivalent to majority:minority = 1:1
    ratio 5 is equivalent to majority:minority = 5:1
    '''
    minority_count = df.filter(F.col('class')==1).count()
    majority_count = df.filter(F.col('class')==0).count()
    
    balance_ratio = majority_count / minority_count
    
    print(f"Initial Majority:Minority ratio is {balance_ratio:.2f}:1")
    if ratio >= balance_ratio:
        print("No oversampling of minority was done as the input ratio was more than or equal to the initial ratio.")
    else:
        print(f"Oversampling of minority done such that Majority:Minority ratio is {ratio}:1")
    
    oversampled_minority = df.filter(F.col('class')==1)\
                                .sample(withReplacement=True, fraction=(balance_ratio/ratio),seed=88)
    oversampled_df = df.filter(F.col('class')==0).union(oversampled_minority)
    
    return oversampled_df
In [55]:
# Train 3 different oversamplings
# Evaluate Oversampling 1:1
oversampled_train_df_1_1 = oversample_minority(train_data,1)
rf_model4 = train_model(oversampled_train_df_1_1)
evaluate_test_data(test_data, rf_model4)
Initial Majority:Minority ratio is 597.07:1
Oversampling of minority done such that Majority:Minority ratio is 1:1
AUROC: 0.05157334682434283
AUPR: 0.0009473638365625561
Accuracy: 0.02817249911629551
Precision: 0.00032776735801300143
Recall: 0.17142857142857143
F1-score: 0.0006542837410490348
No description has been provided for this image
TN: 1576
TP: 18
FP: 54899
FN: 87
In [56]:
# Evaluate Oversampling 1:2
oversampled_train_df_1_2 = oversample_minority(train_data,2)
rf_model5 = train_model(oversampled_train_df_1_2)
evaluate_test_data(test_data, rf_model5)
Initial Majority:Minority ratio is 597.07:1
Oversampling of minority done such that Majority:Minority ratio is 2:1
AUROC: 0.9571983389194546
AUPR: 0.47697441606078206
Accuracy: 0.984729586426299
Precision: 0.08794788273615635
Recall: 0.7714285714285715
F1-score: 0.15789473684210525
No description has been provided for this image
TN: 55635
TP: 81
FP: 840
FN: 24
In [57]:
# Evaluate Oversampling 1:5
oversampled_train_df_1_5 = oversample_minority(train_data,5)
rf_model6 = train_model(oversampled_train_df_1_5)
evaluate_test_data(test_data, rf_model6)
Initial Majority:Minority ratio is 597.07:1
Oversampling of minority done such that Majority:Minority ratio is 5:1
AUROC: 0.9448737115031933
AUPR: 0.49382442323763825
Accuracy: 0.9915694591728526
Precision: 0.1423076923076923
Recall: 0.7047619047619048
F1-score: 0.23679999999999998
No description has been provided for this image
TN: 56029
TP: 74
FP: 446
FN: 31

SMOTE¶

  • 主要步骤:
    • 对每个少数类样本,选取K个最近邻。
    • 随机选择一个邻居,计算两者之间的差值向量。
    • 按照随机比例,将差值加到原样本上,得到新样本。
    • 重复上述过程,直到生成所需数量的合成样本。

SMOTE伪代码¶

image.png

In [58]:
## Used the code from below:
## https://medium.com/@hwangdb/smote-implementation-in-pyspark-76ec4ffa2f1d
## https://gist.github.com/hwang018/420e288021e9bdacd133076600a9ea8c
## https://gist.github.com/inguelberth/547db5aef8fb82527c79b1d6e2fc368c
In [59]:
# 1. pre_smote_df_process: 对原始Spark DataFrame进行预处理,包括数值特征的向量化和类别特征的字符串索引。
# 该函数确保目标列为二分类,并将数值特征组装为特征向量,类别特征进行编码,最终返回适合SMOTE处理的DataFrame。
# 2. subtract_vector_fn: 用于计算两个向量之间的差值,并乘以一个0到1之间的随机数,返回一个新的DenseVector。
# 该函数常用于SMOTE算法中生成合成样本的插值步骤。
# 3. add_vector_fn: 用于将两个向量相加,返回一个新的DenseVector。
# 该函数在SMOTE算法中用于将原始样本与插值向量相加,生成新的合成样本。

import random
import numpy as np
from functools import reduce
from pyspark.sql import DataFrame, SparkSession, Row
import pyspark.sql.functions as F
from pyspark.sql.functions import array, create_map, struct, rand,col,when,concat,substring,lit,udf,lower,sum as ps_sum,count as ps_count,row_number
from pyspark.sql.window import *
from pyspark.sql.window import Window
from pyspark.sql.types import StructField, StructType, StringType, IntegerType, LongType
from pyspark.ml.feature import VectorAssembler,BucketedRandomProjectionLSH,VectorSlicer,StringIndexer
from pyspark.ml.linalg import Vectors, VectorUDT, SparseVector, DenseVector
from pyspark.ml import Pipeline

def pre_smote_df_process(df,num_cols,cat_cols,target_col,index_suffix="_index"):
    '''
    string indexer and vector assembler
    inputs:
    * df: spark df, original
    * num_cols: numerical cols to be assembled
    * cat_cols: categorical cols to be stringindexed
    * target_col: prediction target
    * index_suffix: will be the suffix after string indexing
    output:
    * vectorized: spark df, after stringindex and vector assemble, ready for smote
    '''
    if(df.select(target_col).distinct().count() != 2):
         raise ValueError("Target col must have exactly 2 classes")
    if target_col in num_cols:
        num_cols.remove(target_col)
    # only assembled numeric columns into features
    assembler = VectorAssembler(inputCols = num_cols, outputCol = 'features_for_model')
    # index the string cols, except possibly for the label col
    assemble_stages = [StringIndexer(inputCol=column, outputCol=column+index_suffix) for column in list(set(cat_cols)-set([target_col]))]
    # add the stage of numerical vector assembler
    assemble_stages.append(assembler)
    pipeline = Pipeline(stages=assemble_stages)
    pos_vectorized = pipeline.fit(df).transform(df)
    # drop original num cols and cat cols
    drop_cols = num_cols+cat_cols
    keep_cols = [a for a in pos_vectorized.columns if a not in drop_cols]
    vectorized = pos_vectorized.select(*keep_cols).withColumn('label',pos_vectorized[target_col]).drop(target_col)
    return vectorized

def subtract_vector_fn(arr):
    a = arr[0]
    b = arr[1]
    if isinstance(a, SparseVector):
        a = a.toArray()   
    if isinstance(b, SparseVector):
        b = b.toArray()
    return DenseVector(random.uniform(0, 1)*(a-b))
    
def add_vector_fn(arr):
    a = arr[0]
    b = arr[1]
    if isinstance(a, SparseVector):
        a = a.toArray()
        
    if isinstance(b, SparseVector):
        b = b.toArray()
    return DenseVector(a+b)
In [60]:
def smote(vectorized_sdf,smote_config):
    '''
    contains logic to perform smote oversampling, given a spark df with 2 classes
    inputs:
    * vectorized_sdf: cat cols are already stringindexed, num cols are assembled into 'features' vector
      df target col should be 'label'
    * smote_config: config obj containing smote parameters
    output:
    * oversampled_df: spark df after smote oversampling
    '''
    dataInput_min = vectorized_sdf[vectorized_sdf['label'] == smote_config.positive_label]
    dataInput_maj = vectorized_sdf[vectorized_sdf['label'] == smote_config.negative_label]
    # LSH, bucketed random projection
    brp = BucketedRandomProjectionLSH(inputCol="features_for_model", outputCol="hashes",seed=int(smote_config.seed),\
                                      bucketLength=float(smote_config.bucketLength))
    # smote only applies on existing minority instances    
    model = brp.fit(dataInput_min)
    model.transform(dataInput_min)
    # here distance is calculated from brp's param inputCol
    self_join_w_distance = model.approxSimilarityJoin(dataInput_min, dataInput_min, float('inf'), distCol="EuclideanDistance")
    # remove self-comparison (distance 0)
    self_join_w_distance = self_join_w_distance.filter(self_join_w_distance.EuclideanDistance > 0)
    over_original_rows = Window.partitionBy("datasetA").orderBy("EuclideanDistance")
    self_similarity_df = self_join_w_distance.withColumn("r_num", F.row_number().over(over_original_rows))
    self_similarity_df_selected = self_similarity_df.filter(self_similarity_df.r_num <= int(smote_config.k))
    over_original_rows_no_order = Window.partitionBy('datasetA')
    # list to store batches of synthetic data
    res = []
    # two udf for vector add and subtract, subtraction include a random factor [0,1]
    subtract_vector_udf = F.udf(subtract_vector_fn, VectorUDT())
    add_vector_udf = F.udf(add_vector_fn, VectorUDT())
    # retain original columns
    original_cols = dataInput_min.columns
    for i in range(int(smote_config.multiplier)):
        #print("generating batch %s of synthetic instances"%i)
        # logic to randomly select neighbour: pick the largest random number generated row as the neighbour
        df_random_sel = self_similarity_df_selected\
                            .withColumn("rand", F.rand())\
                            .withColumn('max_rand', F.max('rand').over(over_original_rows_no_order))\
                            .where(F.col('rand') == F.col('max_rand')).drop(*['max_rand','rand','r_num'])
        # create synthetic feature numerical part
        df_vec_diff = df_random_sel\
            .select('*', subtract_vector_udf(F.array('datasetA.features_for_model', 'datasetB.features_for_model')).alias('vec_diff'))
        df_vec_modified = df_vec_diff\
            .select('*', add_vector_udf(F.array('datasetB.features_for_model', 'vec_diff')).alias('features_for_model'))
        # for categorical cols, either pick original or the neighbour's cat values
        for c in original_cols:
            # randomly select neighbour or original data
            col_sub = random.choice(['datasetA','datasetB'])
            val = "{0}.{1}".format(col_sub,c)
            if c != 'features_for_model':
                # do not unpack original numerical features
                df_vec_modified = df_vec_modified.withColumn(c,F.col(val))
        # this df_vec_modified is the synthetic minority instances,
        df_vec_modified = df_vec_modified.drop(*['datasetA','datasetB','vec_diff','EuclideanDistance'])
        res.append(df_vec_modified)
    dfunion = reduce(DataFrame.union, res)
    dfunion = dfunion.union(dataInput_min.select(dfunion.columns))\
        .sort(F.rand(seed=smote_config.seed))\
        .withColumn('row_number', row_number().over(Window.orderBy(lit('A'))))
    dataInput_maj = dataInput_maj.withColumn('row_number', row_number().over(Window.orderBy(lit('A'))))
    # union synthetic instances with original full (both minority and majority) df
    oversampled_df = dfunion.union(dataInput_maj.select(dfunion.columns))
    return oversampled_df.sort('row_number').drop(*['row_number'])
In [61]:
# SmoteConfig 类用于存储 SMOTE 算法的参数配置。
# 参数说明:
# seed: 随机种子,用于保证结果可复现。
# bucketLength: LSH 算法的桶长度,用于近邻搜索。 
# bucketLength 控制局部敏感哈希(LSH)算法中每个桶的宽度,影响近邻检索的精度和效率。
# 较大的 bucketLength 会使更多样本落入同一个桶,增加找到近邻的概率,但可能降低检索的精度;
# 较小的 bucketLength 则提高检索精度,但可能遗漏部分近邻。
# k: 每个少数类样本选择的最近邻数量。
# multiplier: 合成样本批次数(即生成多少倍的少数类样本)。
# positive_label: 少数类标签。
# negative_label: 多数类标签。
class SmoteConfig:
    def __init__(self, seed, bucketLength, k, multiplier, positive_label, negative_label):
        self.seed = seed
        self.bucketLength = bucketLength
        self.k = k
        self.multiplier = multiplier
        self.positive_label = positive_label
        self.negative_label = negative_label
In [62]:
def wrapper_smote(config_file):
    feature_list = ["PC1","PC2","PC3","PC4","PC5"]
    smote_df = pre_smote_df_process(train_data,feature_list,[],'Class',index_suffix="_index")
    train_smoted_df = smote(smote_df, config_file)
    rf_model_smote = train_model(train_smoted_df,smote=1)
    return rf_model_smote
In [63]:
# 使用 SMOTE 进行少数类合成采样,参数说明:
# k = 10:每个少数类样本选择 10 个最近邻
# Multiplier = 25:生成 25 倍的少数类合成样本
# bucketLength = 200:LSH 算法桶长度
# seed = 76:随机种子,保证结果可复现
config4a = SmoteConfig(76, 200, 10, 25, 1, 0)
rf_model10a = wrapper_smote(config4a)
evaluate_test_data(test_data, rf_model10a)
AUROC: 0.9513709816817386
AUPR: 0.5151196209369299
Accuracy: 0.9966772711205373
Precision: 0.3051643192488263
Recall: 0.6190476190476191
F1-score: 0.40880503144654096
No description has been provided for this image
TN: 56327
TP: 65
FP: 148
FN: 40
In [64]:
# 使用 SMOTE 进行少数类合成采样,参数说明:
# k = 5:每个少数类样本选择 5 个最近邻
# Multiplier = 25:生成 25 倍的少数类合成样本
# bucketLength = 20:LSH 算法桶长度
# seed = 76:随机种子,保证结果可复现
config4a = SmoteConfig(76, 20, 5, 25, 1, 0)
rf_model10a = wrapper_smote(config4a)
evaluate_test_data(test_data, rf_model10a)
AUROC: 0.946739096523957
AUPR: 0.5068729833024639
Accuracy: 0.9965889006716154
Precision: 0.2981651376146789
Recall: 0.6190476190476191
F1-score: 0.4024767801857585
No description has been provided for this image
TN: 56322
TP: 65
FP: 153
FN: 40
In [65]:
# 使用 SMOTE 进行少数类合成采样,参数说明:
# k = 5:每个少数类样本选择 5 个最近邻
# Multiplier = 25:生成 25 倍的少数类合成样本
# bucketLength = 1:LSH 算法桶长度
# seed = 76:随机种子,保证结果可复现
config4a = SmoteConfig(76, 1, 5, 25, 1, 0)
rf_model10a = wrapper_smote(config4a)
evaluate_test_data(test_data, rf_model10a)
AUROC: 0.9408430194565652
AUPR: 0.5155511335413011
Accuracy: 0.9974726051608342
Precision: 0.38414634146341464
Recall: 0.6
F1-score: 0.46840148698884754
No description has been provided for this image
TN: 56374
TP: 63
FP: 101
FN: 42