前言
气死我了,这个RuntimeError: mat1 and mat2 shapes cannot be multiplied
疯狂报错,很久以前我就不知道这么处理这个事情,昨天晚上搞了半天没有搞好,气死人了。
今天好好学习一下吧!
额额,暂时不用STL数据库了,之后再补!
正文
1. 数据导入
因为最近学的是STL10数据库,那就先导入一下吧。
from torchvision.datasets import STL10
train_data = STL10(root='./data', split='train', download=True, transform=train_transform_96)
test_data = STL10(root='./data', split='test', download=True, transform=val_test_transform_96)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)
(给以后有可能的AE-RL 项目用)
因为之前cifar10的项目,存储的是特殊的格式,cifar10_data.npz
。于是乎我们将STL10也转换为这个形式:
# Extract data and labels
X_train = train_data.data # Shape: (5000, 3, 96, 96)
y_train = train_data.labels # Labels from 1 to 10
X_test = test_data.data # Shape: (8000, 3, 96, 96)
y_test = test_data.labels # Labels from 1 to 10
# Adjust to (nb, 96, 96, 3)
X_train = np.transpose(X_train, (0, 2, 3, 1)) # New shape: (5000, 96, 96, 3)
X_test = np.transpose(X_test, (0, 2, 3, 1)) # New shape: (8000, 96, 96, 3)
# Adjust labels to be zero-based
y_train = y_train - 1 # Now labels are from 0 to 9
y_test = y_test - 1
# Save the data into an .npz file
# np.savez('../data/stl10_data.npz', X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)
然后load的时候就是这个形式:
def load_data():
data = np.load('../data/stl10_data.npz')
X_train = data['X_train']
y_train = data['y_train']
X_test = data['X_test']
y_test = data['y_test']
y_train = y_train.squeeze()
y_test = y_test.squeeze()
return X_train, y_train, X_test, y_test
2.
参考
总结
Q.E.D.