您的当前位置:首页正文

pytorch criterion踩坑小结

2024-11-07 来源:个人技术集锦

1. 数据类型不匹配:

报错:Expected object of type torch.LongTensor but found type torch.FloatTensor for argument #2 ‘target’

criterion = nn.CrossEntropyLoss()
loss = criterion(y_pre, y_train)

这里的y_train类型一定要是LongTensor的,所以在写DataSet的时候返回的label就要是LongTensor类型的,如下

def__init__(self, ...)
显示全文