您的当前位置:首页正文

deeplog打印损失

2024-11-19 来源:个人技术集锦

deeplog 即 ,该工具使用 torchtrain 整合训练代码

step1 查找文件

找到安装包 torchtrain, 打开 module.py 文件

删除文件夹 __pycache__

step2 修改文件

在 module.py 文件中找到 fit 函数,并修改这个函数

# Loop over each epoch
 def fit(self, X, y,
            epochs        = 10,
            batch_size    = 32,
            learning_rate = 0.01,
            criterion     = nn.NLLLoss(),
            optimizer     = optim.SGD,
            variable      = False,
            verbose       = True,
            **kwargs):
        ......
        for epoch in range(1, epochs+1):
            try:
                # Loop over entire dataset
                for X_, y_ in tqdm.tqdm(data,
                    desc="[Epoch {:{width}}/{:{width}}]".format(
                       epoch, epochs, width=len(str(epochs)))):
                    # Clear optimizer
                    optimizer.zero_grad()
        ......

修改为

# Loop over each epoch
        for epoch in range(1, epochs+1):
            try:
                train_loss = 0
                # Loop over entire dataset
                for X_, y_ in tqdm.tqdm(data,
                    desc="[Epoch {:{width}}/{:{width}}]".format(
                       epoch, epochs, width=len(str(epochs)))):
                    

                    # Clear optimizer
                    optimizer.zero_grad()

                    # Forward pass
                    # Get new input batch
                    X_ = X_.clone().detach().to(device)
                    # Run through module
                    y_pred = self(X_)
                    # Compute loss
                    loss = criterion(y_pred, y_)
                    train_loss += loss.item()

                    # Backward pass
                    # Propagate loss
                    loss.backward()
                    # Perform optimizer step
                    optimizer.step()
                total_step = len(data)
                print('Epoch [{}/{}], train_loss: {:.4f}'.format(epoch, epochs, train_loss / total_step))

            except KeyboardInterrupt:
                print("\nTraining interrupted, performing clean stop")
                break

step3 环境重启

重启 kernel

显示全文