SARIMA模型超参数优化:网格搜索与时间序列预测
from math import sqrt
from multiprocessing import cpu_count
from joblib import Parallel
from joblib import delayed
from warnings import catch_warnings
from warnings import filterwarnings
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_squared_error
from pandas import read_csv
# 进行一步SARIMA预测
def sarima_forecast(history, config):
order, sorder, trend = config
# 定义模型
model = SARIMAX(history, order=order, seasonal_order=sorder, trend=trend, enforce_stationarity=False, enforce_invertibility=False)
# 拟合模型
model_fit = model.fit(disp=False)
# 进行一步预测
yhat = model_fit.predict(len(history), len(history))
return yhat[0]
# 计算均方根误差(RMSE)
def measure_rmse(actual, predicted):
return sqrt(mean_squared_error(actual, predicted))
# 将时间序列数据划分为训练集和测试集
def train_test_split(data, n_test):
return data[:-n_test], data[-n_test:]
# 逐步前向验证
def walk_forward_validation(data, n_test, cfg):
predictions = list()
# 划分数据集
train, test = train_test_split(data, n_test)
# 使用训练集初始化历史数据
history = [x for x in train]
# 遍历测试集中的每一个时间步
for i in range(len(test)):
# 拟合模型并进行预测
yhat = sarima_forecast(history, cfg)
# 将预测结果存储到列表中
predictions.append(yhat)
# 将实际观测值添加到历史数据中,用于下一轮预测
history.append(test[i])
# 计算预测误差
error = measure_rmse(test, predictions)
return error
# 评估模型性能
def score_model(data, n_test, cfg, debug=False):
result = None
# 将模型配置转换为字符串作为键值
key = str(cfg)
# 如果debug为真,则显示所有警告信息并在出现异常时停止程序
if debug:
result = walk_forward_validation(data, n_test, cfg)
else:
# 在进行网格搜索时不显示警告信息,避免干扰
try:
with catch_warnings():
filterwarnings('ignore')
result = walk_forward_validation(data, n_test, cfg)
except:
error = None
# 如果评估结果不为空,则打印模型配置和评估结果
if result is not None:
print(' > Model[%s] %.3f' % (key, result))
return (key, result)
# 网格搜索
def grid_search(data, cfg_list, n_test, parallel=True):
scores = None
if parallel:
# 并行处理任务
executor = Parallel(n_jobs=cpu_count(), backend='multiprocessing')
tasks = (delayed(score_model)(data, n_test, cfg) for cfg in cfg_list)
scores = executor(tasks)
else:
scores = [score_model(data, n_test, cfg) for cfg in cfg_list]
# 删除空结果
scores = [r for r in scores if r[1] != None]
# 根据预测误差对模型配置进行排序
scores.sort(key=lambda tup: tup[1])
return scores
# 生成SARIMA模型的超参数配置列表
def sarima_configs(seasonal=[0]):
models = list()
# 定义各个超参数的可能取值列表
p_params = [0, 1, 2]
d_params = [0, 1]
q_params = [0, 1, 2]
t_params = ['n', 'c', 't', 'ct']
P_params = [0, 1, 2]
D_params = [0, 1]
Q_params = [0, 1, 2]
m_params = seasonal
# 遍历所有可能的超参数组合
for p in p_params:
for d in d_params:
for q in q_params:
for t in t_params:
for P in P_params:
for D in D_params:
for Q in Q_params:
for m in m_params:
cfg = [(p, d, q), (P, D, Q, m), t]
models.append(cfg)
return models
if __name__ == '__main__':
# 从Excel文件中读取时间序列数据
series = read_csv('your_excel_file.xlsx', header=0, index_col=0)
data = series.values
# 选择最近的5年数据
data = data[-(5 * 12):]
# 设置测试集长度
n_test = 12
# 生成SARIMA模型的超参数配置列表
cfg_list = sarima_configs(seasonal=[0, 12])
# 进行网格搜索
scores = grid_search(data, cfg_list, n_test, False)
print('done')
# 打印最佳的3个模型配置和对应的预测误差
for cfg, error in scores[:3]:
print(cfg, error)
如何导入自己的Excel数据集:
- 将您的Excel文件命名为
your_excel_file.xlsx,并将其放置在与Python脚本相同的目录下。 - 确保您的Excel文件中包含时间序列数据,并且第一列是时间索引,第一行是列标题。
- 修改代码中的
read_csv('your_excel_file.xlsx', header=0, index_col=0),将文件名替换为您的Excel文件名。
例如,如果您的Excel文件名为monthly_sales.xlsx,则代码应修改为:
series = read_csv('monthly_sales.xlsx', header=0, index_col=0)
注意:
- 如果您的Excel文件中没有列标题,则将
header=0改为header=None。 - 如果您的时间索引不在第一列,则修改
index_col参数的值,使其对应于时间索引所在的列号(从0开始计数)。
原文地址: https://www.cveoy.top/t/topic/eeR8 著作权归作者所有。请勿转载和采集!