测试 precompute_freqs_cis 函数的 Python 代码
下面是一个测试 precompute_freqs_cis 函数的例子:\n\npython\nimport torch\nfrom typing import Tuple\n\ndef test_precompute_freqs_cis():\n dim = 4\n end = 5\n theta = 10000.0\n \n # 调用被测试的函数\n freqs_cis = precompute_freqs_cis(dim, end, theta)\n \n # 验证输出的类型和形状是否正确\n assert isinstance(freqs_cis, torch.Tensor)\n assert freqs_cis.dtype == torch.complex64\n assert freqs_cis.shape == (end, dim // 2)\n \n # 验证计算是否正确\n freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))\n t = torch.arange(end, device=freqs.device)\n expected_freqs_cis = torch.polar(torch.ones(end, dim // 2), torch.outer(t, freqs))\n assert torch.allclose(freqs_cis, expected_freqs_cis)\n\n\n你可以运行 test_precompute_freqs_cis() 来执行测试。
原文地址: https://www.cveoy.top/t/topic/p8Vx 著作权归作者所有。请勿转载和采集!