diff --git a/environment.yml b/environment.yml new file mode 100644 index 00000000..3b2ae478 --- /dev/null +++ b/environment.yml @@ -0,0 +1,50 @@ +name: ncsn +channels: + - anaconda + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - ca-certificates=2023.08.22=h06a4308_0 + - certifi=2021.5.30=py36h06a4308_0 + - ld_impl_linux-64=2.38=h1181459_1 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.1.0=hdf63c60_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - ncurses=6.3=h7f8727e_2 + - openssl=1.1.1w=h7f8727e_0 + - pip=21.2.2=py36h06a4308_0 + - python=3.6.13=h12debd9_1 + - pyyaml=5.4.1=py36h27cfd23_1 + - readline=8.1.2=h7f8727e_1 + - setuptools=58.0.4=py36h06a4308_0 + - sqlite=3.38.5=hc218d9a_0 + - tk=8.6.12=h1ccaba5_0 + - wheel=0.37.1=pyhd3eb1b0_0 + - xz=5.2.5=h7f8727e_1 + - yaml=0.2.5=h7b6447c_0 + - zlib=1.2.12=h7f8727e_2 + - pip: + - cycler==0.11.0 + - dataclasses==0.8 + - importlib-resources==5.4.0 + - kiwisolver==1.3.1 + - matplotlib==3.3.4 + - numpy==1.19.5 + - packaging==21.3 + - pandas==1.1.5 + - pillow==8.4.0 + - protobuf==4.21.0 + - pyparsing==3.1.1 + - python-dateutil==2.8.2 + - pytz==2024.1 + - scipy==1.5.4 + - seaborn==0.11.2 + - six==1.16.0 + - tensorboardx==2.6.2.2 + - torch==1.10.1 + - torchvision==0.11.2 + - tqdm==4.64.1 + - typing-extensions==4.1.1 + - zipp==3.6.0 +prefix: /opt/conda/envs/ncsn + diff --git a/main.py b/main.py index 8a04f9b7..9b46151c 100644 --- a/main.py +++ b/main.py @@ -37,8 +37,10 @@ def parse_args_and_config(): config = yaml.load(f) new_config = dict2namespace(config) else: - with open(os.path.join(args.log, 'config.yml'), 'r') as f: - config = yaml.load(f) + # Register the constructor for the `!!python/object:argparse.Namespace` tag + yaml.SafeLoader.add_constructor('tag:yaml.org,2002:python/object:argparse.Namespace', construct_namespace) + config_path = os.path.join(args.log, 'config.yml') + config = load_config(config_path) new_config = config if not args.test: @@ -103,6 +105,20 @@ def dict2namespace(config): setattr(namespace, key, new_value) return namespace +# Define a constructor for handling `argparse.Namespace` +def construct_namespace(loader, node): + # Construct a mapping from the node + value = loader.construct_mapping(node) + # Return an argparse.Namespace object constructed from the mapping + return argparse.Namespace(**value) + +# Function to load the YAML file +def load_config(file_path): + with open(file_path, 'r') as f: + # Load the YAML using the SafeLoader + config = yaml.load(f, Loader=yaml.SafeLoader) + return config + def main(): args, config = parse_args_and_config() diff --git a/runners/anneal_runner.py b/runners/anneal_runner.py index 14b58245..015cc329 100644 --- a/runners/anneal_runner.py +++ b/runners/anneal_runner.py @@ -149,7 +149,8 @@ def train(self): optimizer.step() tb_logger.add_scalar('loss', loss, global_step=step) - logging.info("step: {}, loss: {}".format(step, loss.item())) + if step % 200 == 0: + logging.info("step: {}, loss: {}".format(step, loss.item())) if step >= self.config.training.n_iters: return 0 @@ -180,8 +181,16 @@ def train(self): score.state_dict(), optimizer.state_dict(), ] + logging.info("step: {}, checkpoint_: {}".format(step, os.path.join(self.args.log, 'checkpoint.pth'))) torch.save(states, os.path.join(self.args.log, 'checkpoint_{}.pth'.format(step))) torch.save(states, os.path.join(self.args.log, 'checkpoint.pth')) + states = [ + score.state_dict(), + optimizer.state_dict(), + ] + logging.info("step: {}, checkpoint_: {}".format(step, os.path.join(self.args.log, 'checkpoint.pth'))) + torch.save(states, os.path.join(self.args.log, 'checkpoint_{}.pth'.format(step))) + torch.save(states, os.path.join(self.args.log, 'checkpoint.pth')) def Langevin_dynamics(self, x_mod, scorenet, n_steps=200, step_lr=0.00005): images = []