New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Swin Transformer architecture #5491
Conversation
|
|
@xiaohu2015 Can you please share the logs with @jdsgomes and all the information that will allow us to reproduce your experiment (for example the git commit hashcodes you used etc). It's also unclear to me whether you used TorchVision's reference scripts or something else in your experiments. Could you please clarify? |
Despite they use iteration-based lr scheduler, in the official repo they convert parameterise the scheduler builder with the number of epochs and just convert the number of epochs to iterations so I think it should be equivalent. |
for swin_t, I have shared the training logs (just use TorchVision's reference script ) with @jdsgomes, I got @jdsgomes yes, iteration-based lr scheduler and epoch-based lr scheduler should be equivalent. I just suspect that might be the reason, because some minor difference can make a result difference. |
After discussing offline with @datumbox I think we should proceed to merge the PR with the swin_t only since it is clear that we can reproduce the result, so great work @xiaohu2015 ! After that we can continue investigations to close the gap and aim to merge the other variants in a different PR. I will do the final cleanups between today and tomorrow. |
|
Just wanted to echo what Joao said. Big massive thank you @xiaohu2015 for your awesome contribution. Top notch code and excellent research reproduction skills. Also apologies for taking us long to review and reproduce the PR; it's something we want to improve upon. Looking forward seeing this merged! |
| @@ -416,7 +416,7 @@ class Swin_T_Weights(WeightsEnum): | |||
| IMAGENET1K_V1 = Weights( | |||
| url="https://download.pytorch.org/models/swin_t-81486767.pth", | |||
| transforms=partial( | |||
| ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC | |||
| ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This value was determined in post-training optimisation similarly to what we did in convenext
|
for other models, I can convert the offical weight to torchvision version just like efficientnet. |
@xiaohu2015 I understand that would be useful to include pre-trained weights from the initial implementation, but we would prefer to include the other variants once we can replicate the results fully. I am running a few experiments now, and hopefully we can get good results, but for now I will remove even the constructors so this PR can be merged. |
LGTM, thanks again @xiaohu2015 for the awesome contribution.
@jdsgomes thanks as well for your support and guidance.
I think we are good to merge. Just make sure we remove the unnecessary expect files ModelTester.test_swin_*_expect.pk for variants s/b/l that were removed.
Thanks @xiaohu2015 for the great contribution and @datumbox for the feedback
This work related to #2707 and #5410: add swin transformer to torchvision model_zoo.
Refactor code.
I made some modifications compared to official code:
remove
absolute position embedding: as we can see from table 4 in the paper, the swin model withrelative position biasget best results, so the default swin model does not useabsolute position embedding. Another trouble withabsolute position embeddingis that we have to set input_size to the model to initialize the pos_embedding.remove the
input_resolutionparameter: so input with arbitrary shape can be handled by the swint model, which is necessary for some tasks eg. segmentaion and object detection. Compared to offical code, we keep tensor with shape [B, H, W, C] instead of [B, N, C], so we can get width and height withoutinput_resolution. But after do that, one must dynamically compare the window size and input size in the shifted window attention, for example, if the input size is lower than window size (when the image size is 224, the feature size of last stage is 7x7), you need't do shift operation. but the dynamic behavior is not well supported in torch.fx, so I createshifted_window_attentionfunction and warp it. Note: this modification can add run time as we have to generateattention_maskdynamically, but the cost time is insignificant.Validate the training.
which can give result:
Acc@1 81.222 Acc@5 95.332train logsI also modified the reference code, https://github.com/xiaohu2015/vision/blob/main/references/classification/utils.py#L406. as the current code only supports no weight decay for norm layers.
references