解决TensorFlow中的InvalidArgumentError: cannot compute MatMul as input #1(zero-based) was expected to be a float tensor but is a double tensor
解决TensorFlow中的InvalidArgumentError: cannot compute MatMul as input #1(zero-based) was expected to be a float tensor but is a double tensor
在使用TensorFlow进行矩阵乘法时,可能会遇到以下错误:
InvalidArgumentError Traceback (most recent call last)
Cell In[12], line 15
11 for i in range(iter + 1):
13 with tf.GradientTape() as tape:
---> 15 PRED_train = tf.matmul(X_train,W)
17 Loss_train = 0.5 * tf.reduce_mean(tf.square(Y_train - PRED_train))
19 PRED_test = tf.matmul(X_test,W)
File ~\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\util\traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
151 except Exception as e:
152 filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153 raise e.with_traceback(filtered_tb) from None
154 finally:
155 del filtered_tb
File ~\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\framework\ops.py:7262, in raise_from_not_ok_status(e, name)
7260 def raise_from_not_ok_status(e, name):
7261 e.message += (' name: ' + name if name is not None else '')
-> 7262 raise core._status_to_exception(e) from None
InvalidArgumentError: cannot compute MatMul as input #1(zero-based) was expected to be a float tensor but is a double tensor [Op:MatMul]
出错的原因是输入的张量类型不匹配。 根据错误信息,输入的张量X_train和W的类型应该是float类型,但实际上它们的类型是double。
如何修改内容:
要解决这个问题,可以将输入的张量类型转换为float类型。可以使用tf.cast()函数来实现类型转换。修改代码如下:
X_train = tf.cast(X_train, tf.float32)
W = tf.cast(W, tf.float32)
在计算tf.matmul()之前,将X_train和W转换为float类型。这样就可以解决InvalidArgumentError错误。
原文地址: https://www.cveoy.top/t/topic/fM1t 著作权归作者所有。请勿转载和采集!