question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How to save trained models?

See original GitHub issue

I would like to save my trained models for future uses, especially as I am considering building my PhD predictors based on this library. Is there any way to add the .save attribute to the demos/models.py models to achieve something similar to what it is explained here?
Many thanks

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:5

github_iconTop GitHub Comments

2reactions
alvarosgcommented, Feb 3, 2020

Here’s some examples for check-pointing and model storage using Sonnet 2 and TF2: https://github.com/deepmind/sonnet#tensorflow-checkpointing

1reaction
alvarosgcommented, Aug 15, 2019

Because graph_nets is not built on top of Keras, saving and restoring models is slightly different. To save the model you can need to use a tf.train.Saver and saver.save, and to restore it, you should build the tensorflow graph in the same way, and then use saver.restore. See example below:

def get_input_graphs():
  # Some function that returns a graphs.GraphsTuple

def build_and_connect_model(input_graphs):
  graph_network = modules.GraphNetwork(
      edge_model_fn=lambda: snt.Linear(output_size=4),
      node_model_fn=lambda: snt.Linear(output_size=4),
      global_model_fn=lambda: snt.Linear(output_size=4))
  output_graphs = graph_network(input_graphs)
  return graph_network, output_graphs

def log_variables(sess, variables):
  vars_out = sess.run(variables)
  print([(var.name, var_out.flatten()[:][0]) 
         for var, var_out in zip(variables, vars_out)])

  
# Saving it.
tf.reset_default_graph()
input_graphs = get_input_graphs()
graph_net, output_graphs = build_and_connect_model(input_graphs)
initializer = tf.global_variables_initializer()

saver = tf.train.Saver()
with tf.Session() as sess:
  sess.run(initializer)
  saver.save(sess, "/tmp/model")
  log_variables(sess, graph_net.variables)

# Reloading it later.
tf.reset_default_graph()
input_graphs = get_input_graphs()
graph_net, output_graphs = build_and_connect_model(input_graphs)
saver = tf.train.Saver()
with tf.Session() as sess:
  saver.restore(sess, "/tmp/model")
  log_variables(sess, graph_net.variables)

Hope this helps!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Save and load models | TensorFlow Core
To save weights manually, use tf.keras.Model.save_weights . By default, tf.keras —and the Model.save_weights method in particular—uses the ...
Read more >
Save and Load Machine Learning Models in Python with scikit ...
# Save Model Using Pickle · # Fit the model on training set · # save the model to disk · # some...
Read more >
How To Save Trained Machine Learning Models? - Medium
1. If you are working with StatsModel Machine Learning Models. 1.1 Save The Model. import statsmodels.api as sm
Read more >
Save trained model in Python - ProjectPro
Step 1 - Import the library · Step 2 - Setting up the Data · Step 3 - Training and Saving the model...
Read more >
Saving a machine learning Model - GeeksforGeeks
In machine learning, while working with scikit learn library, we need to save the trained models in a file and restore them in...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found