from math import sqrt
from multiprocessing import cpu_count
from joblib import Parallel
from joblib import delayed
from warnings import catch_warnings
from warnings import filterwarnings
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_squared_error
from pandas import read_csv

# 进行一步SARIMA预测
def sarima_forecast(history, config):
    order, sorder, trend = config
    # 定义模型
    model = SARIMAX(history, order=order, seasonal_order=sorder, trend=trend, enforce_stationarity=False, enforce_invertibility=False)
    # 拟合模型
    model_fit = model.fit(disp=False)
    # 进行一步预测
    yhat = model_fit.predict(len(history), len(history))
    return yhat[0]

# 计算均方根误差(RMSE)
def measure_rmse(actual, predicted):
    return sqrt(mean_squared_error(actual, predicted))

# 将时间序列数据划分为训练集和测试集
def train_test_split(data, n_test):
    return data[:-n_test], data[-n_test:]

# 逐步前向验证
def walk_forward_validation(data, n_test, cfg):
    predictions = list()
    # 划分数据集
    train, test = train_test_split(data, n_test)
    # 使用训练集初始化历史数据
    history = [x for x in train]
    # 遍历测试集中的每一个时间步
    for i in range(len(test)):
        # 拟合模型并进行预测
        yhat = sarima_forecast(history, cfg)
        # 将预测结果存储到列表中
        predictions.append(yhat)
        # 将实际观测值添加到历史数据中,用于下一轮预测
        history.append(test[i])
    # 计算预测误差
    error = measure_rmse(test, predictions)
    return error

# 评估模型性能
def score_model(data, n_test, cfg, debug=False):
    result = None
    # 将模型配置转换为字符串作为键值
    key = str(cfg)
    # 如果debug为真,则显示所有警告信息并在出现异常时停止程序
    if debug:
        result = walk_forward_validation(data, n_test, cfg)
    else:
        # 在进行网格搜索时不显示警告信息,避免干扰
        try:
            with catch_warnings():
                filterwarnings('ignore')
                result = walk_forward_validation(data, n_test, cfg)
        except:
            error = None
    # 如果评估结果不为空,则打印模型配置和评估结果
    if result is not None:
        print(' > Model[%s] %.3f' % (key, result))
    return (key, result)

# 网格搜索
def grid_search(data, cfg_list, n_test, parallel=True):
    scores = None
    if parallel:
        # 并行处理任务
        executor = Parallel(n_jobs=cpu_count(), backend='multiprocessing')
        tasks = (delayed(score_model)(data, n_test, cfg) for cfg in cfg_list)
        scores = executor(tasks)
    else:
        scores = [score_model(data, n_test, cfg) for cfg in cfg_list]
    # 删除空结果
    scores = [r for r in scores if r[1] != None]
    # 根据预测误差对模型配置进行排序
    scores.sort(key=lambda tup: tup[1])
    return scores

# 生成SARIMA模型的超参数配置列表
def sarima_configs(seasonal=[0]):
    models = list()
    # 定义各个超参数的可能取值列表
    p_params = [0, 1, 2]
    d_params = [0, 1]
    q_params = [0, 1, 2]
    t_params = ['n', 'c', 't', 'ct']
    P_params = [0, 1, 2]
    D_params = [0, 1]
    Q_params = [0, 1, 2]
    m_params = seasonal
    # 遍历所有可能的超参数组合
    for p in p_params:
        for d in d_params:
            for q in q_params:
                for t in t_params:
                    for P in P_params:
                        for D in D_params:
                            for Q in Q_params:
                                for m in m_params:
                                    cfg = [(p, d, q), (P, D, Q, m), t]
                                    models.append(cfg)
    return models

if __name__ == '__main__':
    # 从Excel文件中读取时间序列数据
    series = read_csv('your_excel_file.xlsx', header=0, index_col=0)
    data = series.values
    # 选择最近的5年数据
    data = data[-(5 * 12):]
    # 设置测试集长度
    n_test = 12
    # 生成SARIMA模型的超参数配置列表
    cfg_list = sarima_configs(seasonal=[0, 12])
    # 进行网格搜索
    scores = grid_search(data, cfg_list, n_test, False)
    print('done')
    # 打印最佳的3个模型配置和对应的预测误差
    for cfg, error in scores[:3]:
        print(cfg, error)

如何导入自己的Excel数据集:

  1. 将您的Excel文件命名为your_excel_file.xlsx,并将其放置在与Python脚本相同的目录下。
  2. 确保您的Excel文件中包含时间序列数据,并且第一列是时间索引,第一行是列标题。
  3. 修改代码中的read_csv('your_excel_file.xlsx', header=0, index_col=0),将文件名替换为您的Excel文件名。

例如,如果您的Excel文件名为monthly_sales.xlsx,则代码应修改为:

series = read_csv('monthly_sales.xlsx', header=0, index_col=0)

注意:

  • 如果您的Excel文件中没有列标题,则将header=0改为header=None
  • 如果您的时间索引不在第一列,则修改index_col参数的值,使其对应于时间索引所在的列号(从0开始计数)。
SARIMA模型超参数优化:网格搜索与时间序列预测

原文地址: https://www.cveoy.top/t/topic/eeR8 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录