deeplog 即 ,该工具使用 torchtrain 整合训练代码
找到安装包 torchtrain, 打开 module.py
文件
删除文件夹 __pycache__
在 module.py 文件中找到 fit 函数,并修改这个函数
# Loop over each epoch
def fit(self, X, y,
epochs = 10,
batch_size = 32,
learning_rate = 0.01,
criterion = nn.NLLLoss(),
optimizer = optim.SGD,
variable = False,
verbose = True,
**kwargs):
......
for epoch in range(1, epochs+1):
try:
# Loop over entire dataset
for X_, y_ in tqdm.tqdm(data,
desc="[Epoch {:{width}}/{:{width}}]".format(
epoch, epochs, width=len(str(epochs)))):
# Clear optimizer
optimizer.zero_grad()
......
修改为
# Loop over each epoch
for epoch in range(1, epochs+1):
try:
train_loss = 0
# Loop over entire dataset
for X_, y_ in tqdm.tqdm(data,
desc="[Epoch {:{width}}/{:{width}}]".format(
epoch, epochs, width=len(str(epochs)))):
# Clear optimizer
optimizer.zero_grad()
# Forward pass
# Get new input batch
X_ = X_.clone().detach().to(device)
# Run through module
y_pred = self(X_)
# Compute loss
loss = criterion(y_pred, y_)
train_loss += loss.item()
# Backward pass
# Propagate loss
loss.backward()
# Perform optimizer step
optimizer.step()
total_step = len(data)
print('Epoch [{}/{}], train_loss: {:.4f}'.format(epoch, epochs, train_loss / total_step))
except KeyboardInterrupt:
print("\nTraining interrupted, performing clean stop")
break
重启 kernel