介绍¶
用股票涨跌预测为基础,在 PySpark 中实现, 探讨了与数据科学有关的不同主题
- 数据收集和清理
- 特征工程
- 创建机器学习模型进行预测
注意:本项目并非财务或投资建议。它不保证在大多数情况下都能提供正确的结果。
目标¶
股票价格会涨吗?
- 价格模型(一行股票指标) -> [0, 1],表示在下一时期内是否上涨
该模型用于回答上述问题。它预测在给定时间段内股票价格是否会上涨(是/否)。 建模问题的参数包括:
- 股票代码(symbol)
- 股票价格上涨的百分比
- 股票应在其时间窗口内(以天为单位)上涨
这是一个分类问题。
相关Python库¶
- yfinance 是一个开源的 Python 库,用来从 Yahoo! Inc. 的 “Yahoo Finance” 数据接口(非官方 API)获取金融市场的数据。
- finta (“Financial Technical Analysis”) 是一个专注于 技术指标计算 的 Python 库。
pip install yfinance finta
- sparkdl(全称 Spark Deep Learning Pipelines)是由 Databricks 团队开发的一个 Python 库,旨在让用户在 Apache Spark 上轻松运行深度学习任务。
- sparkdl 的设计目标是:让 TensorFlow/Keras 模型可以无缝集成进 Spark 的 DataFrame 与 ML Pipeline 中
- Databricks 官方在 2021 年之后几乎不再维护 sparkdl,也没有扩展到 PyTorch 的计划。新方向是用 TorchDistributor
- 选择一个带有 ML(机器学习运行时) 的 Databricks Runtime,其中已安装了 sparkdl 库
- 如果不是,请安装sparkdl
pip install sparkdl
In [1]:
# 初始化 findspark 以便能够找到 Spark
import findspark
findspark.init()
import warnings
warnings.filterwarnings('ignore')
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("spark://spark-master:7077").config("spark.ui.port", "8080").config("spark.acls.enable", "false").config("spark.ui.view.acls", "*").config("spark.modify.acls", "*").getOrCreate()
sc = spark.sparkContext
In [2]:
# 导入所需的库
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import yfinance as yf
from finta import TA
import datetime
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql import DataFrame
In [ ]:
# 定义用于数据挖掘的一些常量
"""
定义数据挖掘相关的常量
"""
NUM_DAYS = 10000 # 要获取的历史数据天数
INTERVAL = '1d' # 历史数据的采样频率
# 股票代码列表
# SPY: 标普500指数ETF,代表美国股市整体表现
# XOM: 埃克森美孚公司,美国最大的石油和天然气公司之一
# AAPL: 苹果公司,全球知名的科技公司
# AMD: 超威半导体公司,主要生产CPU和GPU芯片
# NVDA: 英伟达公司,全球领先的GPU和AI芯片制造商
STOCK_SYMBOLS = ['SPY', 'XOM', 'AAPL', 'AMD', 'NVDA']
In [ ]:
# 技术指标的符号列表
# RSI: 相对强弱指数,用于衡量股票价格的超买或超卖状态
# MACD: 移动平均收敛/发散指标,反映价格趋势的强度、方向和变化速度
# STOCH: 随机振荡器,评估当前价格相对于一定时期内价格区间的位置
# ADL: 累积/派发线,衡量资金流入流出
# ATR: 平均真实波幅,衡量市场波动性
# MOM: 动量指标,反映价格变动的速度
# MFI: 资金流量指标,结合价格和成交量判断买卖压力
# ROC: 变动率指标,衡量价格变动的百分比
# OBV: 能量潮指标,分析成交量变化与价格趋势的关系
# CCI: 商品通道指数,判断价格偏离均值的程度
# EMV: 易动性指标,结合价格和成交量衡量价格变动的难易程度
# VORTEX: 涡轮指标,识别趋势的开始和结束
INDICATORS = ['RSI', 'MACD', 'STOCH','ADL', 'ATR', 'MOM', 'MFI', 'ROC', 'OBV', 'CCI', 'EMV', 'VORTEX']
In [4]:
# 下载股票历史数据并合并为一个 DataFrame
# 定义下载函数,使用 yfinance 获取数据,并重命名列以适配 finta
def download_stock(symbol, num_days, interval) -> DataFrame:
"""
下载指定 `symbol` 的 `num_days` 天历史价格数据,采样间隔为 `interval`
返回下载后的 Spark DataFrame。
"""
start = (datetime.date.today() - datetime.timedelta(num_days) )
end = datetime.datetime.today()
# 使用 yfinance 下载数据
pdf = yf.download(symbol, start=start, end=end, interval=interval)
# 重命名列名为小写,便于后续指标计算
pdf.rename(columns={"Close": 'close', "High": 'high', "Low": 'low', 'Volume': 'volume', 'Open': 'open'}, inplace=True)
pdf.columns = pdf.columns.levels[0]
# 将索引转为日期列
pdf['date'] = pdf.index
prices = None
if not pdf.empty:
# 转为 Spark DataFrame,并添加 symbol 列
prices = spark.createDataFrame(pdf)
prices = prices.withColumn('symbol', F.lit(symbol) )
else:
print(f"No price data for {symbol}")
return prices
# 循环下载所有股票的数据,并合并到 price_data
price_data = None
for symbol in STOCK_SYMBOLS:
df = download_stock(symbol, NUM_DAYS, INTERVAL)
if df:
price_data = price_data.union(df) if price_data else df
[*********************100%***********************] 1 of 1 completed [*********************100%***********************] 1 of 1 completed [*********************100%***********************] 1 of 1 completed [*********************100%***********************] 1 of 1 completed [*********************100%***********************] 1 of 1 completed
In [5]:
# 显示下载的股票数据的前两行,便于快速查看数据结构和内容
df.show(2)
+-------------------+--------------------+-------------------+--------------------+----------+-------------------+------+ | close| high| low| open| volume| date|symbol| +-------------------+--------------------+-------------------+--------------------+----------+-------------------+------+ |0.03760697320103645|0.044769679366149916|0.03557698923942547| 0.04011373608373559|2714688000|1999-01-22 00:00:00| NVDA| |0.04154682159423828| 0.04202360042703383|0.03760696623676553|0.040591426301578235| 510480000|1999-01-25 00:00:00| NVDA| +-------------------+--------------------+-------------------+--------------------+----------+-------------------+------+ only showing top 2 rows
In [6]:
# 显示下载的股票数据的前两行,便于快速查看数据结构和内容
price_data.show(2)
+-----------------+-----------------+-----------------+-----------------+-------+-------------------+------+ | close| high| low| open| volume| date|symbol| +-----------------+-----------------+-----------------+-----------------+-------+-------------------+------+ |69.23340606689453|69.60084593149537|69.09803348519948|69.27208394737883|4159900|1998-06-08 00:00:00| SPY| |69.44613647460938|69.57183958992125|68.94332401336185|69.13671342153397|2725300|1998-06-09 00:00:00| SPY| +-----------------+-----------------+-----------------+-----------------+-------+-------------------+------+ only showing top 2 rows
In [7]:
# 下载并合并经济指标数据(如利率等),用于分析经济环境对股票表现的影响
# ^IRX: 13周美国国库券利率(短期利率,反映货币市场流动性和政策变化)
# ^FVX: 5年期美国国债收益率(中期利率,衡量经济增长和通胀预期)
# ^TNX: 10年期美国国债收益率(长期利率,常用于评估经济前景和风险偏好)
# ^TYX: 30年期美国国债收益率(超长期利率,反映市场对未来经济和通胀的预期)
econ_df = None
ECONOMIC_CONDITIONS = ['^IRX', '^FVX', '^TNX', '^TYX']
for symbol in ECONOMIC_CONDITIONS:
df = download_stock(symbol, NUM_DAYS, INTERVAL).select('date', 'close').withColumnRenamed('close', symbol.replace("^", "_"))
econ_df = econ_df.join(df, 'date', 'left') if econ_df else df
econ_df.orderBy('date').show(2)
[*********************100%***********************] 1 of 1 completed [*********************100%***********************] 1 of 1 completed [*********************100%***********************] 1 of 1 completed [*********************100%***********************] 1 of 1 completed
+-------------------+-----------------+------------------+-----------------+------------------+ | date| _IRX| _FVX| _TNX| _TYX| +-------------------+-----------------+------------------+-----------------+------------------+ |1998-06-08 00:00:00|4.980000019073486| 5.586999893188477|5.574999809265137|5.7870001792907715| |1998-06-09 00:00:00| 5.0|5.6020002365112305|5.574999809265137| 5.784999847412109| +-------------------+-----------------+------------------+-----------------+------------------+ only showing top 2 rows
Data quality check¶
In [8]:
# 统计每列中的 NaN 值数量
def nan_column_count(df):
"""
返回每列中 NaN 值的数量。
适用于数值类型的列。
"""
return(df.select([F.count(F.when(F.isnan(df[c]), c)).alias(c) for c in df.columns]))
# 统计每列中的 NaN 值数量(将列强制转换为 DoubleType)
def nan_column_count(df):
"""
返回每列中 NaN 值的数量。
先将列转换为 DoubleType,适用于可能为非数值类型的列。
"""
return(df.select([F.count(F.when(F.isnan(df[c].cast(T.DoubleType())), c)).alias(c) for c in df.columns]))
In [9]:
# 统计每列中的 Null 值数量
def null_column_count(df):
"""
返回每列中 Null 值的数量。
"""
return(df.select([F.count(F.when(df[c].isNull(), c)).alias(c) for c in df.columns]))
# 统计每个字符串列中的空字符串数量
def empty_string_column_count(df):
"""
返回每个字符串类型列中的空字符串数量。
"""
cols = [c[0] for c in df.dtypes if c[1] == 'string']
df = df.select(*cols)
return(df.select([F.count(F.when(F.col(c) == "", c)).alias(c) for c in df.columns]))
In [10]:
# 删除所有值为 Null 的列
def drop_null_columns(df):
"""
删除所有值均为 Null 的列,返回新的 DataFrame。
"""
nrows = df.count()
cols = [col for col in df.columns if df.where(F.col(col).isNull()).count() == nrows]
df = df.drop(*cols)
return df
In [11]:
# 检查 price_data 数据中的缺失值和空字符串情况
# 分别统计 NaN、Null 和空字符串的数量,便于后续数据清洗和处理
print("Nans")
nan_column_count(price_data).show()
print("Nulls")
null_column_count(price_data).show()
print("Empty strings")
empty_string_column_count(price_data).show()
Nans +-----+----+---+----+------+----+------+ |close|high|low|open|volume|date|symbol| +-----+----+---+----+------+----+------+ | 0| 0| 0| 0| 0| 0| 0| +-----+----+---+----+------+----+------+ Nulls +-----+----+---+----+------+----+------+ |close|high|low|open|volume|date|symbol| +-----+----+---+----+------+----+------+ | 0| 0| 0| 0| 0| 0| 0| +-----+----+---+----+------+----+------+ Empty strings +------+ |symbol| +------+ | 0| +------+
In [12]:
price_data.show()
+-----------------+-----------------+-----------------+-----------------+--------+-------------------+------+ | close| high| low| open| volume| date|symbol| +-----------------+-----------------+-----------------+-----------------+--------+-------------------+------+ |69.23340606689453|69.60084593149537|69.09803348519948|69.27208394737883| 4159900|1998-06-08 00:00:00| SPY| |69.44613647460938|69.57183958992125|68.94332401336185|69.13671342153397| 2725300|1998-06-09 00:00:00| SPY| | 69.0013427734375|69.96828984145093|68.84663124255536|69.07869853887858| 6186900|1998-06-10 00:00:00| SPY| |67.70561218261719|69.23338806448716|67.68627324740365|68.96264297149754| 8056600|1998-06-11 00:00:00| SPY| |68.34383392333984|68.34383392333984|66.99010772791432| 67.995732901659| 9779000|1998-06-12 00:00:00| SPY| | 66.5452880859375| 68.0150474275624|66.52594914723191|67.33818457286672|10234200|1998-06-15 00:00:00| SPY| | 67.5702896118164|67.58962856133324|66.68069793404207|67.08681587389557| 7471500|1998-06-16 00:00:00| SPY| |68.98198699951172| 69.1753763659247|68.03437910408809|68.16975166057718|12057700|1998-06-17 00:00:00| SPY| | 68.672607421875|68.96269165485954|68.53723477981553|68.80798006393445| 3844600|1998-06-18 00:00:00| SPY| |68.32848358154297|69.05600150724084|68.05687688928244|68.94929887813848| 3549100|1998-06-19 00:00:00| SPY| |68.65827941894531|68.94928653713808|68.32847135166018| 68.4448741989373| 5438300|1998-06-22 00:00:00| SPY| |69.47309875488281|69.57010112678296|68.91048499786199|68.96868642100206| 5004600|1998-06-23 00:00:00| SPY| |70.40435028076172|70.57895462149659| 69.2888225482889|69.62833098860672| 8291900|1998-06-24 00:00:00| SPY| |70.28791809082031|71.04453658531162|70.03571192598987|70.71472852361028| 5854800|1998-06-25 00:00:00| SPY| |70.57892608642578|70.67592845872709| 70.2297175461411|70.34612039290266| 4833100|1998-06-26 00:00:00| SPY| |70.83116912841797|71.19977833505723|70.64686452509834|70.92817155121777| 5953400|1998-06-29 00:00:00| SPY| |70.34608459472656|70.88929760318005|70.19088087802557|70.71469342189143| 4284800|1998-06-30 00:00:00| SPY| |71.16095733642578|71.35496212741931| 70.5401420052465|70.81174871263744| 3489500|1998-07-01 00:00:00| SPY| |71.29673767089844|71.31613814373405|70.92812868702168|71.21913577955596| 3503900|1998-07-02 00:00:00| SPY| |72.01455688476562|72.01455688476562|71.12213511302554|71.25793842611643| 3144400|1998-07-06 00:00:00| SPY| +-----------------+-----------------+-----------------+-----------------+--------+-------------------+------+ only showing top 20 rows
In [13]:
# 显示 price_data 的统计摘要信息,便于了解数据的分布和基本统计特征
price_data.summary().show()
+-------+------------------+-------------------+------------------+-------------------+--------------------+------+ |summary| close| high| low| open| volume|symbol| +-------+------------------+-------------------+------------------+-------------------+--------------------+------+ | count| 34277| 34277| 34277| 34277| 34277| 34277| | mean| 64.88604790785323| 65.45473909235245| 64.26958136208378| 64.87546120827079|2.2293126407795316E8| NULL| | stddev| 98.91854901772871| 99.48771643355069| 98.26014821886608| 98.89576849405182| 3.435202280062264E8| NULL| | min|0.0312795527279377|0.03259251973800169|0.0305625465934129|0.03199562382333364| 0| AAPL| | 25%| 4.239999771118164| 4.32890190400655| 4.159999847412109| 4.260000228881836| 18409400| NULL| | 50%|26.751354217529297| 27.146293540483786|26.329999923706055| 26.738335592846923| 71346400| NULL| | 75%| 84.51000213623047| 85.21276032528878| 83.65921222628269| 84.45166637855665| 307398000| NULL| | max| 673.1099853515625| 673.9500122070312| 669.97998046875| 673.530029296875| 9230856000| XOM| +-------+------------------+-------------------+------------------+-------------------+--------------------+------+
In [14]:
# 按股票代码分组,计算每只股票的平均收盘价
price_data.groupBy('symbol').avg('close').show()
+------+------------------+ |symbol| avg(close)| +------+------------------+ | SPY|186.80662559109362| | XOM| 47.48217498925526| | AAPL| 44.07443397157776| | AMD| 32.55231486295293| | NVDA| 12.30846712991666| +------+------------------+
In [15]:
# 按年和月对股票数据进行分组,计算每月的平均成交量和收盘价,并生成新的日期字段
price_data_monthly = (price_data \
.withColumn("year", F.year(F.col("date"))) \
.withColumn("month", F.month(F.col("date"))) \
.groupBy('symbol', 'year', 'month')\
.agg(F.avg('volume').alias('volume'),
F.avg('close').alias('close'),
F.concat_ws('/', F.col('year'), F.col('month')).alias('date'))
.orderBy('year', 'month'))
price_data_monthly.show(2)
+------+----+-----+--------------------+-------------------+------+ |symbol|year|month| volume| close| date| +------+----+-----+--------------------+-------------------+------+ | SPY|1998| 6| 6336805.882352941| 69.08287138097427|1998/6| | AAPL|1998| 6|1.7566145882352942E8|0.20972690336844502|1998/6| +------+----+-----+--------------------+-------------------+------+ only showing top 2 rows
In [16]:
# 将 Spark DataFrame 转换为 Pandas DataFrame,并绘制每月平均收盘价的分布图
pdf = price_data_monthly.toPandas()
sns.displot(data=pdf, x='close', hue='symbol', kde=True, bins=25, height=3, aspect=2).set(title='Stock closing prices (monthly avg)')
Out[16]:
<seaborn.axisgrid.FacetGrid at 0x70c54e064990>
In [17]:
# 绘制每月平均收盘价的折线图,按股票代码分色
g = sns.relplot(data=pdf, x='date', y='close', hue='symbol', kind='line', height=3, aspect=2).set(title='Stock closing prices (monthly avg)')
g.set(xticks=[])
Out[17]:
<seaborn.axisgrid.FacetGrid at 0x70c54d68d5d0>
In [18]:
# 绘制每月平均成交量的柱状图,按股票代码分色
plt.figure(figsize=(8,3))
g = sns.barplot(data=pdf, x='date', y='volume', hue='symbol')
g.set(title='Stock volume (monthly avg)')
for index, label in enumerate(g.get_xticklabels()):
if index % 24 == 0:
label.set_visible(True)
else:
label.set_visible(False)
g.tick_params(bottom=False)
g.plot()
Out[18]:
[]
特征工程¶
本节将介绍我们选择用于特征工程的技术指标,并简要列出 finta 支持的全部指标类型。
选择的技术指标
INDICATORS = ['RSI', 'MACD', 'STOCH','ADL', 'ATR', 'MOM', 'MFI', 'ROC', 'OBV', 'CCI', 'EMV', 'VORTEX']
技术指标的计算使用了 finta。
从支持的指标来看:Finta 支持 80 多种交易指标:
* Simple Moving Average 'SMA'
* Simple Moving Median 'SMM'
* Smoothed Simple Moving Average 'SSMA'
* Exponential Moving Average 'EMA'
* Double Exponential Moving Average 'DEMA'
* Triple Exponential Moving Average 'TEMA'
* Triangular Moving Average 'TRIMA'
* Triple Exponential Moving Average Oscillator 'TRIX'
* Volume Adjusted Moving Average 'VAMA'
* Kaufman Efficiency Indicator 'ER'
* Kaufman's Adaptive Moving Average 'KAMA'
* Zero Lag Exponential Moving Average 'ZLEMA'
* Weighted Moving Average 'WMA'
* Hull Moving Average 'HMA'
* Elastic Volume Moving Average 'EVWMA'
* Volume Weighted Average Price 'VWAP'
* Smoothed Moving Average 'SMMA'
* Fractal Adaptive Moving Average 'FRAMA'
* Moving Average Convergence Divergence 'MACD'
* Percentage Price Oscillator 'PPO'
* Volume-Weighted MACD 'VW_MACD'
* Elastic-Volume weighted MACD 'EV_MACD'
* Market Momentum 'MOM'
* Rate-of-Change 'ROC'
* Relative Strenght Index 'RSI'
* Inverse Fisher Transform RSI 'IFT_RSI'
* True Range 'TR'
* Average True Range 'ATR'
* Stop-and-Reverse 'SAR'
* Bollinger Bands 'BBANDS'
* Bollinger Bands Width 'BBWIDTH'
* Momentum Breakout Bands 'MOBO'
* Percent B 'PERCENT_B'
* Keltner Channels 'KC'
* Donchian Channel 'DO'
* Directional Movement Indicator 'DMI'
* Average Directional Index 'ADX'
* Pivot Points 'PIVOT'
* Fibonacci Pivot Points 'PIVOT_FIB'
* Stochastic Oscillator %K 'STOCH'
* Stochastic oscillator %D 'STOCHD'
* Stochastic RSI 'STOCHRSI'
* Williams %R 'WILLIAMS'
* Ultimate Oscillator 'UO'
* Awesome Oscillator 'AO'
* Mass Index 'MI'
* Vortex Indicator 'VORTEX'
* Know Sure Thing 'KST'
* True Strength Index 'TSI'
* Typical Price 'TP'
* Accumulation-Distribution Line 'ADL'
* Chaikin Oscillator 'CHAIKIN'
* Money Flow Index 'MFI'
* On Balance Volume 'OBV'
* Weighter OBV 'WOBV'
* Volume Zone Oscillator 'VZO'
* Price Zone Oscillator 'PZO'
* Elder's Force Index 'EFI'
* Cummulative Force Index 'CFI'
* Bull power and Bear Power 'EBBP'
* Ease of Movement 'EMV'
* Commodity Channel Index 'CCI'
* Coppock Curve 'COPP'
* Buy and Sell Pressure 'BASP'
* Normalized BASP 'BASPN'
* Chande Momentum Oscillator 'CMO'
* Chandelier Exit 'CHANDELIER'
* Qstick 'QSTICK'
* Twiggs Money Index 'TMF'
* Wave Trend Oscillator 'WTO'
* Fisher Transform 'FISH'
* Ichimoku Cloud 'ICHIMOKU'
* Adaptive Price Zone 'APZ'
* Squeeze Momentum Indicator 'SQZMI'
* Volume Price Trend 'VPT'
* Finite Volume Element 'FVE'
* Volume Flow Indicator 'VFI'
* Moving Standard deviation 'MSD'
* Schaff Trend Cycle 'STC'
* Mark Whistler's WAVE PM 'WAVEPM'
In [19]:
# 该函数用于计算股票的技术指标特征,并对部分特征进行归一化处理
# 输入参数 data 为单只股票的历史行情数据(pandas DataFrame)
# 输出为包含原始数据和所有技术指标特征的新 DataFrame
import re
def _get_indicator_data(data):
"""
使用 finta API 计算作为特征的技术指标
:return:
"""
# 在 pandas_udf 或 applyInPandas 这样的分布式函数中,代码会在每个 worker 进程中独立运行。
# 这时全局 import 的包在 worker 进程中未必可见,所以需要在函数内部再次 import 依赖包,确保每个 worker 都能找到依赖。
# 需要在 pandas_udf 内部导入依赖
from finta import TA
# 计算技术指标前,确保数据包含 finta 所需的标准列名
# finta 要求输入 DataFrame 至少包含 'open', 'high', 'low', 'close', 'volume' 这些列
# 如果缺失这些列,部分指标将无法计算
for indicator in INDICATORS:
ind_data = eval('TA.' + indicator + '(data)')
if not isinstance(ind_data, pd.DataFrame):
ind_data = ind_data.to_frame()
data = data.merge(ind_data, left_index=True, right_index=True)
for c in data.columns:
data.rename(columns={c: re.sub("[ \.!%\(\)]", '_', c)}, inplace=True)
# 也计算常用的移动平均线作为特征
data['ema50'] = data['close'] / data['close'].ewm(50).mean()
data['ema21'] = data['close'] / data['close'].ewm(21).mean()
data['ema15'] = data['close'] / data['close'].ewm(14).mean()
data['ema5'] = data['close'] / data['close'].ewm(5).mean()
# 与其直接使用原始成交量(volume),更合理的做法是用短期移动平均归一化成交量,消除不同时间段的量级变化影响
data['normVol'] = data['volume'] / data['volume'].ewm(5).mean()
return data
In [20]:
from pyspark.sql import DataFrame
# 定义一个函数,根据 pandas DataFrame 的列类型自动生成 Spark SQL schema 字符串
def indicator_schema(df: DataFrame, indicators_function=_get_indicator_data):
schema = ""
pdf = indicators_function(df.limit(1).toPandas())
for name, typ in zip(pdf.columns, pdf.dtypes):
if str(typ).startswith('float'):
schema += f"{name} float, "
elif str(typ).startswith('int'):
schema += f"{name} int, "
elif str(typ).startswith('datetime'):
schema += f"{name} timestamp, "
else:
schema += f"{name} string, "
return schema[:-2]
# 生成 Spark SQL schema 字符串,用于 applyInPandas 时指定输出 DataFrame 的结构
indicator_schema = indicator_schema(price_data)
print(indicator_schema)
print(price_data.columns)
close float, high float, low float, open float, volume int, date timestamp, symbol string, 14_period_RSI float, MACD float, SIGNAL float, 14_period_STOCH__K float, MFV float, 14_period_ATR float, MOM float, 14_period_MFI float, ROC float, OBV float, 20_period_CCI float, 14_period_EMV_ float, VIm float, VIp float, ema50 float, ema21 float, ema15 float, ema5 float, normVol float ['close', 'high', 'low', 'open', 'volume', 'date', 'symbol']
In [21]:
# 对每只股票分别计算技术指标特征,并合并为一个 Spark DataFrame
# 使用 applyInPandas 结合自定义函数 add_indicators 进行批量特征工程
import pandas as pd
def add_indicators(prices: pd.DataFrame) -> pd.DataFrame:
return _get_indicator_data(prices)
stock_data = price_data.groupBy('symbol').applyInPandas(add_indicators, schema=indicator_schema)
print(type(stock_data))
stock_data.show(2)
<class 'pyspark.sql.dataframe.DataFrame'> +----------+----------+----------+----------+---------+-------------------+------+-------------+------------+------------+------------------+------------+-------------+----+-------------+----+---------+-------------+--------------+----+----+---------+---------+---------+---------+---------+ | close| high| low| open| volume| date|symbol|14_period_RSI| MACD| SIGNAL|14_period_STOCH__K| MFV|14_period_ATR| MOM|14_period_MFI| ROC| OBV|20_period_CCI|14_period_EMV_| VIm| VIp| ema50| ema21| ema15| ema5| normVol| +----------+----------+----------+----------+---------+-------------------+------+-------------+------------+------------+------------------+------------+-------------+----+-------------+----+---------+-------------+--------------+----+----+---------+---------+---------+---------+---------+ |0.20451239|0.20779562| 0.2012283| 0.2026354|126627200|1998-06-08 00:00:00| AAPL| NULL| 0.0| 0.0| NULL| 16180.97| NULL|NULL| NULL|NULL| NULL| NULL| NULL|NULL|NULL| 1.0| 1.0| 1.0| 1.0| 1.0| |0.21201693|0.21389307|0.20545046|0.20545046|275744000|1998-06-09 00:00:00| AAPL| 100.0|1.6837104E-4|9.3539464E-5| NULL|1.53206384E8| NULL|NULL| NULL|NULL|2.75744E8| 66.666664| NULL|NULL|NULL|1.0178353|1.0175905|1.0173848|1.0163522|1.3259242| +----------+----------+----------+----------+---------+-------------------+------+-------------+------------+------------+------------------+------------+-------------+----+-------------+----+---------+-------------+--------------+----+----+---------+---------+---------+---------+---------+ only showing top 2 rows
标注训练数据¶
设置这些参数用于数据分析:
- 窗口大小(window size)
- 股票价格百分比变化(stock price percent change)
- 预测标签列(prediction label column)
- 如果股票在窗口期内至少上涨了设定的百分比,则为 1.0
- 否则为 0.0
In [22]:
from pyspark.sql.window import Window
# 标注训练数据:为每一行计算未来 window_size 天后的收盘价涨跌幅,并根据涨幅是否超过 percent_change 生成标签
window_size=365
percent_change = 0.1
label = 'label'
# 计算当前行与未来 window_size 天收盘价的百分比变化
# 如果涨幅超过 percent_change,则标记为 1.0,否则为 0.0
windowSpec = Window.partitionBy("symbol").orderBy("date")
df =stock_data\
.withColumn("lead",F.lead("close",window_size).over(windowSpec))\
.withColumn('p_change', (F.col('lead') - F.col('close')) / F.col('close')) \
.withColumn(label, F.when(F.col('p_change') > F.lit(percent_change), F.lit(1.0)).otherwise(F.lit(0.0)))
# Drop null rows
labeled_data = df.dropna()
labeled_data.where('p_change < 0').select('symbol', 'date', 'close', 'lead', 'p_change', label).show(2)
+------+-------------------+---------+----------+-------------------+-----+ |symbol| date| close| lead| p_change|label| +------+-------------------+---------+----------+-------------------+-----+ | AAPL|1999-04-30 00:00:00| 0.345231| 0.3264689|-0.0543464580115901| 0.0| | AAPL|1999-05-03 00:00:00|0.3719677|0.31333515|-0.1576280739818759| 0.0| +------+-------------------+---------+----------+-------------------+-----+ only showing top 2 rows
Data balancing¶
In [23]:
labeled_data.groupBy('symbol', 'label').count().show()
+------+-----+-----+ |symbol|label|count| +------+-----+-----+ | AAPL| 1.0| 4738| | AAPL| 0.0| 1735| | XOM| 1.0| 3223| | XOM| 0.0| 3238| | AMD| 1.0| 3505| | AMD| 0.0| 2857| | NVDA| 1.0| 4337| | NVDA| 0.0| 1962| | SPY| 1.0| 4180| | SPY| 0.0| 2302| +------+-----+-----+
In [24]:
from pyspark.sql import DataFrame
import pyspark.sql.functions as F
import pyspark.sql.types as T
import pandas as pd
def class_weights(df:DataFrame, column='Class'):
"""
计算给定 `column` 的类别平衡权重。
"""
class_counts = list(df.groupBy(column).count().toPandas().to_dict(orient='list').values())
num_classes = len(class_counts[0])
total_samples = sum(class_counts[1])
pairs = list(zip(class_counts[0], class_counts[1]))
return {k:total_samples/(num_classes * v) for (k, v) in pairs}
def weight_column(classes: pd.Series, class_weights) -> pd.Series:
"""
该UDF用于计算类别权重。
参数 `class_weights` 是预先计算好的类别权重字典。
用法示例:sdf.withColumn('weight', weight_column(sdf['Class'], class_weights(sdf, column='Class')))
"""
@F.pandas_udf(T.DoubleType())
def weight_column_udf(classes: pd.Series) -> pd.Series:
"""
UDF to compute rebalancing weights for the given `classes` (as a column).
Depends on the within lexically scoped variable, class_weights, which contains a dictionary of weights.
"""
result = []
for _, value in classes.items():
result += [class_weights[value]]
return pd.Series(result)
return weight_column_udf(classes)
In [25]:
# 为每只股票分别计算类别权重,并为每行添加权重列
# 这样可以在训练模型时对类别不平衡进行调整
def add_weight_column(df: DataFrame, symbols=STOCK_SYMBOLS) -> DataFrame:
result_df =None
for symbol in symbols:
dfs = df.where(f"symbol = '{symbol}'")
dfs_weighted = dfs.withColumn('weight', weight_column(dfs['label'], class_weights(dfs, column='label')))
result_df = result_df.union(dfs_weighted) if result_df else dfs_weighted
return result_df
In [26]:
# 为训练数据添加类别权重列,以便后续模型训练时对类别不平衡进行调整
df_weighted = add_weight_column(labeled_data)
df_weighted.groupBy('symbol', 'label', 'weight').count().show()
+------+-----+------------------+-----+ |symbol|label| weight|count| +------+-----+------------------+-----+ | SPY| 1.0|0.7779644743158906| 4166| | SPY| 0.0|1.3993955094991364| 2316| | AMD| 1.0|0.9110824742268041| 3491| | AMD| 0.0|1.1081504702194358| 2871| | XOM| 1.0|0.9979919678714859| 3237| | XOM| 0.0| 1.002016129032258| 3224| | AAPL| 0.0|1.8502001143510578| 1749| | AAPL| 1.0|0.6851577387253864| 4723| | NVDA| 0.0|1.6052497451580021| 1962| | NVDA| 1.0|0.7261932211205903| 4337| +------+-----+------------------+-----+
In [27]:
# 尝试加权和不加权训练模型
USE_WEIGHTS = True
if USE_WEIGHTS:
labeled_data = df_weighted
else:
labeled_data = labeled_data.withColumn('weight', F.lit(1.0))
Model training¶
In [28]:
# 模型训练前的特征处理和管道设置
# 包括特征选择、向量化、标准化等步骤
from pyspark.ml.feature import StandardScaler, VectorAssembler
from pyspark.ml import Pipeline, Model
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, GBTClassifier, MultilayerPerceptronClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from xgboost.spark import SparkXGBClassifier
# 所有特征列作为自变量,排除以下列:
# ['date', 'symbol', 'lead', 'p_change', label, 'weight']
x_cols = list(set(labeled_data.columns) - {'date', 'symbol', 'lead', 'p_change', label, 'weight'})
# Pipeline basic to be shared across model fitting and testing
pipeline = Pipeline(stages=[]) # Must initialize with empty list!
# base pipeline (the processing here is reused across pipelines)
basePipeline =[VectorAssembler(inputCols=[*x_cols], outputCol='vector_features', handleInvalid='error'),
StandardScaler(inputCol='vector_features', outputCol='features')]
# One grid from the individual grids
param_grid = []
In [ ]:
# 逻辑回归是一种广泛应用于二分类问题的统计学习方法。
# Logistic Regression model for binary classification
# 使用 LogisticRegression 作为分类器,并将其加入参数网格
lr = LogisticRegression(featuresCol="features", labelCol=label, predictionCol="prediction", weightCol='weight')
param_grid += ParamGridBuilder()\
.baseOn({pipeline.stages: basePipeline + [lr]})\
.build()
In [ ]:
# 随机森林是一种集成学习方法,通过构建多个决策树并进行投票或平均来提高预测准确性和稳定性。
# 它属于Bagging(Bootstrap Aggregating)类算法。
# 随机森林分类器(RandomForestClassifier)用于二分类任务
rf = RandomForestClassifier(featuresCol="features", labelCol=label, predictionCol="prediction", weightCol='weight')
param_grid += ParamGridBuilder()\
.baseOn({pipeline.stages: basePipeline + [rf]})\
.build()
Boosting 和 Bagging 的区别¶
| 特性 | Bagging | Boosting |
|---|---|---|
| 代表算法 | 随机森林 | GBDT, XGBoost, AdaBoost |
| 训练方式 | 并行训练 | 顺序训练 |
| 样本权重 | 等权重采样 | 动态调整权重 |
| 基学习器关系 | 相互独立 | 相互依赖 |
In [ ]:
# 梯度提升树(GBT) 是一种强大的集成学习方法,
# 通过顺序构建多个弱学习器(通常是决策树)来逐步减少残差。
# 它属于Boosting算法家族,与随机森林的Bagging思想形成对比。
# Gradient Boosted Trees (GBTClassifier) 用于二分类任务,适合处理非线性关系
gb = GBTClassifier(featuresCol="features", labelCol=label, predictionCol="prediction", weightCol='weight')
param_grid += ParamGridBuilder()\
.baseOn({pipeline.stages: basePipeline + [gb]})\
.build()
In [ ]:
# XGBoost(eXtreme Gradient Boosting)是一种基于梯度提升框架的机器学习算法,
# 主要用于分类和回归问题。
# 它通过集成多个弱学习器(通常是决策树)来构建一个强学习器。X
# GBoost在梯度提升决策树(GBDT)的基础上进行了优化,引入了二阶导数、正则化、并行处理等技术,从而在性能和速度上都有显著提升。
# XGBoost 分类器(SparkXGBClassifier)用于二分类任务
xgb = SparkXGBClassifier(features_col="features", label_col=label, prediction_col="prediction", weight_col='weight',\
missing=-999, num_workers=2)
param_grid += ParamGridBuilder()\
.baseOn({pipeline.stages: basePipeline + [xgb]})\
.build()
In [ ]:
# MLP是一种前馈人工神经网络,由输入层、一个或多个隐藏层和输出层组成。
# MLP是深度学习中最基础的模型之一。
# 指定神经网络的层结构:
# 输入层大小必须等于特征数量,中间层可自定义,输出层为2(用于二分类)。
num_inputs = len(basePipeline[0].getInputCols())
pc = MultilayerPerceptronClassifier(featuresCol="features", labelCol=label, predictionCol="prediction", maxIter=100, blockSize=128, seed=1234)
param_grid += ParamGridBuilder()\
.baseOn({pipeline.stages: basePipeline + [pc]})\
.addGrid(pc.layers, [[num_inputs, 5, 5, 2]]) \
.build()
In [34]:
# 定义评估器,使用 F1 分数作为模型评估指标
evaluator = MulticlassClassificationEvaluator()\
.setMetricName('f1')\
.setPredictionCol("prediction")\
.setLabelCol(label)
In [35]:
# 定义用于交叉验证训练股票模型的函数
# data: 输入的训练数据
# pipeline: 机器学习管道
# evaluator: 评估器(如F1分数)
# param_map: 参数网格(模型和参数组合)
# folds: 交叉验证折数,默认为3
def cv_stocks(data, pipeline, evaluator, param_map, folds=3):
spark.conf.set("spark.sql.shuffle.partitions", "16")
# set to number of cores on cluster
cv = CrossValidator()\
.setParallelism(2) \
.setEstimator(pipeline)\
.setEvaluator(evaluator)\
.setEstimatorParamMaps(param_grid) \
.setNumFolds(folds)
# 缓存数据以提升交叉验证效率,避免重复计算
return cv.fit(data.repartition(16).cache())
In [36]:
# 对每只股票(NVDA、SPY)和全市场(ALL)分别训练模型
# 1. 对于每只股票,先筛选出对应的训练数据,再进行训练和交叉验证
# 2. 对于全市场模型,直接用全部数据训练
# 全市场模型(ALL)是指用所有股票的数据(不区分具体股票代码)一起训练的模型。
# 它可以捕捉整体市场的共性规律,而不是只针对某一只股票的特性进行建模。
# 这样训练出的模型可以用于泛化预测,适用于任意股票的涨跌趋势判断。
symbols = ['NVDA', 'SPY', None]
models = {}
for symbol in symbols:
if symbol:
# Stock specific modeling
train_data, test_data = labeled_data.where(f"symbol = '{symbol}'").randomSplit([0.6, 0.4], 24) # proportions [], seed for random
models[symbol] = cv_stocks(train_data, pipeline, evaluator, param_grid)
else:
train_data, test_data = labeled_data.randomSplit([0.6, 0.4], 24) # proportions [], seed for random
models['ALL'] = cv_stocks(train_data, pipeline, evaluator, param_grid)
2025-10-22 06:31:44,183 INFO XGBoost-PySpark: _fit Running xgboost-3.1.0 on 2 workers with
booster params: {'device': 'cpu', 'objective': 'binary:logistic', 'nthread': 1}
train_call_kwargs_params: {'verbose_eval': True, 'num_boost_round': 100}
dmatrix_kwargs: {'nthread': 1, 'missing': -999.0}
2025-10-22 06:31:51,423 INFO XGBoost-PySpark: _fit Finished xgboost training!
2025-10-22 06:32:15,199 INFO XGBoost-PySpark: _fit Running xgboost-3.1.0 on 2 workers with
booster params: {'device': 'cpu', 'objective': 'binary:logistic', 'nthread': 1}
train_call_kwargs_params: {'verbose_eval': True, 'num_boost_round': 100}
dmatrix_kwargs: {'nthread': 1, 'missing': -999.0}
2025-10-22 06:32:19,773 INFO XGBoost-PySpark: _fit Finished xgboost training!
2025-10-22 06:32:40,123 INFO XGBoost-PySpark: _fit Running xgboost-3.1.0 on 2 workers with
booster params: {'device': 'cpu', 'objective': 'binary:logistic', 'nthread': 1}
train_call_kwargs_params: {'verbose_eval': True, 'num_boost_round': 100}
dmatrix_kwargs: {'nthread': 1, 'missing': -999.0}
2025-10-22 06:32:44,429 INFO XGBoost-PySpark: _fit Finished xgboost training!
2025-10-22 06:32:50,578 INFO XGBoost-PySpark: _fit Running xgboost-3.1.0 on 2 workers with
booster params: {'device': 'cpu', 'objective': 'binary:logistic', 'nthread': 1}
train_call_kwargs_params: {'verbose_eval': True, 'num_boost_round': 100}
dmatrix_kwargs: {'nthread': 1, 'missing': -999.0}
2025-10-22 06:32:54,551 INFO XGBoost-PySpark: _fit Finished xgboost training!
2025-10-22 06:33:10,109 INFO XGBoost-PySpark: _fit Running xgboost-3.1.0 on 2 workers with
booster params: {'device': 'cpu', 'objective': 'binary:logistic', 'nthread': 1}
train_call_kwargs_params: {'verbose_eval': True, 'num_boost_round': 100}
dmatrix_kwargs: {'nthread': 1, 'missing': -999.0}
2025-10-22 06:33:14,576 INFO XGBoost-PySpark: _fit Finished xgboost training!
2025-10-22 06:33:33,294 INFO XGBoost-PySpark: _fit Running xgboost-3.1.0 on 2 workers with
booster params: {'device': 'cpu', 'objective': 'binary:logistic', 'nthread': 1}
train_call_kwargs_params: {'verbose_eval': True, 'num_boost_round': 100}
dmatrix_kwargs: {'nthread': 1, 'missing': -999.0}
2025-10-22 06:33:37,635 INFO XGBoost-PySpark: _fit Finished xgboost training!
2025-10-22 06:33:54,221 INFO XGBoost-PySpark: _fit Running xgboost-3.1.0 on 2 workers with
booster params: {'device': 'cpu', 'objective': 'binary:logistic', 'nthread': 1}
train_call_kwargs_params: {'verbose_eval': True, 'num_boost_round': 100}
dmatrix_kwargs: {'nthread': 1, 'missing': -999.0}
2025-10-22 06:33:58,526 INFO XGBoost-PySpark: _fit Finished xgboost training!
2025-10-22 06:34:03,478 INFO XGBoost-PySpark: _fit Running xgboost-3.1.0 on 2 workers with
booster params: {'device': 'cpu', 'objective': 'binary:logistic', 'nthread': 1}
train_call_kwargs_params: {'verbose_eval': True, 'num_boost_round': 100}
dmatrix_kwargs: {'nthread': 1, 'missing': -999.0}
2025-10-22 06:34:07,541 INFO XGBoost-PySpark: _fit Finished xgboost training!
2025-10-22 06:34:22,186 INFO XGBoost-PySpark: _fit Running xgboost-3.1.0 on 2 workers with
booster params: {'device': 'cpu', 'objective': 'binary:logistic', 'nthread': 1}
train_call_kwargs_params: {'verbose_eval': True, 'num_boost_round': 100}
dmatrix_kwargs: {'nthread': 1, 'missing': -999.0}
2025-10-22 06:34:26,919 INFO XGBoost-PySpark: _fit Finished xgboost training!
2025-10-22 06:34:41,446 INFO XGBoost-PySpark: _fit Running xgboost-3.1.0 on 2 workers with
booster params: {'device': 'cpu', 'objective': 'binary:logistic', 'nthread': 1}
train_call_kwargs_params: {'verbose_eval': True, 'num_boost_round': 100}
dmatrix_kwargs: {'nthread': 1, 'missing': -999.0}
2025-10-22 06:34:46,108 INFO XGBoost-PySpark: _fit Finished xgboost training!
2025-10-22 06:35:00,518 INFO XGBoost-PySpark: _fit Running xgboost-3.1.0 on 2 workers with
booster params: {'device': 'cpu', 'objective': 'binary:logistic', 'nthread': 1}
train_call_kwargs_params: {'verbose_eval': True, 'num_boost_round': 100}
dmatrix_kwargs: {'nthread': 1, 'missing': -999.0}
2025-10-22 06:35:05,300 INFO XGBoost-PySpark: _fit Finished xgboost training!
2025-10-22 06:35:10,381 INFO XGBoost-PySpark: _fit Running xgboost-3.1.0 on 2 workers with
booster params: {'device': 'cpu', 'objective': 'binary:logistic', 'nthread': 1}
train_call_kwargs_params: {'verbose_eval': True, 'num_boost_round': 100}
dmatrix_kwargs: {'nthread': 1, 'missing': -999.0}
2025-10-22 06:35:14,741 INFO XGBoost-PySpark: _fit Finished xgboost training!
Model training summary¶
In [37]:
# 输出每个模型的名称、模型数量以及各折交叉验证的平均指标
for name in models:
print(f"Model {name}")
num_models = len(models[name].getEstimatorParamMaps())
print(f"Ran {num_models} models")
print(f"Average model metric over the folds are: {models[name].avgMetrics}")
Model NVDA Ran 5 models Average model metric over the folds are: [np.float64(0.7367919029645188), np.float64(0.8715688664909829), np.float64(0.9331476566258687), np.float64(0.9720014074330402), np.float64(0.7656554905392223)] Model SPY Ran 5 models Average model metric over the folds are: [np.float64(0.6430573528965232), np.float64(0.8562090782446531), np.float64(0.926819987221774), np.float64(0.9590439416105984), np.float64(0.6809762589791603)] Model ALL Ran 5 models Average model metric over the folds are: [np.float64(0.5521539711297984), np.float64(0.7553098877994303), np.float64(0.8506427534689439), np.float64(0.951477063359074), np.float64(0.6109382841172067)]
In [38]:
import re
from pyspark.ml.tuning import CrossValidatorModel
# 定义函数,用于从参数网格中提取模型名称和参数
def paramGrid_model_name(model):
params = [v for v in model.values() if type(v) is not list]
name = [v[-1] for v in model.values() if type(v) is list][0]
name = re.match(r'([a-zA-Z]*)', str(name)).groups()[0]
return f"{name}{params}"
def cv_metrics(cv: CrossValidatorModel):
"""
返回 CrossValidator 运行多个模型(参数管道)时的指标和模型名称。
"""
# Resulting metric and model description
# get the metric from the CrossValidator's resulting avgMetrics
# get the model name & params from the paramGrid
# put them together here:
measures = zip(cv.avgMetrics, [paramGrid_model_name(m) for m in cv.getEstimatorParamMaps()])
metrics, model_names = zip(*measures)
return metrics, model_names
In [39]:
# 绘制交叉验证各模型的评估指标(如F1分数)柱状图
# 输入参数 fitted_cv: 训练好的 CrossValidatorModel
# title: 图表标题
def cv_plot_metrics(fitted_cv: CrossValidatorModel, title="Model Metrics"):
metrics, model_names = cv_metrics(fitted_cv)
metric_name = fitted_cv.getEvaluator().getMetricName()
sns.set_context('notebook')
sns.set_style('white')
sns.set_palette("bright")
def add_metric_labels(metrics):
for i in range(len(metrics)):
plt.text(i, metrics[i], f"{metrics[i]:.3f}", ha = 'center', fontsize=10)
plt.figure(figsize=(6, 3))
pdf = pd.DataFrame(zip(metrics, model_names), columns=[metric_name, 'model'])
sns.barplot(data=pdf, x='model', y=metric_name).set(title=title)
plt.xticks(rotation=45)
add_metric_labels(metrics)
In [40]:
for name in models:
cv_plot_metrics(models[name], title=f"{name} models")
In [41]:
names = iter(models.keys())
name = next(names)
cv_plot_metrics(models[name], title=f"{name} models")
In [42]:
name = next(names)
cv_plot_metrics(models[name], title=f"{name} models")
In [43]:
name = next(names)
cv_plot_metrics(models[name], title=f"{name} models")
Model evaluation¶
In [ ]:
def classification_metrics(predictions: DataFrame, label='label', prediction='prediction'):
"""
给定预测结果的汇总指标(如ROC、PR等)。
返回指标字典和混淆矩阵。
提供不同beta值下的F-measure,参考:https://machinelearningmastery.com/fbeta-measure-for-machine-learning
较小的beta值(如0.5)更侧重于精确率,较大beta值(如2.0)更侧重于召回率。
对于股票预测,我们更关注精确率(即正确预测上涨的比例)。
"""
metrics_b = BinaryClassificationMetrics(predictions.select(label, prediction).rdd.map(tuple))
metrics = {}
metrics['PR AUC'] = metrics_b.areaUnderPR
metrics['ROC AUC'] = metrics_b.areaUnderROC
metrics_m = MulticlassMetrics(predictions.select(label, prediction).rdd.map(tuple))
# F0.5 = (1 + 0.5²) × (Precision × Recall) / (0.5² × Precision + Recall)
# F0.5 假阳性代价高的场景
metrics['F0.5 Score'] = metrics_m.fMeasure(label=1.0, beta=0.5)
metrics['F1 Score'] = metrics_m.fMeasure(label=1.0, beta=1.0)
metrics['F2 Score'] = metrics_m.fMeasure(label=1.0, beta=2.0)
metrics['Recall'] = metrics_m.recall(label=1)
metrics['Precision'] = metrics_m.precision(1)
metrics['Accuracy'] = metrics_m.accuracy
return metrics, metrics_m.confusionMatrix().toArray()
In [45]:
def precision_recall(pred_df: DataFrame, col_name):
"""
Output: precision, recall used in evaluation function
"""
rdd_pred = pred_df.select([col_name, 'label']).rdd
metrics_m = MulticlassMetrics(rdd_pred)
precision = metrics_m.precision(1)
recall = metrics_m.recall(label=1)
f2 = metrics_m.fMeasure(1.0, 2.0)
f1 = metrics_m.fMeasure(1.0, 1.0)
f05 = metrics_m.fMeasure(1.0, 0.5)
return (precision, recall, f2, f1, f05)
In [46]:
# https://machinelearningmastery.com/threshold-moving-for-imbalanced-classification/
def threshold_tuning(valid_df: DataFrame):
"""
输入: 一个验证后的数据框(valid_df)
输出: 一个包含从0到1的阈值及其对应的精确率/召回率/F2分数的Pandas数据框
"""
pr_results = []
preds_new = valid_df
preds_new = preds_new.withColumn('pred_probability', firstelement('probability'))
thresholds = np.arange(start=0.1, stop=1.1, step=0.1)
c = ['c1','c2','c3','c4','c5','c6','c7','c8','c9','c10']
i=0
for threshold in thresholds:
preds_new = preds_new.withColumn(c[i], F.when(preds_new["pred_probability"].cast(T.DoubleType()) >= threshold , 1.0).otherwise(0.0).cast(T.DoubleType()))
i = i+1
for i in range(len(thresholds)-1):
precision, recall, f2, f1, f05 = precision_recall(preds_new, c[i])
pr_results.append((thresholds[i], precision, recall, f2, f1, f05))
pr_df = pd.DataFrame(pr_results).rename(columns={0:'Threshold',1:'Precision',2:'Recall', 3:'f2-score', 4:'f1-score', 5:'f0.5-score'})
return pr_df
In [ ]:
from pyspark.sql.functions import udf
from pyspark.mllib.evaluation import BinaryClassificationMetrics, MulticlassMetrics
import sparkdl
# Function to graph first position of the dense vector probability
# Used in threshold_tuning function
# 从输入数组中提取索引为1的元素(即第二个元素)
firstelement = udf(lambda item:float(item[1]),T.FloatType())
In [51]:
class CurveMetrics(BinaryClassificationMetrics):
"""
Helper function to plot roc curve
"""
def __init__(self, *args):
super(CurveMetrics, self).__init__(*args)
def _to_list(self, rdd):
points = []
# Note this collect could be inefficient for large datasets
# considering there may be one probability per datapoint (at most)
# The Scala version takes a numBins parameter,
# but it doesn't seem possible to pass this from Python to Java
for row in rdd.collect():
# Results are returned as type scala.Tuple2,
# which doesn't appear to have a py4j mapping
points += [(float(row._1()), float(row._2()))]
return points
def get_curve(self, method):
rdd = getattr(self._java_model, method)().toJavaRDD()
return self._to_list(rdd)
In [52]:
def evaluate(preds_train: DataFrame, preds_valid: DataFrame, model, label='label', features='features', cm_percent=True):
"""
Input: predicted model for train and validation set, model
Output: PySpark DataFrame of evaluation metrics for class 0, 1
Confusion Matrix, PR-Curve
"""
plt.rcParams.update({
'font.size': 6, # 默认字体
'axes.titlesize': 6, # 子图标题
'axes.labelsize': 8, # 坐标轴标签
'xtick.labelsize': 7, # X轴刻度
'ytick.labelsize': 7, # Y轴刻度
'legend.fontsize': 7, # 图例字体
})
tr_metrics, _ = classification_metrics(preds_train)
ts_metrics, confusion_matrix = classification_metrics(preds_valid)
preds_valid_pr = threshold_tuning(preds_valid)
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(8, 2))
# plot the Precision-Recall curve
sns.set(font_scale=1, style='whitegrid')
sns.lineplot(x='Recall', y='Precision', data=preds_valid_pr, label='PR Curve', ax=axes[0])
axes[0].set_title('Precision-Recall Curve', fontsize=5)
axes[0].legend()
# https://machinelearningmastery.com/roc-curves-and-precision-recall-curves-for-classification-in-python/
# False Positive Rate = 1 - Specificity
# Used by plot ROC AUC
rdd_valid_b = preds_valid.select(label,'probability').rdd.map(lambda row: (float(row['probability'][1]), float(row[label])))
# plot ROC AUC
metrics_valid = CurveMetrics(rdd_valid_b)
points_roc = metrics_valid.get_curve('roc')
x_val = [x[0] for x in points_roc]
y_val = [x[1] for x in points_roc]
sns.lineplot(x=x_val, y=y_val, color='lightsteelblue',label='ROC AUC',ax= axes[1])
# Get the xy data from the lines so that we can shade
l1 = axes[1].lines[0]
x1 = l1.get_xydata()[:,0]
y1 = l1.get_xydata()[:,1]
axes[1].fill_between(x1,y1, color="lightblue", alpha=0.3)
axes[1].set_ylim([0.1, 1])
axes[1].set_xlabel('FPR (1-Specificity)')
axes[1].set_ylabel('TPR (Recall)')
axes[1].set_title('ROC AUC curve (Validation)', fontsize=5)
axes[1].legend()
# Plot confusion matrix
cm = confusion_matrix
confusion_matrix = pd.DataFrame(cm)
if cm_percent:
sns.heatmap(confusion_matrix/np.sum(confusion_matrix), annot=True, fmt=".1%", linewidth=0.5, cmap='Blues', ax=axes[2])
else:
sns.heatmap(confusion_matrix, annot=True, fmt=",.1f", linewidth=0.5, annot_kws={'fontsize':5}, cmap='RdBu', ax=axes[2])
size = int(preds_valid.count())
size = f'{size:,}'
axes[2].set_title('Confusion Matrix - N={}'.format(size), fontsize=5)
axes[2].set_ylabel('Actual Values')
axes[2].set_xlabel('Predicted Values')
# plt.show()
return tr_metrics, ts_metrics, model
def text_print(tr_metrics, ts_metrics, model):
print(str(model))
print(f"{'Metric': <10} {'Train': >7} {'Test': >7}")
for key in tr_metrics.keys():
print(f"{key: <10} {tr_metrics[key]: >7,.4f} {ts_metrics[key]: >7,.4f}")
print(' Validation Plots')
In [53]:
# Using F-measure (beta =0.5, to weight precision more because don't wan't false positives)
for name in models:
#print(f"Evaluating for symbol {name}")
if name != "ALL":
# Stock specific modeling
train_data, test_data = labeled_data.where(f"symbol = '{name}'").randomSplit([0.6, 0.4], 24) # proportions [], seed for random
train_predictions = models[name].bestModel.transform(train_data)
test_predictions = models[name].bestModel.transform(test_data)
tr_metrics, ts_metrics, m = evaluate(train_predictions, test_predictions, models[name].bestModel.stages[-1], features='vector_features')
text_print(tr_metrics, ts_metrics, m)
else:
train_data, test_data = labeled_data.randomSplit([0.6, 0.4], 24) # proportions [], seed for random
train_predictions = models[name].bestModel.transform(train_data)
test_predictions = models[name].bestModel.transform(test_data)
tr_metrics, ts_metrics, m = evaluate(train_predictions, test_predictions, models[name].bestModel.stages[-1], features='vector_features')
text_print(tr_metrics, ts_metrics, m)
SparkXGBClassifier_78763f166ecc
Metric Train Test
PR AUC 0.7610 0.8227
ROC AUC 0.7077 0.7509
F0.5 Score 0.8887 0.8751
F1 Score 0.8888 0.8613
F2 Score 0.8889 0.8478
Recall 0.8890 0.8391
Precision 0.8886 0.8846
Accuracy 0.8481 0.8014
Validation Plots
SparkXGBClassifier_78763f166ecc
Metric Train Test
PR AUC 0.9130 0.9743
ROC AUC 0.7707 0.9650
F0.5 Score 0.9307 0.9772
F1 Score 0.9329 0.9730
F2 Score 0.9351 0.9688
Recall 0.9366 0.9660
Precision 0.9292 0.9800
Accuracy 0.9130 0.9653
Validation Plots
SparkXGBClassifier_78763f166ecc
Metric Train Test
PR AUC 0.8662 0.8169
ROC AUC 0.8161 0.7306
F0.5 Score 0.7816 0.8695
F1 Score 0.7744 0.8242
F2 Score 0.7673 0.7835
Recall 0.7626 0.7585
Precision 0.7865 0.9025
Accuracy 0.7131 0.7620
Validation Plots
In [54]:
names = iter(models.keys())
name = next(names)
print(f"Evaluating for symbol {name}")
# Stock specific modeling
train_data, test_data = labeled_data.where(f"symbol = '{name}'").randomSplit([0.6, 0.4], 24) # proportions [], seed for random
train_predictions = models[name].bestModel.transform(train_data)
test_predictions = models[name].bestModel.transform(test_data)
tr_metrics, ts_metrics, m = evaluate(train_predictions, test_predictions, models[name].bestModel.stages[-1], features='vector_features')
Evaluating for symbol NVDA
In [55]:
text_print(tr_metrics, ts_metrics, m)
SparkXGBClassifier_78763f166ecc
Metric Train Test
PR AUC 0.7806 0.7912
ROC AUC 0.7265 0.7498
F0.5 Score 0.8374 0.9304
F1 Score 0.8342 0.9177
F2 Score 0.8310 0.9053
Recall 0.8289 0.8972
Precision 0.8395 0.9391
Accuracy 0.7721 0.8822
Validation Plots
In [56]:
name = next(names)
print(f"Evaluating for symbol {name}")
# Stock specific modeling
train_data, test_data = labeled_data.where(f"symbol = '{name}'").randomSplit([0.6, 0.4], 24) # proportions [], seed for random
train_predictions = models[name].bestModel.transform(train_data)
test_predictions = models[name].bestModel.transform(test_data)
tr_metrics, ts_metrics, m = evaluate(train_predictions, test_predictions, models[name].bestModel.stages[-1], features='vector_features')
Evaluating for symbol SPY
In [57]:
text_print(tr_metrics, ts_metrics, m)
SparkXGBClassifier_78763f166ecc
Metric Train Test
PR AUC 0.9202 0.9014
ROC AUC 0.9034 0.7624
F0.5 Score 1.0000 0.8811
F1 Score 1.0000 0.8302
F2 Score 1.0000 0.7848
Recall 1.0000 0.7572
Precision 1.0000 0.9187
Accuracy 1.0000 0.7595
Validation Plots
In [58]:
name = next(names)
print(f"Evaluating for symbol {name}")
train_data, test_data = labeled_data.randomSplit([0.6, 0.4], 24) # proportions [], seed for random
train_predictions = models[name].bestModel.transform(train_data)
test_predictions = models[name].bestModel.transform(test_data)
tr_metrics, ts_metrics, m = evaluate(train_predictions, test_predictions, models[name].bestModel.stages[-1], features='vector_features')
Evaluating for symbol ALL
In [59]:
text_print(tr_metrics, ts_metrics, m)
SparkXGBClassifier_78763f166ecc
Metric Train Test
PR AUC 0.8331 0.8546
ROC AUC 0.7607 0.7996
F0.5 Score 0.8452 0.8571
F1 Score 0.8276 0.8308
F2 Score 0.8108 0.8060
Recall 0.8000 0.7903
Precision 0.8573 0.8756
Accuracy 0.7768 0.7794
Validation Plots
结论¶
该数据集表现最好的分类器是梯度提升(XgboostClassifier)
- 无论是否使用权重,表现都最佳
它能够在超过 90% 的情况下正确预测股票上涨
改进方向¶
使用模型选择下一时间窗口的股票
将分析应用于一组股票
选择同时满足以下条件的股票:
模型最准确(模型置信度更高)
最近时间窗口中正向预测次数最多
价格涨幅最大(投资回报率最高)
使用时间序列模型(替代或结合当前方法)进行同类分析
增加指标维度
- 引入反映经济环境的其他指标,例如利率等