Skip to content
master
Switch branches/tags
Code
This branch is 5 commits ahead of kuangliu/pytorch-cifar:master.
Contribute

Latest commit

 

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
 
 
 
 
 
 
 
 
 
 
 
 

PyTorch Lightning CIFAR10

GitHub license saythanks

About this fork

Modified version of the PyTorch CIFAR project to exploit the PyTorch Lightning library.

In addition:

  • Improvements main.py script, allowing you to train one or more models in a single command.
  • Used PyTorch Lightning data module for the dataset (from the lightning-bolts package).
  • Optimizer changed to use OneCycleLR scheduler.
  • black formatter applied to all files.
  • Added more consistent config of VGG and ShuffleNetV2 models.
  • Added Tensorboard logging, and a JSON file final_metrics.json that saves final accuracy.

Library initially developed while at University of Glasgow's gicLAB.

Why is this useful?

CIFAR-10 is a small image classification dataset, which can be useful for validating research ideas (since models are smaller, and cheaper to train). It has two key advantages over the MNIST database in that the problem is a bit harder, and the images are non-greyscale. However, when doing research, you want to be able to have access to as much functionality as possible, without having to write a lot of boilerplate code. That is the design philosophy of the PyTorch Lightning library, hence why combining the two together makes sense. Note that few (if any) of these model architectures can be considered "official implementations", since the architectures have to be changed slightly to support the different data sizes. However, I have seen these models used in research, you need merely say that the models are defined for CIFAR10 (and ideally cite this repo, see the righthand side of the GitHub page!).

Contributing

PRs especially appreciated for features like:

  • fixing ShuffleNetV1 (Currently having an issue with ShuffleNetV1, where initializing the model fails).
  • expose more PyTorch Lightning features.
  • Better logging systems.
  • Systematic testing.
  • CI/CD with Travis.

Prerequisites

  • Python 3.6+
  • PyTorch 1.7+
  • lightning-bolts>=0.5.0
  • pytorch-lightning>=1.6.2

Training

# Start training with:
python main.py --model_arch [your model, e.g. `mobilenetv2`, `resnet18 resnet50 vgg16`, or `all` for all models]

Accuracy

Models compared against the table reported in the original repo. Some models were not reported in the original repo, and are represented with a -

Model Orig Acc. New Acc.1
DenseNet-CIFAR10 - 93.44%
DenseNet121 95.04% 94.94%
DenseNet161 - -
DenseNet169 - -
DenseNet201 - -
DLA 95.47% 94.29%
SimpleDLA 94.89% 93.04%
DPN26 - 95.03%
DPN92 95.16% -
EfficientNetB0 - -
LeNet - -
MobileNetV1 - -
MobileNetV2 94.43% -
PNASNetA - -
PNASNetB - -
PreActResNet18 95.11% -
PreActResNet34 - -
PreActResNet50 - -
PreActResNet101 - 94.33%
PreActResNet152 - 94.83%
RegNetX_200MF 94.24% 94.27%
RegNetX_400MF - 94.61%
RegNetY_400MF 94.29% 94.61%
ResNet18 93.02% 94.30%
ResNet34 - 93.86%
ResNet50 93.62% 94.73%
ResNet101 93.75% -
ResNet152 - -
ResNeXt29(32x4d) 94.73% -
ResNeXt29(2x64d) 94.82% -
VGG16 92.64% -

Footnotes

  1. Using default training config of main.py (i.e., 200 training epochs, initial learning rate of 0.05, etc), more configuration could give better results.

About

Common CNN models defined for PyTorch Lightning

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages