aboutsummaryrefslogtreecommitdiff
path: root/gantools
diff options
context:
space:
mode:
authorVee9ahd1 <[email protected]>2019-08-21 22:21:10 -0400
committerVee9ahd1 <[email protected]>2019-08-21 22:21:10 -0400
commitf6cbf568ff9815c180095733560a567dcb70e859 (patch)
treea6edd6071ff65537ac83124edfdf2df2dc8b83c8 /gantools
parentb7b043ac983613f02166b0b42bee70daff1539ef (diff)
added a check to make sure the correct amount of keys are passed to cubic interpolation (providing a more meaningful error message than TypeError('m > k must hold'))
Diffstat (limited to 'gantools')
-rw-r--r--gantools/cli.py28
-rw-r--r--gantools/latent_space.py3
2 files changed, 22 insertions, 9 deletions
diff --git a/gantools/cli.py b/gantools/cli.py
index a4159f5..bb7946d 100644
--- a/gantools/cli.py
+++ b/gantools/cli.py
@@ -1,4 +1,4 @@
-import sys, argparse
+import sys, os, argparse
from gantools import ganbreeder
from gantools import biggan
from gantools import latent_space
@@ -18,7 +18,7 @@ def handle_args(argv=None):
parser.add_argument('-b', '--nbatch', metavar='N', type=int, help='Number of frames in each \'batch\' \
(note: the truncation value can only change once per batch. Don\'t fuck with this unless you know \
what it does.).', default=1)
- parser.add_argument('-o', '--output-dir', help='Directory path for output images.')
+ parser.add_argument('-o', '--output-dir', help='Directory path for output images.', default=os.getcwd())
parser.add_argument('--prefix', help='File prefix for output images.')
parser.add_argument('--interp', choices=['linear', 'cubic'], default='cubic', help='Set interpolation method.')
group_loop = parser.add_mutually_exclusive_group(required=False)
@@ -41,12 +41,18 @@ def main():
# interpolate path through input space
print('Interpolating path through input space...')
- z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(
- keyframes,
- args.nframes,
- batch_size=args.nbatch,
- interp_method=args.interp,
- loop=args.loop)
+ try:
+ z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(
+ keyframes,
+ args.nframes,
+ batch_size=args.nbatch,
+ interp_method=args.interp,
+ loop=args.loop)
+ except ValueError as e:
+ print(e)
+ print('ERROR: Interpolation failed. Make sure you are using at least 3 keys (4 if --no-loop is enabled)')
+ print('If you would like to use fewer keys, try using the --interp linear argument')
+ return 1
# sample the GAN
print('Loading bigGAN...')
@@ -55,7 +61,11 @@ def main():
path = '' if args.output_dir == None else str(args.output_dir)
prefix = '' if args.prefix == None else str(args.prefix)
saver = image_utils.ImageSaver(output_dir=path, prefix=prefix)
- print('Saving image files to: '+path + prefix)
+ print('Image files will be saved to: '+path + prefix)
print('Sampling from bigGAN...')
gan.sample(z_seq, label_seq, truncation=truncation_seq, batch_size=args.nbatch, save_callback=saver.save)
print('Done.')
+ return 0
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/gantools/latent_space.py b/gantools/latent_space.py
index 44c64a3..64e3dec 100644
--- a/gantools/latent_space.py
+++ b/gantools/latent_space.py
@@ -21,8 +21,11 @@ def cubic_spline_interp(points, step_count):
tck = interpolate.splrep(x, y, s=0)
xnew = np.linspace(0., 1., step_count)
return interpolate.splev(xnew, tck, der=0)
+ if points.shape[0] < 4:
+ raise ValueError('Too few points for cubic interpolation: need 4, got {}'.format(points.shape[0]))
return np.apply_along_axis(cubic_spline_interp1d, 0, points)
+
# TODO: the math in this function is embarrasingly bad. fix at some point.
def sequence_keyframes(keyframes, num_frames, batch_size=1, interp_method='linear', loop=False):
interp_fn = {