Перейти к содержимому
Compvision.ru
Pechkin80

Замарозка модели.

Recommended Posts

Добрый день, подскажите пожалуйста как правильно "заморозить" модель tensorflow, сохранённую с помощью tf.saved_model.simple_save как ./out/saved_model.pb ?

#init_op = tf.initialize_all_variables()
with tf.keras.backend.get_session() as sess:
    sess.run(init_op)
    tf.saved_model.simple_save(
        sess,
        export_path,
        inputs={'input': model.input},
        outputs={'output': model.output})
    saver.save(sess, 'tmp/my-weights')


Пробовал приспособить freeze_graph.freeze_graph, но безуспешно. Непонятно откуда брать outoute_node когда беру имя выходного узла из модели то в ответ читаю что данные узел не в графе.

Щас вообще беда, даже чекпоинт не могу сохранить, который вроде нужен для замарозки:

def write_temporary():
    init_op = tf.initialize_all_variables()
    export_path = './out'
    with tf.Session(graph=tf.Graph()) as sess:
        tf.saved_model.loader.load(sess, [tag_constants.SERVING], export_path)
        graph = tf.get_default_graph()
        tf.train.write_graph(graph, './outtmp', 'model_saved.pbtxt')
        #print([node.name for node in graph.as_graph_def().node])
        #sess.run(init_op)
    with tf.Session(graph=tf.Graph()) as sess:
        tf.saved_model.loader.load(sess, [tag_constants.SERVING], export_path)
        graph = tf.get_default_graph()
        saver = tf.train.Saver([graph])
        #
        #tf.train.write_graph(graph, './outtmp', 'model_saved.pbtxt')
        saver.save(sess, './outtmp/model.ckpt')
        #print([node.name for node in graph.as_graph_def().node])
        sess.run(init_op)
    
write_temporary()

 

Поделиться сообщением


Ссылка на сообщение
Поделиться на других сайтах

использую такой скрипт для иvпорта в opencv, сверточные слои работают

 

#region Imports
from tensorflow.python.tools import freeze_graph
import tensorflow as tf
import numpy as np
from keras.models import Model # basic class for specifying and training a neural network
from keras.layers import Input, Dense # the two types of neural network layer we will be using
from keras.utils import np_utils # utilities for one-hot encoding of ground truth values
from keras.models import Sequential
from keras.layers import Input, Convolution2D, MaxPooling2D, Dense, Dropout, Flatten, Reshape, Activation
from keras.engine.topology  import InputLayer
from keras.layers.core import Dense as DenseLayer
from keras import backend as K
import os
from tensorflow.core import framework
#endregion




def find_all_nodes(graph_def, **kwargs):
    for node in graph_def.node:
        for key, value in kwargs.items():
            if getattr(node, key) != value:
                break
        else:
            yield node
    raise StopIteration


def find_node(graph_def, **kwargs):
    try:
        return next(find_all_nodes(graph_def, **kwargs))
    except StopIteration:
        raise ValueError(
            'no node with attributes: {}'.format(
                ', '.join("'{}': {}".format(k, v) for k, v in kwargs.items())))


def walk_node_ancestors(graph_def, node_def, exclude=set()):
    openlist = list(node_def.input)
    closelist = set()
    while openlist:
        name = openlist.pop()
        if name not in exclude:
            node = find_node(graph_def, name=name)
            openlist += list(node.input)
            closelist.add(name)
    return closelist


def remove_nodes_by_name(graph_def, node_names):
    for i in reversed(range(len(graph_def.node))):
        if graph_def.node[i].name in node_names:
            del graph_def.node[i]


def make_shape_node_const(node_def, tensor_values):
    node_def.op = 'Const'
    node_def.ClearField('input')
    node_def.attr.clear()
    node_def.attr['dtype'].type = framework.types_pb2.DT_INT32
    tensor = node_def.attr['value'].tensor
    tensor.dtype = framework.types_pb2.DT_INT32
    tensor.tensor_shape.dim.add()
    tensor.tensor_shape.dim[0].size = len(tensor_values)
    for value in tensor_values:
        tensor.tensor_content += value.to_bytes(4, 'little')
    output_shape = node_def.attr['_output_shapes']
    output_shape.list.shape.add()
    output_shape.list.shape[0].dim.add()
    output_shape.list.shape[0].dim[0].size = len(tensor_values)


def make_cv2_compatible(graph_def):
    # A reshape node needs a shape node as its second input to know how it
    # should reshape its input tensor.
    # When exporting a model using Keras, this shape node is computed
    # dynamically using `Shape`, `StridedSlice` and `Pack` operators.
    # Unfortunately those operators are not supported yet by the OpenCV API.
    # The goal here is to remove all those unsupported nodes and hard-code the
    # shape layer as a const tensor instead.
    for reshape_node in find_all_nodes(graph_def, op='Reshape'):

        # Get a reference to the shape node
        shape_node = find_node(graph_def, name=reshape_node.input[1])

        # Find and remove all unsupported nodes
        garbage_nodes = walk_node_ancestors(graph_def, shape_node,
                                            exclude=[reshape_node.input[0]])
        remove_nodes_by_name(graph_def, garbage_nodes)

        # Infer the shape tensor from the reshape output tensor shape
        if not '_output_shapes' in reshape_node.attr:
            raise AttributeError(
                'cannot infer the shape node value from the reshape node. '
                'Please set the `add_shapes` argument to `True` when calling '
                'the `Session.graph.as_graph_def` method.')
        output_shape = reshape_node.attr['_output_shapes'].list.shape[0]
        output_shape = [dim.size for dim in output_shape.dim]

        # Hard-code the inferred shape in the shape node
        make_shape_node_const(shape_node, output_shape[1:])

def Save2(model_name,model):
    sess = K.get_session()
    graph_def = sess.graph.as_graph_def(add_shapes=True)
    graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, [model.output.name.split(':')[0]])
    make_cv2_compatible(graph_def)

    # Print the graph nodes
    print('\n'.join(node.name for node in graph_def.node))

    # Save the graph as a binary protobuf2 file
    tf.train.write_graph(graph_def, '', model_name+'.pb', as_text=False)

 

само сохранение такое (model  это keras, backend tf):

import TFfreez
TFfreez.Save2("d://number.pb",model )

 

Поделиться сообщением


Ссылка на сообщение
Поделиться на других сайтах

Подкину еще один скрипт, выдранный из работающего проекта. FreezeMe.py  тут некоторых вещей нехватает, но может чем и пригодится. И еще, чтобы смотреть что записалось попробуйте Netron, я его как то уже упоминал. 

Скриншот 2019-09-08 15.03.53.png

 

  • Like 1

Поделиться сообщением


Ссылка на сообщение
Поделиться на других сайтах
В 9/8/2019 at 15:05, Smorodov сказал:

Подкину еще один скрипт, выдранный из работающего проекта. FreezeMe.py  тут некоторых вещей нехватает, но может чем и пригодится. И еще, чтобы смотреть что записалось попробуйте Netron, я его как то уже упоминал. 

Скриншот 2019-09-08 15.03.53.png

 

Спасибо, интересная утилита. Но у меня после всех оптимизаций кудато потерялись атрибуты и код:
 

		GraphDef graph_def;
...
		auto shape = graph_def.node().Get(0).attr().at("shape").shape();
        for (int i = 0; i < shape.dim_size(); i++) {
            std::cout << shape.dim(0).size()<<std::endl;
        }

приводит к исключению:

Цитата

 

[libprotobuf FATAL /usr/local/include/google/protobuf/map.h:1064] CHECK failed: it != end(): key not found: shape

terminate called after throwing an instance of 'google::protobuf::FatalException'

what(): CHECK failed: it != end(): key not found: shape

22:06:41: The program has unexpectedly finished.

 

 

Поделиться сообщением


Ссылка на сообщение
Поделиться на других сайтах

Создайте учётную запись или войдите для комментирования

Вы должны быть пользователем, чтобы оставить комментарий

Создать учётную запись

Зарегистрируйтесь для создания учётной записи. Это просто!

Зарегистрировать учётную запись

Войти

Уже зарегистрированы? Войдите здесь.

Войти сейчас


  • Сейчас на странице   0 пользователей

    Нет пользователей, просматривающих эту страницу

×