Visualisation de l'état de tir de la couche cachée du modèle appris dans le tutoriel TensorFlow MNIST

Aperçu

Pour comprendre l'apprentissage automatique, j'ai essayé le didacticiel MNIST de TensorFlow. Nous avons également mis en œuvre une application qui vous permet de saisir des caractères manuscrits à partir d'un navigateur et visualisé l'état d'allumage et les résultats de discrimination de la couche cachée.

Environnement d'exécution principal

Tutoriel TensorFlow MNIST

Suivez le tutoriel MNIST (Deep MNIST for Experts) sur le site officiel de TensorFlow. Ici, un réseau est construit avec un ensemble d'une couche de pliage et d'une couche de mise en commun en deux étapes, et une couche entièrement connectée en une étape. Ce qui a changé par rapport au didacticiel, c'est qu'il rend la session interactive et ajoute le processus d'enregistrement des poids et des biais après l'entraînement du modèle.

Une fois exécuté, l'apprentissage est effectué (environ 1 heure), et finalement une précision d'environ 99% peut être obtenue. Après l'exécution, les poids et biais entraînés sont enregistrés sous forme de fichier binaire.

train.py


import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

def weight_variable(shape, name):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial, name=name)

def bias_variable(shape, name):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial, name=name)

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
x_image = tf.reshape(x, [-1, 28, 28, 1])

W_conv1 = weight_variable([5, 5, 1, 32], name='W_conv1')
b_conv1 = bias_variable([32], name='b_conv1')

h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

W_conv2 = weight_variable([5, 5, 32, 64], name='W_conv2')
b_conv2 = bias_variable([64], name='b_conv2')

h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

W_fc1 = weight_variable([7 * 7 * 64, 1024], name='W_fc1')
b_fc1 = bias_variable([1024], name='b_fc1')

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

W_fc2 = weight_variable([1024, 10], name='W_fc2')
b_fc2 = bias_variable([10], name='b_fc2')

y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

for i in range(20000):
    batch = mnist.train.next_batch(50)
    if i % 100 == 0:
        train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})
    print('step %d, training accuracy %g' % (i, train_accuracy))
    train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

print('test accuracy %g' % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

saver = tf.train.Saver({'W_conv1':W_conv1, 'b_conv1':b_conv1, 'W_conv2':W_conv2, 'b_conv2':b_conv2, 'W_fc1':W_fc1, 'b_fc1':b_fc1, 'W_fc2':W_fc2, 'b_fc2':b_fc2})
saver.save(sess, 'cnn_model')

Interface de saisie de caractères manuscrits

Implémentez une interface de saisie pour les caractères manuscrits à l'aide de HTML5 Canvas. Chaque fois que l'entrée est terminée, la valeur de pixel de l'image d'entrée dessinée est POSTÉE sur le serveur au format JSON et le résultat de la discrimination est acquis (décrit plus loin).

templates/index.html


<!DOCTYPE html>
<html>
<head>
    <title>TensorFlow MNIST Demo</title>
    <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.2.1/jquery.min.js"></script>
    <link href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-BVYiiSIFeK1dGmJRAkycuHAHRg32OmUcww7on3RYdg4Va+PmSTsz/K68vbdEjh4u" crossorigin="anonymous">
</head>
<body>

<div class="container">
    <h1>TensorFlow MNIST Demo</h1>
    <div class="row">
        <div class="col-md-4">
            <h2>Input</h2>
            <canvas id="input-image" width="28" height="28" style="width: 280px; height: 280px;"></canvas>
        </div>

        <div class="col-md-8">
            <h2>Prediction</h2>
            <table class="table table-bordered table-striped">
                <tr>
                    <th width="10%">0</th>
                    <th width="10%">1</th>
                    <th width="10%">2</th>
                    <th width="10%">3</th>
                    <th width="10%">4</th>
                    <th width="10%">5</th>
                    <th width="10%">6</th>
                    <th width="10%">7</th>
                    <th width="10%">8</th>
                    <th width="10%">9</th>
                </tr>
                <tr>
                    <td id="score0">-</td>
                    <td id="score1">-</td>
                    <td id="score2">-</td>
                    <td id="score3">-</td>
                    <td id="score4">-</td>
                    <td id="score5">-</td>
                    <td id="score6">-</td>
                    <td id="score7">-</td>
                    <td id="score8">-</td>
                    <td id="score9">-</td>
                </tr>
            </table>
            <button class="btn btn-large" id="clear">Clear</button>
        </div>
    </div>

    <div class="row">
        <h2>h_conv1</h2>
        <canvas id="conv1" width="260" height="132" style="width: 1300px; height: 660px;"></canvas>
    </div>
    <div class="row">
        <h2>h_pool1</h2>
        <canvas id="pool1" width="260" height="132" style="width: 1300px; height: 660px;"></canvas>
    </div>
    <div class="row">
        <h2>h_conv2</h2>
        <canvas id="conv2" width="260" height="66" style="width: 1300px; height: 330px;"></canvas>
    </div>    
    <div class="row">
        <h2>h_pool2</h2>
        <canvas id="pool2" width="260" height="66" style="width: 1300px; height: 330px;"></canvas>
    </div>
</div>

<script type="text/javascript">
    
    var canvas = document.getElementById('input-image');
    var context = canvas.getContext('2d');    
    var moveFlag = false;
    var Xpoint;
    var Ypoint;
    var offsetX = canvas.getBoundingClientRect().left;
    var offsetY = canvas.getBoundingClientRect().top;
    var size = 28;
    var scale = 10;

    context.lineWidth = 1;
    context.strokeStyle = '#FFF';
    context.fillStyle = '#000';
    context.fillRect(0, 0, size, size);

    canvas.addEventListener('mousedown', startPoint, false);
    canvas.addEventListener('mousemove', movePoint, false);
    canvas.addEventListener('mouseup', endPoint, false);

    document.getElementById('clear').addEventListener('click', clear, false);
    updateImage();

    function startPoint(e) {
        e.preventDefault();
        context.beginPath();
        Xpoint = Math.round((e.pageX - offsetX) / scale);
        Ypoint = Math.round((e.pageY - offsetY) / scale);
        context.moveTo(Xpoint, Ypoint);
    }

    function movePoint(e) {
        if(e.buttons === 1 || e.witch === 1) {
            Xpoint = Math.round((e.pageX - offsetX) / scale);
            Ypoint = Math.round((e.pageY - offsetY) / scale);
            moveFlag = true;
            context.lineTo(Xpoint, Ypoint);
            context.stroke();
        }
    }

    function endPoint(e) {
        if (moveFlag === true) {
            context.lineTo(Xpoint, Ypoint);
            context.stroke();
        }
        moveFlag = false;
        updateImage();
    }

    function clear() {
        context.fillStyle = '#000';
        context.fillRect(0, 0, size, size);
        updateImage();
    }

    function updateImage() {
        //Voir ci-dessous
    }

</script>

</body>
</html>

Application de saisie de caractères manuscrits

En même temps que le démarrage de Flask, créez le même réseau créé dans le didacticiel MNIST et chargez les poids et biais entraînés.

index.py


from flask import Flask, render_template, request, redirect, jsonify
import numpy as np
import tensorflow as tf

def weight_variable(shape, name):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial, name=name)

def bias_variable(shape, name):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial, name=name)

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
x_image = tf.reshape(x, [-1, 28, 28, 1])

W_conv1 = weight_variable([5, 5, 1, 32], name='W_conv1')
b_conv1 = bias_variable([32], name='b_conv1')

h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

W_conv2 = weight_variable([5, 5, 32, 64], name='W_conv2')
b_conv2 = bias_variable([64], name='b_conv2')

h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

W_fc1 = weight_variable([7 * 7 * 64, 1024], name='W_fc1')
b_fc1 = bias_variable([1024], name='b_fc1')

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

W_fc2 = weight_variable([1024, 10], name='W_fc2')
b_fc2 = bias_variable([10], name='b_fc2')

y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

saver = tf.train.Saver({'W_conv1':W_conv1, 'b_conv1':b_conv1, 'W_conv2':W_conv2, 'b_conv2':b_conv2, 'W_fc1':W_fc1, 'b_fc1':b_fc1, 'W_fc2':W_fc2, 'b_fc2':b_fc2})
saver.restore(sess, 'cnn_model')

app = Flask(__name__)

if __name__ == '__main__':
    app.debug = True
    app.run(host='0.0.0.0')

L'action d'index affiche l'interface de saisie de caractères manuscrits et le résultat de l'inférence. Il n'y a rien de spécial côté serveur.

index.py


@app.route('/')
def index():
    return render_template('index.html')

Dans l'action de prédiction, la valeur de pixel de l'image d'entrée POSTed est entrée dans le réseau et le résultat de la discrimination est renvoyé dans JSON.

index.py


@app.route('/predict', methods=['POST'])
def predict():
    result = {}
    result['h_conv1'] = sess.run(h_conv1, feed_dict={x: [request.get_json()], keep_prob: 1.0})[0].transpose(2, 0, 1).tolist()
    result['h_pool1'] = sess.run(h_pool1, feed_dict={x: [request.get_json()], keep_prob: 1.0})[0].transpose(2, 0, 1).tolist()
    result['h_conv2'] = sess.run(h_conv2, feed_dict={x: [request.get_json()], keep_prob: 1.0})[0].transpose(2, 0, 1).tolist()
    result['h_pool2'] = sess.run(h_pool2, feed_dict={x: [request.get_json()], keep_prob: 1.0})[0].transpose(2, 0, 1).tolist()
    result['y_conv']  = sess.run(y_conv,  feed_dict={x: [request.get_json()], keep_prob: 1.0})[0].tolist()

    return jsonify(result)

Traitement de la communication et de l'affichage des résultats par Ajax. La zone autour du canevas est encombrée, mais l'état de déclenchement du calque masqué est simplement dessiné dans un carré.

templates/index.html


<script>
function updateImage() {
    rawImage = context.getImageData(0, 0, size, size);
    image = Array.from(rawImage.data.filter(function(element, index, array) {
        return index % 4 === 0;
    }));
    
    $.ajax({
        url: '/predict',
        type: 'POST',
        data: JSON.stringify(image),
        contentType: 'application/JSON',
        dataType : 'JSON',
        success: function(data, status, xhr) {
            console.log('success');
            console.log(data);

            var canvas = document.getElementById('conv1');
            var context = canvas.getContext('2d');
            context.fillStyle = '#ddf';
            context.fillRect(0, 0, 260, 132);

            for (var i = 0; i < data['h_conv1'].length; i++) {
                var x0 = 4 + (i%8) * 32;
                var y0 = 4 + Math.floor(i/8) * 32;
                for (var j = 0; j < 28; j++) {
                    for (var k = 0; k < 28; k++) {
                        val = Math.round(data['h_conv1'][i][j][k]);
                        context.fillStyle = 'rgb('+val+','+val+','+val+')';
                        context.fillRect(x0+k, y0+j, 1, 1);
                    }
                }
            }

            canvas = document.getElementById('pool1');
            context = canvas.getContext('2d');
            context.fillStyle = '#fdd';
            context.fillRect(0, 0, 260, 132);

            for (var i = 0; i < data['h_pool1'].length; i++) {
                var x0 = 4 + (i%8) * 32;
                var y0 = 4 + Math.floor(i/8) * 32;
                for (var j = 0; j < 14; j++) {
                    for (var k = 0; k < 14; k++) {
                        val = Math.round(data['h_pool1'][i][j][k]);
                        context.fillStyle = 'rgb('+val+','+val+','+val+')';
                        context.fillRect(x0+k*2, y0+j*2, 2, 2);
                    }
                }
            }

            canvas = document.getElementById('conv2');
            context = canvas.getContext('2d');
            context.fillStyle = '#ddf';
            context.fillRect(0, 0, 260, 132);

            for (var i = 0; i < data['h_conv2'].length; i++) {
                var x0 = 3 + (i%16) * 16;
                var y0 = 2 + Math.floor(i/16) * 16;
                for (var j = 0; j < 14; j++) {
                    for (var k = 0; k < 14; k++) {
                        val = Math.round(data['h_conv2'][i][j][k]);
                        context.fillStyle = 'rgb('+val+','+val+','+val+')';
                        context.fillRect(x0+k, y0+j, 1, 1);
                    }
                }
            }

            canvas = document.getElementById('pool2');
            context = canvas.getContext('2d');
            context.fillStyle = '#fdd';
            context.fillRect(0, 0, 260, 132);

            for (var i = 0; i < data['h_conv2'].length; i++) {
                var x0 = 3 + (i%16) * 16;
                var y0 = 2 + Math.floor(i/16) * 16;
                for (var j = 0; j < 7; j++) {
                    for (var k = 0; k < 7; k++) {
                        val = Math.round(data['h_pool2'][i][j][k]);
                        context.fillStyle = 'rgb('+val+','+val+','+val+')';
                        context.fillRect(x0+k*2, y0+j*2, 2, 2);
                    }
                }
            }

            $('[id^="score"]').removeClass('warning');
            for (var i = 0; i < 10; i++) {
                $('#score'+i).text(Math.round(data['y_conv'][i]));
                if (Math.max.apply(null, data['y_conv']) === data['y_conv'][i]) {
                    $('#score'+i).addClass('warning');
                }
            }

        },
        error: function(data, status, error) {
            console.log('error');
            console.log(error);
        }
    });
}
</script>

Résultat d'exécution

J'ai essayé d'entrer "5" à la main. Le score de «5» est le plus élevé et il peut être correctement identifié. Screen Shot 2017-07-24 at 17.07.13.png

Résultat de sortie de la couche de convolution par le filtre 5x5 du 1er étage. Impression qu'il réagit aux formes telles que la gauche et la droite des lignes verticales, le haut et le bas des lignes horizontales et les lignes diagonales. Screen Shot 2017-07-24 at 17.07.35.png

Résultat de sortie de la couche de mise en commun de la valeur maximale 2x2 du premier étage. Screen Shot 2017-07-24 at 17.07.45.png

Résultat de sortie de la couche de convolution par le filtre 5x5 du 2ème étage. Réagissez-vous à un modèle légèrement plus complexe que la première ligne? Screen Shot 2017-07-24 at 17.07.57.png

Résultat de sortie de la couche de mise en commun de valeur maximale 2x2 du deuxième étage. Screen Shot 2017-07-24 at 17.08.08.png

Après cela, il y a une couche entièrement connectée avec 1024 unités, qui mène à une couche de sortie avec 10 unités.

référence

Recommended Posts

Visualisation de l'état de tir de la couche cachée du modèle appris dans le tutoriel TensorFlow MNIST
J'ai essayé le tutoriel MNIST de tensorflow pour les débutants.
J'ai fait une démo qui permet au modèle formé dans le didacticiel mnist de Tensorflow de distinguer les nombres manuscrits écrits sur la toile.
Utilisez le vecteur appris par word2vec dans la couche Embedding de LSTM
Tutoriel TensorFlow J'ai essayé MNIST 3rd
Créez une API REST à l'aide du modèle appris dans Lobe et TensorFlow Serving.
Record of TensorFlow mnist Expert Edition (Visualisation de TensorBoard)
Réalisation du didacticiel TensorFlow MNIST pour débutants en ML
Apprentissage supervisé de mnist dans la couche entièrement connectée, clustering et évaluation de l'étape finale
Spécifier le modèle d'éclairage du matériau SCN dans Pythonista
Comptez le nombre de paramètres dans le modèle d'apprentissage en profondeur
L'idée de Tensorflow a appris de la fabrication de pommes de terre
Examinez les paramètres de RandomForestClassifier dans le didacticiel Kaggle / Titanic
Comment utiliser le modèle appris dans Lobe en Python
J'ai essayé de refactoriser le modèle CNN de TensorFlow en utilisant TF-Slim
[Blender] Connaître l'état de sélection des objets cachés sur l'outliner
L'histoire de la rétrogradation de la version de tensorflow dans la démo de Mask R-CNN.