PyTorch Lightning CIFAR10
About this fork
Modified version of the PyTorch CIFAR project to exploit the PyTorch Lightning library.
In addition:
- Improvements
main.pyscript, allowing you to train one or more models in a single command. - Used PyTorch Lightning data module for the dataset (from the
lightning-boltspackage). - Optimizer changed to use OneCycleLR scheduler.
blackformatter applied to all files.- Added more consistent config of VGG and ShuffleNetV2 models.
- Added Tensorboard logging, and a JSON file
final_metrics.jsonthat saves final accuracy.
Library initially developed while at University of Glasgow's gicLAB.
Why is this useful?
CIFAR-10 is a small image classification dataset, which can be useful for validating research ideas (since models are smaller, and cheaper to train). It has two key advantages over the MNIST database in that the problem is a bit harder, and the images are non-greyscale. However, when doing research, you want to be able to have access to as much functionality as possible, without having to write a lot of boilerplate code. That is the design philosophy of the PyTorch Lightning library, hence why combining the two together makes sense. Note that few (if any) of these model architectures can be considered "official implementations", since the architectures have to be changed slightly to support the different data sizes. However, I have seen these models used in research, you need merely say that the models are defined for CIFAR10 (and ideally cite this repo, see the righthand side of the GitHub page!).
Contributing
PRs especially appreciated for features like:
- fixing ShuffleNetV1 (Currently having an issue with ShuffleNetV1, where initializing the model fails).
- expose more PyTorch Lightning features.
- Better logging systems.
- Systematic testing.
- CI/CD with Travis.
Prerequisites
- Python 3.6+
- PyTorch 1.7+
- lightning-bolts>=0.5.0
- pytorch-lightning>=1.6.2
Training
# Start training with:
python main.py --model_arch [your model, e.g. `mobilenetv2`, `resnet18 resnet50 vgg16`, or `all` for all models]
Accuracy
Models compared against the table reported in the original repo.
Some models were not reported in the original repo, and are represented with a -
| Model | Orig Acc. | New Acc.1 |
|---|---|---|
| DenseNet-CIFAR10 | - | 93.44% |
| DenseNet121 | 95.04% | 94.94% |
| DenseNet161 | - | - |
| DenseNet169 | - | - |
| DenseNet201 | - | - |
| DLA | 95.47% | 94.29% |
| SimpleDLA | 94.89% | 93.04% |
| DPN26 | - | 95.03% |
| DPN92 | 95.16% | - |
| EfficientNetB0 | - | - |
| LeNet | - | - |
| MobileNetV1 | - | - |
| MobileNetV2 | 94.43% | - |
| PNASNetA | - | - |
| PNASNetB | - | - |
| PreActResNet18 | 95.11% | - |
| PreActResNet34 | - | - |
| PreActResNet50 | - | - |
| PreActResNet101 | - | 94.33% |
| PreActResNet152 | - | 94.83% |
| RegNetX_200MF | 94.24% | 94.27% |
| RegNetX_400MF | - | 94.61% |
| RegNetY_400MF | 94.29% | 94.61% |
| ResNet18 | 93.02% | 94.30% |
| ResNet34 | - | 93.86% |
| ResNet50 | 93.62% | 94.73% |
| ResNet101 | 93.75% | - |
| ResNet152 | - | - |
| ResNeXt29(32x4d) | 94.73% | - |
| ResNeXt29(2x64d) | 94.82% | - |
| VGG16 | 92.64% | - |
Footnotes
-
Using default training config of
main.py(i.e., 200 training epochs, initial learning rate of 0.05, etc), more configuration could give better results.↩