前言

气死我了,这个RuntimeError: mat1 and mat2 shapes cannot be multiplied 疯狂报错,很久以前我就不知道这么处理这个事情,昨天晚上搞了半天没有搞好,气死人了。

今天好好学习一下吧!

额额,暂时不用STL数据库了,之后再补!

正文

1. 数据导入

因为最近学的是STL10数据库,那就先导入一下吧。

from torchvision.datasets import STL10
train_data = STL10(root='./data', split='train', download=True, transform=train_transform_96)
test_data = STL10(root='./data', split='test', download=True, transform=val_test_transform_96)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

(给以后有可能的AE-RL 项目用)
因为之前cifar10的项目,存储的是特殊的格式,cifar10_data.npz。于是乎我们将STL10也转换为这个形式:

# Extract data and labels
X_train = train_data.data  # Shape: (5000, 3, 96, 96)
y_train = train_data.labels  # Labels from 1 to 10
X_test = test_data.data  # Shape: (8000, 3, 96, 96)
y_test = test_data.labels  # Labels from 1 to 10

# Adjust to (nb, 96, 96, 3)
X_train = np.transpose(X_train, (0, 2, 3, 1))  # New shape: (5000, 96, 96, 3)
X_test = np.transpose(X_test, (0, 2, 3, 1))    # New shape: (8000, 96, 96, 3)

# Adjust labels to be zero-based
y_train = y_train - 1  # Now labels are from 0 to 9
y_test = y_test - 1

# Save the data into an .npz file
# np.savez('../data/stl10_data.npz', X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)

然后load的时候就是这个形式:

def load_data():
    data = np.load('../data/stl10_data.npz')
    X_train = data['X_train']
    y_train = data['y_train']
    X_test = data['X_test']
    y_test = data['y_test']
    y_train = y_train.squeeze()
    y_test = y_test.squeeze()   
    return X_train, y_train, X_test, y_test

2.

参考

总结

Q.E.D.


立志做一个有趣的碳水化合物