diff options
author | Vee9ahd1 <[email protected]> | 2019-08-21 22:21:10 -0400 |
---|---|---|
committer | Vee9ahd1 <[email protected]> | 2019-08-21 22:21:10 -0400 |
commit | f6cbf568ff9815c180095733560a567dcb70e859 (patch) | |
tree | a6edd6071ff65537ac83124edfdf2df2dc8b83c8 /gantools | |
parent | b7b043ac983613f02166b0b42bee70daff1539ef (diff) |
added a check to make sure the correct amount of keys are passed to cubic interpolation (providing a more meaningful error message than TypeError('m > k must hold'))
Diffstat (limited to 'gantools')
-rw-r--r-- | gantools/cli.py | 28 | ||||
-rw-r--r-- | gantools/latent_space.py | 3 |
2 files changed, 22 insertions, 9 deletions
diff --git a/gantools/cli.py b/gantools/cli.py index a4159f5..bb7946d 100644 --- a/gantools/cli.py +++ b/gantools/cli.py @@ -1,4 +1,4 @@ -import sys, argparse +import sys, os, argparse from gantools import ganbreeder from gantools import biggan from gantools import latent_space @@ -18,7 +18,7 @@ def handle_args(argv=None): parser.add_argument('-b', '--nbatch', metavar='N', type=int, help='Number of frames in each \'batch\' \ (note: the truncation value can only change once per batch. Don\'t fuck with this unless you know \ what it does.).', default=1) - parser.add_argument('-o', '--output-dir', help='Directory path for output images.') + parser.add_argument('-o', '--output-dir', help='Directory path for output images.', default=os.getcwd()) parser.add_argument('--prefix', help='File prefix for output images.') parser.add_argument('--interp', choices=['linear', 'cubic'], default='cubic', help='Set interpolation method.') group_loop = parser.add_mutually_exclusive_group(required=False) @@ -41,12 +41,18 @@ def main(): # interpolate path through input space print('Interpolating path through input space...') - z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes( - keyframes, - args.nframes, - batch_size=args.nbatch, - interp_method=args.interp, - loop=args.loop) + try: + z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes( + keyframes, + args.nframes, + batch_size=args.nbatch, + interp_method=args.interp, + loop=args.loop) + except ValueError as e: + print(e) + print('ERROR: Interpolation failed. Make sure you are using at least 3 keys (4 if --no-loop is enabled)') + print('If you would like to use fewer keys, try using the --interp linear argument') + return 1 # sample the GAN print('Loading bigGAN...') @@ -55,7 +61,11 @@ def main(): path = '' if args.output_dir == None else str(args.output_dir) prefix = '' if args.prefix == None else str(args.prefix) saver = image_utils.ImageSaver(output_dir=path, prefix=prefix) - print('Saving image files to: '+path + prefix) + print('Image files will be saved to: '+path + prefix) print('Sampling from bigGAN...') gan.sample(z_seq, label_seq, truncation=truncation_seq, batch_size=args.nbatch, save_callback=saver.save) print('Done.') + return 0 + +if __name__ == '__main__': + sys.exit(main()) diff --git a/gantools/latent_space.py b/gantools/latent_space.py index 44c64a3..64e3dec 100644 --- a/gantools/latent_space.py +++ b/gantools/latent_space.py @@ -21,8 +21,11 @@ def cubic_spline_interp(points, step_count): tck = interpolate.splrep(x, y, s=0) xnew = np.linspace(0., 1., step_count) return interpolate.splev(xnew, tck, der=0) + if points.shape[0] < 4: + raise ValueError('Too few points for cubic interpolation: need 4, got {}'.format(points.shape[0])) return np.apply_along_axis(cubic_spline_interp1d, 0, points) + # TODO: the math in this function is embarrasingly bad. fix at some point. def sequence_keyframes(keyframes, num_frames, batch_size=1, interp_method='linear', loop=False): interp_fn = { |