该代码段是一个名为 CQRDQN 的类,用于实现基于深度强化学习的交易策略。该类包含了多个函数和变量,用于构建神经网络、生成交易决策和训练网络等功能。
在构建神经网络方面,该类使用了 CNet 类的成员函数,包括 'feedForward'、'backProp' 和 'getResults' 等函数,用于输入数据、反向传播和获取网络的输出结果。同时,该类还包含了一个名为 'cTargetNet' 的 CNet 类对象,用于实现目标网络的更新和交换。
在生成交易决策方面,该类使用了 'getAction' 和 'getSample' 函数,前者用于获取当前时刻的交易动作,后者用于在策略探索阶段随机生成交易动作。
在训练网络方面,该类使用了 'backProp' 函数,用于根据当前状态、动作和奖励信号更新网络参数。同时,该类还包含了一些辅助函数,如 'getRecentAverageError'、'Save' 和 'Load' 等,用于获取网络的性能指标和保存/加载网络模型。
class CQRDQN : protected CNet
{
private:
uint iCountBackProp;
protected:
uint iNumbers;
uint iActions;
uint iUpdateTarget;
matrix mTaus;
//---
CNet cTargetNet;
public:
/** Constructor */
CQRDQN(void);
CQRDQN(CArrayObj Description) { Create(Description, iActions); }
bool Create(CArrayObj Description, uint actions);
/ Destructor */~CQRDQN(void);
bool feedForward(CArrayFloat *inputVals, int window = 1, bool tem = true)
{ return CNet::feedForward(inputVals, window, tem); }
bool backProp(CBufferFloat *targetVals, float discount,
CArrayFloat *nextState, int window = 1, bool tem = true);
void getResults(CBufferFloat *&resultVals);
int getAction(void);
int getSample(void);
float getRecentAverageError() { return recentAverageError; }
bool Save(string file_name, datetime time, bool common = true)
{ return CNet::Save(file_name, getRecentAverageError(),
(float)iActions, 0, time, common); }
virtual bool Save(const int file_handle);
virtual bool Load(string file_name, datetime &time, bool common = true);
virtual bool Load(const int file_handle);
//---
virtual int Type(void) const { return defQRDQN; }
virtual bool TrainMode(bool flag) { return CNet::TrainMode(flag); }
virtual bool GetLayerOutput(uint layer, CBufferFloat *&result)
{ return CNet::GetLayerOutput(layer, result); }
//---
virtual void SetUpdateTarget(uint batch) { iUpdateTarget = batch; }
virtual bool UpdateTarget(string file_name);
//---
virtual bool SetActions(uint actions);
};