GitHub

您所在的位置:网站首页 vaeaeon GitHub

GitHub

2024-07-08 00:51| 来源: 网络整理| 查看: 265

PyTorch VAE

Update 22/12/2021: Added support for PyTorch Lightning 1.5.6 version and cleaned up the code.

A collection of Variational AutoEncoders (VAEs) implemented in pytorch with focus on reproducibility. The aim of this project is to provide a quick and simple working example for many of the cool VAE models out there. All the models are trained on the CelebA dataset for consistency and comparison. The architecture of all the models are kept as similar as possible with the same layers, except for cases where the original paper necessitates a radically different architecture (Ex. VQ VAE uses Residual layers and no Batch-Norm, unlike other models). Here are the results of each model.

Requirements Python >= 3.5 PyTorch >= 1.3 Pytorch Lightning >= 0.6.0 (GitHub Repo) CUDA enabled computing device Installation $ git clone https://github.com/AntixK/PyTorch-VAE $ cd PyTorch-VAE $ pip install -r requirements.txt Usage $ cd PyTorch-VAE $ python run.py -c configs/

Config file template

model_params: name: "" in_channels: 3 latent_dim: . # Other parameters required by the model . . data_params: data_path: "" train_batch_size: 64 # Better to have a square number val_batch_size: 64 patch_size: 64 # Models are designed to work for this size num_workers: 4 exp_params: manual_seed: 1265 LR: 0.005 weight_decay: . # Other arguments required for training, like scheduler etc. . . trainer_params: gpus: 1 max_epochs: 100 gradient_clip_val: 1.5 . . . logging_params: save_dir: "logs/" name: ""

View TensorBoard Logs

$ cd logs//version_ $ tensorboard --logdir .

Note: The default dataset is CelebA. However, there has been many issues with downloading the dataset from google drive (owing to some file structure changes). So, the recommendation is to download the file from google drive directly and extract to the path of your choice. The default path assumed in the config files is `Data/celeba/img_align_celeba'. But you can change it acording to your preference.

Results Model Paper Reconstruction Samples VAE (Code, Config) Link Conditional VAE (Code, Config) Link WAE - MMD (RBF Kernel) (Code, Config) Link WAE - MMD (IMQ Kernel) (Code, Config) Link Beta-VAE (Code, Config) Link Disentangled Beta-VAE (Code, Config) Link Beta-TC-VAE (Code, Config) Link IWAE (K = 5) (Code, Config) Link MIWAE (K = 5, M = 3) (Code, Config) Link DFCVAE (Code, Config) Link MSSIM VAE (Code, Config) Link Categorical VAE (Code, Config) Link Joint VAE (Code, Config) Link Info VAE (Code, Config) Link LogCosh VAE (Code, Config) Link SWAE (200 Projections) (Code, Config) Link VQ-VAE (K = 512, D = 64) (Code, Config) Link N/A DIP VAE (Code, Config) Link Contributing

If you have trained a better model, using these implementations, by fine-tuning the hyper-params in the config file, I would be happy to include your result (along with your config file) in this repo, citing your name 😊.

Additionally, if you would like to contribute some models, please submit a PR.

License

Apache License 2.0

Permissions Limitations Conditions ✔️ Commercial use ❌ Trademark use ⓘ License and copyright notice ✔️ Modification ❌ Liability ⓘ State changes ✔️ Distribution ❌ Warranty ✔️ Patent use ✔️ Private use Citation @misc{Subramanian2020, author = {Subramanian, A.K}, title = {PyTorch-VAE}, year = {2020}, publisher = {GitHub}, journal = {GitHub repository}, howpublished = {\url{https://github.com/AntixK/PyTorch-VAE}} }


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3