aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVee9ahd1 <[email protected]>2019-10-06 22:25:28 -0400
committerVee9ahd1 <[email protected]>2019-10-06 22:30:14 -0400
commita7466accc6fc68f117444d7231f486013c8c6b6b (patch)
treea8639eb00c17288be3ac2c7822f3908832a83880
parent90cf4e52f6f916d1ecd7703b387283ffa3b64bde (diff)
tensorflow 2.0 compatibility and a little bit of cleanup
-rw-r--r--gantools/biggan.py29
1 files changed, 18 insertions, 11 deletions
diff --git a/gantools/biggan.py b/gantools/biggan.py
index 71228ae..fbfde98 100644
--- a/gantools/biggan.py
+++ b/gantools/biggan.py
@@ -1,19 +1,11 @@
# methods for setting up and interacting with biggan
-import tensorflow as tf
+import tensorflow.compat.v1 as tf
import tensorflow_hub as hub
import numpy as np
from itertools import cycle
-#-----------------------------------------------------------------
-# fix "could not create cudnn handle" error
-# see: https://github.com/tensorflow/tensorflow/issues/24496
-from tensorflow.compat.v1 import ConfigProto
-from tensorflow.compat.v1 import InteractiveSession
-config = ConfigProto()
-config.gpu_options.allow_growth = True
-#-----------------------------------------------------------------
-session = InteractiveSession(config=config)
+#session = InteractiveSession(config=config)
MODULE_PATH = 'https://tfhub.dev/deepmind/biggan-512/2'
@@ -22,6 +14,13 @@ class BigGAN(object):
def __init__(self, module_path=MODULE_PATH):
tf.reset_default_graph()
print('Loading BigGAN module from:', module_path)
+
+ #-----------------------------------------------------------------
+ # fix "RuntimeError: Exporting/importing meta graphs is not
+ # supported when eager execution is enabled." error when importing
+ # the tfhub module
+ tf.disable_eager_execution()
+ #-----------------------------------------------------------------
module = hub.Module(module_path)
self.inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
for k, v in module.get_input_info_dict().items()}
@@ -34,7 +33,15 @@ class BigGAN(object):
# initialize/instantiate tf variables
initializer = tf.global_variables_initializer()
- self.sess = tf.Session()
+
+ #-----------------------------------------------------------------
+ # fix "could not create cudnn handle" error
+ # see: https://github.com/tensorflow/tensorflow/issues/24496
+ config = tf.ConfigProto()
+ config.gpu_options.allow_growth = True
+ #-----------------------------------------------------------------
+
+ self.sess = tf.Session(config=config)
self.sess.run(initializer)
# NOTE: use save callback to save images once per batch. return type changes to None.