diff options
author | Vee9ahd1 <[email protected]> | 2019-05-12 01:29:05 -0400 |
---|---|---|
committer | Vee9ahd1 <[email protected]> | 2019-05-12 01:29:05 -0400 |
commit | 2e27eb73344d691b657f72c8e794f81ce47036c6 (patch) | |
tree | fafe6255167c74131f03e2a80b105278c28af30d /gantools/biggan.py | |
parent | 560f86d452277084a1be04fbc4c0e8c5f1206ff5 (diff) |
implemented most of the basic functionality from the prototype script and created some messy tests
Diffstat (limited to 'gantools/biggan.py')
-rw-r--r-- | gantools/biggan.py | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/gantools/biggan.py b/gantools/biggan.py new file mode 100644 index 0000000..8a0a71d --- /dev/null +++ b/gantools/biggan.py @@ -0,0 +1,48 @@ +# methods for setting up and interacting with biggan +import tensorflow as tf +import tensorflow_hub as hub +import numpy as np +from itertools import cycle + +MODULE_PATH = 'https://tfhub.dev/deepmind/biggan-512/2' + +class BigGAN: + def __init__(self, module_path=MODULE_PATH): + tf.reset_default_graph() + print('Loading BigGAN module from:', module_path) + module = hub.Module(module_path) + self.inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k) + for k, v in module.get_input_info_dict().items()} + self.input_z = self.inputs['z'] + self.dim_z = self.input_z.shape.as_list()[1] + self.input_y = self.inputs['y'] + self.vocab_size = self.input_y.shape.as_list()[1] # dimension of y (aka label count) + self.input_trunc = self.inputs['truncation'] + self.output = module(self.inputs) + + # initialize/instantiate tf variables + initializer = tf.global_variables_initializer() + self.sess = tf.Session() + self.sess.run(initializer) + + def sample(self, vectors, labels, truncation=0.5, batch_size=1): + num = vectors.shape[0] + + # deal with scalar input case + truncation = np.asarray(truncation) + if truncation.ndim == 0:# truncation is a scalar + #TODO: there has to be a better way to do this... + truncation = cycle([truncation]) + + ims = [] + for batch_start, trunc in zip(range(0, num, batch_size), truncation): + s = slice(batch_start, min(num, batch_start + batch_size)) + feed_dict = {self.input_z: vectors[s], self.input_y: labels[s], self.input_trunc: trunc} + ims.append(self.sess.run(self.output, feed_dict=feed_dict)) + ims = np.concatenate(ims, axis=0) + assert ims.shape[0] == num + ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255) + ims = np.uint8(ims) + return ims + # TODO: make a version of sample() that includes a callback function to save ims somewhere instead of keeping + # them in memory. |