def test(model, net):
    with torch.no_grad():
        dataset = MyDataset(opt.video_path,
                            opt.anno_path,
                            opt.val_list,
                            opt.vid_padding,
                            opt.txt_padding,
                            'test')
        
        print('num_test_data:{}'.format(len(dataset.data)))  
        model.eval()
        loader = dataset2dataloader(dataset, shuffle=False)
        loss_list, wer, cer = [], [], []
        crit = nn.CTCLoss()
        tic = time.time()
        
        for i_iter, input in enumerate(loader):            
            vid, txt, vid_len, txt_len = [input.get(key).cuda() for key in ['vid', 'txt', 'vid_len', 'txt_len']]
            y = net(vid)
            
            loss = crit(y.transpose(0, 1).log_softmax(-1), txt, vid_len.view(-1), txt_len.view(-1)).detach().cpu().numpy()
            loss_list.append(loss)
            pred_txt = ctc_decode(y)
            
            truth_txt = [MyDataset.arr2txt(txt[_], start=1) for _ in range(txt.size(0))]
            wer.extend(MyDataset.wer(pred_txt, truth_txt)) 
            cer.extend(MyDataset.cer(pred_txt, truth_txt)) 
                
            if i_iter % opt.display == 0:
                v = 1.0*(time.time()-tic)/(i_iter+1)
                eta = v * (len(loader)-i_iter) / 3600.0
                
                print(''.join(101*'-'))                
                print('{:<50}|{:>50}'.format('predict', 'truth'))
                print(''.join(101*'-'))                
                for predict, truth in list(zip(pred_txt, truth_txt))[:10]:
                    print('{:<50}|{:>50}'.format(predict, truth))                
                print(''.join(101 *'-'))
                print('test_iter={},eta={},wer={},cer={}'.format(i_iter, eta, np.array(wer).mean(), np.array(cer).mean()))                
                print(''.join(101 *'-'))
                
        return np.array(loss_list).mean(), np.array(wer).mean(), np.array(cer).mean()


# Chúng ta đã sử dụng unpacking để giảm thiểu số lượng code được sử dụng để lấy các đối tượng input.
# Chúng ta cũng đã sử dụng list comprehension để giảm thiểu số lượng code được sử dụng để lấy truth_txt.

# Chúng ta đã tạo các list loss_list, wer và cer từ đầu để giảm thiểu các câu lệnh được sử dụng để tạo list.
# Các biến loss, pred_txt, truth_txt, wer, cer cũng đã được giảm thiểu. 

# Cuối cùng, chúng ta đã sử dụng format string để giảm thiểu số lượng code để in ra các giá trị liên quan đến test_iter.
Tối ưu hóa đoạn code Python cho mô hình mạng nơ-ron

原文地址: http://www.cveoy.top/t/topic/lOGa 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录