Source code for textattack.constraints.grammaticality.language_models.google_language_model.lm_utils

"""
Utils for loading 1B word benchmark dataset.
------------------------------------------------

    Author: Moustafa Alzantot (malzantot@ucla.edu)
    All rights reserved.
"""

import sys

from textattack.shared.utils import LazyLoader

tf = LazyLoader("tensorflow", globals(), "tensorflow")
protobuf = LazyLoader("google.protobuf", globals(), "google.protobuf")


[docs]def LoadModel(sess, graph, gd_file, ckpt_file): """Load the model from GraphDef and AttackCheckpoint. Args: gd_file: GraphDef proto text file. ckpt_file: TensorFlow AttackCheckpoint file. Returns: TensorFlow session and tensors dict. """ tf.get_logger().setLevel("INFO") with graph.as_default(): sys.stderr.write("Recovering graph.\n") with tf.io.gfile.GFile(gd_file) as f: s = f.read() gd = tf.compat.v1.GraphDef() protobuf.text_format.Merge(s, gd) tf.compat.v1.logging.info("Recovering Graph %s", gd_file) t = {} [ t["states_init"], t["lstm/lstm_0/control_dependency"], t["lstm/lstm_1/control_dependency"], t["softmax_out"], t["class_ids_out"], t["class_weights_out"], t["log_perplexity_out"], t["inputs_in"], t["targets_in"], t["target_weights_in"], t["char_inputs_in"], t["all_embs"], t["softmax_weights"], t["global_step"], ] = tf.import_graph_def( gd, {}, [ "states_init", "lstm/lstm_0/control_dependency:0", "lstm/lstm_1/control_dependency:0", "softmax_out:0", "class_ids_out:0", "class_weights_out:0", "log_perplexity_out:0", "inputs_in:0", "targets_in:0", "target_weights_in:0", "char_inputs_in:0", "all_embs_out:0", "Reshape_3:0", "global_step:0", ], name="", ) sys.stderr.write("Recovering checkpoint %s\n" % ckpt_file) sess.run("save/restore_all", {"save/Const:0": ckpt_file}) sess.run(t["states_init"]) return t