LSTM, GRU, and MNN Model Performance Comparison for Latitude Error
import pandas as pd import numpy as np import matplotlib.pyplot as plt
y_hat = pd.read_csv('y_hat.csv') # len = 93 predict_Y = pd.read_csv('predict_Y.csv') # len = 92 original = pd.read_csv('original.csv') # len = 100 y_pred = pd.read_csv('y_pred_mnn.csv') # len = 100
plt.figure(figsize=(10, 10)) plt.xlabel('Longitude Error (degree)') plt.ylabel('Latitude Error (degree)')
LSTM
a_lat_LSTM = [] b_lon_LSTM = [] for i in range(len(y_hat)): long = original.iloc[i + 7, 1] - y_hat.iloc[i, 1] lat = original.iloc[i+7, 0] - y_hat.iloc[i, 0] a_lat_LSTM.append(long) b_lon_LSTM.append(lat)
GRU
a_lat_GRU = [] b_lon_GRU = [] for i in range(len(predict_Y)): long = original.iloc[i+8, 1] - predict_Y.iloc[i, 1] lat = original.iloc[i+8, 0] - predict_Y.iloc[i, 0] a_lat_GRU.append(long) b_lon_GRU.append(lat)
MNN
a_lat_MNN = [] b_lon_MNN = [] for i in range(len(y_pred)): long = original.iloc[i, 1] - y_pred.iloc[i, 2] lat = original.iloc[i, 0] - y_pred.iloc[i, 1] a_lat_MNN.append(long) b_lon_MNN.append(lat)
LSTM
MSE_lat_LSTM = np.square(np.array(a_lat_LSTM)).mean() MAE_lat_LSTM = np.abs(np.array(a_lat_LSTM)).mean() RMSE_lat_LSTM = np.sqrt(MSE_lat_LSTM)
MSE_lon_LSTM = np.square(np.array(b_lon_LSTM)).mean() MAE_lon_LSTM = np.abs(np.array(b_lon_LSTM)).mean() RMSE_lon_LSTM = np.sqrt(MSE_lon_LSTM)
GRU
MSE_lat_GRU = np.square(np.array(a_lat_GRU)).mean() MAE_lat_GRU = np.abs(np.array(a_lat_GRU)).mean() RMSE_lat_GRU = np.sqrt(MSE_lat_GRU)
MSE_lon_GRU = np.square(np.array(b_lon_GRU)).mean() MAE_lon_GRU = np.abs(np.array(b_lon_GRU)).mean() RMSE_lon_GRU = np.sqrt(MSE_lon_GRU)
MNN
MSE_lat_MNN = np.square(np.array(a_lat_MNN)).mean() MAE_lat_MNN = np.abs(np.array(a_lat_MNN)).mean() RMSE_lat_MNN = np.sqrt(MSE_lat_MNN)
MSE_lon_MNN = np.square(np.array(b_lon_MNN)).mean() MAE_lon_MNN = np.abs(np.array(b_lon_MNN)).mean() RMSE_lon_MNN = np.sqrt(MSE_lon_MNN)
print('LSTM:') print('MSE of Latitude Error:', MSE_lat_LSTM) print('MAE of Latitude Error:', MAE_lat_LSTM) print('RMSE of Latitude Error:', RMSE_lat_LSTM)
print('MSE of Longitude Error:', MSE_lon_LSTM)
print('MAE of Longitude Error:', MAE_lon_LSTM)
print('RMSE of Longitude Error:', RMSE_lon_LSTM)
print('GRU:') print('MSE of Latitude Error:', MSE_lat_GRU) print('MAE of Latitude Error:', MAE_lat_GRU) print('RMSE of Latitude Error:', RMSE_lat_GRU)
print('MSE of Longitude Error:', MSE_lon_GRU)
print('MAE of Longitude Error:', MAE_lon_GRU)
print('RMSE of Longitude Error:', RMSE_lon_GRU)
print('MNN:') print('MSE of Latitude Error:', MSE_lat_MNN) print('MAE of Latitude Error:', MAE_lat_MNN) print('RMSE of Latitude Error:', RMSE_lat_MNN)
print('MSE of Longitude Error:', MSE_lon_MNN)
print('MAE of Longitude Error:', MAE_lon_MNN)
print('RMSE of Longitude Error:', RMSE_lon_MNN)
构造数据
labels = ['MSE', 'MAE', 'RMSE'] data_mnn = [MSE_lat_MNN, MAE_lat_MNN, RMSE_lat_MNN] data_lstm = [MSE_lat_LSTM, MAE_lat_LSTM, RMSE_lat_LSTM] data_gru = [MSE_lat_GRU, MAE_lat_GRU, RMSE_lat_GRU]
x = np.arange(len(labels)) width = .25
plt.rcParams['font.family'] = "Times New Roman"
plots
fig, ax = plt.subplots(figsize=(5, 3), dpi=200) bar_a = ax.bar(x - width / 2, data_mnn, width, label='MNN', color='#130074', ec='black', lw=.5) bar_b = ax.bar(x + width / 2, data_lstm, width, label='LSTM', color='#CB181B', ec='black', lw=.5)
bar_c = ax.bar(x + width * 3 / 2, data_gru, width, label='GRU', color='white', ec='black', lw=.5)
bar_c = ax.bar(x+width*3/2, data_gru,width,label='GRU',color='#008B45',ec='black',lw=.5)
定制化设计
ax.tick_params(axis='x', direction='in', bottom=False) ax.tick_params(axis='y', direction='out', labelsize=8, length=3) ax.set_xticks(x) ax.set_xticklabels(labels, size=10)
ax.set_ylim(bottom=0, top=40)
ax.set_yticks(np.arange(0, 50, step=5))
for spine in ['top', 'right']: ax.spines[spine].set_color('none')
ax.legend(fontsize=7, frameon=False)
text_font = {'size': '14', 'weight': 'bold', 'color': 'black'} plt.savefig('Column1.png', width=5, height=3,dpi=900, bbox_inches='tight') plt.show()
ax.yaxis.grid(True, linestyle='--', alpha=0.5) # 添加横向网格线 ax.xaxis.grid(False) # 取消纵向网格线
原文地址: https://www.cveoy.top/t/topic/l6F7 著作权归作者所有。请勿转载和采集!