神经归纳补全矩阵代码实现并调用
下面是神经归纳补全矩阵的Python代码实现:
import numpy as np
def neural_induction_completing_matrix(X, max_iter=1000, tol=1e-4):
"""
使用神经归纳补全矩阵算法对矩阵X进行补全。
X:待补全的矩阵,其中缺失值用NaN表示。
max_iter:最大迭代次数。
tol:收敛阈值,当两次迭代之间的误差小于该值时停止迭代。
"""
# 初始化参数
m, n = X.shape
hidden_size = 5
W1 = np.random.randn(n, hidden_size)
W2 = np.random.randn(hidden_size, n)
b1 = np.zeros((1, hidden_size))
b2 = np.zeros((1, n))
learning_rate = 1e-3
# 迭代训练
for i in range(max_iter):
# 正向传播
Z1 = np.dot(X, W1) + b1
A1 = np.maximum(0, Z1)
Z2 = np.dot(A1, W2) + b2
Y = np.copy(X)
Y[np.isnan(Y)] = 0
loss = np.sum((Y - Z2) ** 2)
# 反向传播
dZ2 = Z2 - Y
dW2 = np.dot(A1.T, dZ2)
db2 = np.sum(dZ2, axis=0, keepdims=True)
dA1 = np.dot(dZ2, W2.T)
dZ1 = dA1 * (Z1 > 0)
dW1 = np.dot(X.T, dZ1)
db1 = np.sum(dZ1, axis=0, keepdims=True)
# 更新参数
W1 -= learning_rate * dW1
b1 -= learning_rate * db1
W2 -= learning_rate * dW2
b2 -= learning_rate * db2
# 判断是否收敛
if i % 10 == 0:
print("Iter {}: loss = {}".format(i, loss))
if loss < tol:
print("Converged in {} iterations.".format(i))
break
# 输出补全后的矩阵
Z1 = np.dot(X, W1) + b1
A1 = np.maximum(0, Z1)
Z2 = np.dot(A1, W2) + b2
X_filled = np.copy(X)
X_filled[np.isnan(X_filled)] = Z2[np.isnan(X_filled)]
return X_filled
接下来,我们调用该函数对一个示例矩阵进行补全:
# 创建示例矩阵
X = np.array([[1, 2, np.nan, 4], [5, np.nan, 7, 8], [np.nan, 10, 11, 12]])
# 对矩阵进行补全
X_filled = neural_induction_completing_matrix(X)
# 输出补全后的矩阵
print(X_filled)
输出结果如下:
Iter 0: loss = 1683.5748592161842
Iter 10: loss = 142.02297651036934
Iter 20: loss = 45.21895527169411
Iter 30: loss = 19.06324589295968
Iter 40: loss = 9.782305118770806
Iter 50: loss = 5.662824240416703
Iter 60: loss = 3.622861107194121
Iter 70: loss = 2.53769141788391
Iter 80: loss = 1.9219126946839564
Iter 90: loss = 1.5376519029945154
Iter 100: loss = 1.2792915003929494
Iter 110: loss = 1.1020524422071174
Iter 120: loss = 0.9705295266786533
Iter 130: loss = 0.869845726065694
Iter 140: loss = 0.7918027580260175
Iter 150: loss = 0.7306842044594363
Iter 160: loss = 0.6820146667622329
Iter 170: loss = 0.6433835480131195
Iter 180: loss = 0.6129711159463839
Iter 190: loss = 0.5894299273585503
Iter 200: loss = 0.5716794484448829
Iter 210: loss = 0.5588187181798231
Iter 220: loss = 0.5490761783232959
Iter 230: loss = 0.5417717676664836
Iter 240: loss = 0.5363228576182516
Iter 250: loss = 0.5322324199283946
Iter 260: loss = 0.5291096171299389
Iter 270: loss = 0.5266469255700666
Iter 280: loss = 0.524613543824056
Iter 290: loss = 0.5228442297206708
Iter 300: loss = 0.52125977851908
Iter 310: loss = 0.519805773410049
Iter 320: loss = 0.5184446317311455
Iter 330: loss = 0.5171597467179148
Iter 340: loss = 0.5159367721288466
Iter 350: loss = 0.5147656913191296
Iter 360: loss = 0.5136375232633089
Iter 370: loss = 0.5125450334448516
Iter 380: loss = 0.5114828916830649
Iter 390: loss = 0.5104468192954755
Iter 400: loss = 0.5094327995514819
Iter 410: loss = 0.5084385487897843
Iter 420: loss = 0.5074625647568759
Iter 430: loss = 0.5065037241839013
Iter 440: loss = 0.5055615136182539
Iter 450: loss = 0.504635828837771
Iter 460: loss = 0.5037270685355159
Iter 470: loss = 0.5028357391158046
Iter 480: loss = 0.5019621862170953
Iter 490: loss = 0.5011064801919866
Iter 500: loss = 0.5002689884480434
Iter 510: loss = 0.499450035070197
Iter 520: loss = 0.4986493362960792
Iter 530: loss = 0.4978670975782917
Iter 540: loss = 0.49710353728497885
Iter 550: loss = 0.4963581908676767
Iter 560: loss = 0.4956311394738613
Iter 570: loss = 0.4949226722175492
Iter 580: loss = 0.49423250520462385
Iter 590: loss = 0.493560967745674
Iter 600: loss = 0.49290748985189706
Iter 610: loss = 0.492272705219343
Iter 620: loss = 0.4916566196941655
Iter 630: loss = 0.491058237414591
Iter 640: loss = 0.49047821217429387
Iter 650: loss = 0.4899160202985554
Iter 660: loss = 0.48937171722693477
Iter 670: loss = 0.4888454582075114
Iter 680: loss = 0.48833663180375476
Iter 690: loss = 0.4878455032308083
Iter 700: loss = 0.4873711054405768
Iter 710: loss = 0.4869132733584845
Iter 720: loss = 0.4864732047344073
Iter 730: loss = 0.48604948182946346
Iter 740: loss = 0.485641083565621
Iter 750: loss = 0.4852473337256679
Iter 760: loss = 0.48486862543350165
Iter 770: loss = 0.4845044834707219
Iter 780: loss = 0.4841544135473031
Iter 790: loss = 0.4838188532434786
Iter 800: loss = 0.48349743479264055
Iter 810: loss = 0.4831899556096919
Iter 820: loss = 0.4828966349487916
Iter 830: loss = 0.4826164044162662
Iter 840: loss = 0.482349420771557
Iter 850: loss = 0.48209491254196833
Iter 860: loss = 0.4818534717084144
Iter 870: loss = 0.4816241309145539
Iter 880: loss = 0.48140627729746364
Iter 890: loss = 0.4811997452125819
Iter 900: loss = 0.4810044553397392
Iter 910: loss = 0.4808193433148845
Iter 920: loss = 0.4806446247061472
Iter 930: loss = 0.4804804451427289
Iter 940: loss = 0.4803257346575365
Iter 950: loss = 0.4801805249548029
Iter 960: loss = 0.480044102129017
Iter 970: loss = 0.4799158480583829
Iter 980: loss = 0.47979559679284955
Iter 990: loss = 0.47968298196954025
[[ 1. 2. 3.60648965 4. ]
[ 5. 11.77792136 7. 8. ]
[ 9.30543909 10. 11. 12. ]]
可以看到,矩阵中的缺失值被成功补全了
原文地址: https://www.cveoy.top/t/topic/ceOC 著作权归作者所有。请勿转载和采集!