Cómo usar TensorFlow con Java

El aprendizaje automático está ganando popularidad y uso en todo el mundo. Ya ha cambiado drásticamente la forma en que se construyen ciertas aplicaciones y es probable que cont...

Introducción

El aprendizaje automático está ganando popularidad y uso en todo el mundo. Ya ha cambiado drásticamente la forma en que se construyen ciertas aplicaciones y probablemente seguirá siendo una parte enorme (y cada vez mayor) de nuestra vida diaria.

No hay forma de endulzarlo, el aprendizaje automático no es simple. Es bastante desalentador y puede parecer muy complejo para muchos.

Empresas como Google se encargaron de acercar los conceptos de Machine Learning a los desarrolladores y permitirles poco a poco, con una gran ayuda, dar sus primeros pasos.

Así nacieron frameworks como TensorFlow.

¿Qué es TensorFlow?

TensorFlow es un marco de aprendizaje automático de código abierto desarrollado por Google en Python y C++.

Ayuda a los desarrolladores a adquirir datos fácilmente, preparar y entrenar modelos, predecir estados futuros y realizar aprendizaje automático a gran escala.

Con él, podemos entrenar y ejecutar redes neuronales profundas que se utilizan con mayor frecuencia para Reconocimiento óptico de caracteres, [Reconocimiento/clasificación de imágenes]( /reconocimiento de imágenes -in-python-with-tensorflow-and-keras/)Procesamiento-del-lenguaje-natural, etc.

Tensores y operaciones

TensorFlow se basa en gráficos computacionales, que puedes imaginar como un gráfico clásico con nodos y bordes.

Cada nodo se denomina operación, y toman cero o más tensores y producen cero o más tensores. Una operación puede ser muy simple, como una suma básica, pero también puede ser muy compleja.

Los tensores se representan como bordes del gráfico y son la unidad de datos central. Realizamos diferentes funciones en estos tensores a medida que los alimentamos a las operaciones. Pueden tener una o varias dimensiones, que a veces se denominan sus [rangos] (http://mathworld.wolfram.com/TensorRank.html) - (Escalar: rango 0, Vector: rango 1, Matriz: rango 2 )

Estos datos fluyen a través del gráfico computacional a través de tensores, afectados por las operaciones, de ahí el nombre TensorFlow.

Los tensores pueden almacenar datos en cualquier cantidad de dimensiones, y hay tres tipos principales de tensores: marcadores de posición, variables y constantes.

Instalación de TensorFlow {#instalación de TensorFlow}

Usando Experto, instalar TensorFlow es tan fácil como incluir la dependencia:

1
2
3
4
5
<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>tensorflow</artifactId>
  <version>1.13.1</version>
</dependency>

Si su dispositivo es compatible con Soporte de GPU, use estas dependencias:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>libtensorflow</artifactId>
  <version>1.13.1</version>
</dependency>

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>libtensorflow_jni_gpu</artifactId>
  <version>1.13.1</version>
</dependency>

Puedes verificar la versión de TensorFlow actualmente instalada usando el objeto TensorFlow:

1
System.out.println(TensorFlow.version());

API Java de TensorFlow

Las ofertas de Java API TensorFlow están contenidas dentro del paquete org.tensorflow. Actualmente es experimental, por lo que no se garantiza que sea estable.

Tenga en cuenta que el único lenguaje totalmente compatible con TensorFlow es Python y que la API de Java no es tan funcional.

Nos presenta nuevas clases, una interfaz, una enumeración y una excepción.

Clases

Las nuevas clases introducidas a través de la API son:

  • Graph: un gráfico de flujo de datos que representa un cálculo de TensorFlow
  • Operación: un nodo gráfico que realiza cálculos en tensores
  • OperationBuilder: una clase constructora para Operaciones
  • Output<T>: un identificador simbólico para un tensor producido por una operación
  • SavedModelBundle: representa un modelo cargado desde el almacenamiento.
  • SavedModelBundle.Loader: proporciona opciones para cargar un modelo guardado
  • Servidor: un servidor TensorFlow en proceso, para usar en capacitación distribuida
  • Session: Controlador para la ejecución de gráficos
  • Session.Run: Tensores de salida y metadatos obtenidos al ejecutar una sesión
  • Session.Runner: Ejecutar Operaciones y evaluar Tensores
  • Forma: La forma posiblemente parcialmente conocida de un tensor producido por una operación
  • Tensor<T>: Una matriz multidimensional tipada estáticamente cuyos elementos son de un tipo descrito por T
  • TensorFlow: métodos de utilidad estática que describen el tiempo de ejecución de TensorFlow
  • Tensores: métodos de fábrica con seguridad de tipos para crear objetos Tensor
Enumeración
  • DataType: representa el tipo de elementos en un tensor como una enumeración
Interfaz
  • Operand<T>: Interfaz implementada por operandos de una operación TensorFlow
Excepción
  • TensorFlowException: excepción no verificada lanzada al ejecutar TensorFlow Graphs

Si comparamos todo esto con el módulo tf en Python, hay una diferencia obvia. La API de Java no tiene casi la misma cantidad de funcionalidad, al menos por ahora.

Gráficos

Como se mencionó anteriormente, TensorFlow se basa en gráficos computacionales, donde org.tensorflow.Graph es la implementación de Java.

Nota: sus instancias son seguras para subprocesos, aunque necesitamos liberar explícitamente los recursos utilizados por Graph una vez que hayamos terminado con él.

Comencemos con un gráfico vacío:

1
Graph graph = new Graph();

Este gráfico no significa mucho, está vacío. Para hacer cualquier cosa con él, primero tenemos que cargarlo con Operaciones.

Para cargarlo con operaciones, usamos el método opBuilder(), que devuelve un objeto OperationBuilder que agregará las operaciones a nuestro gráfico una vez que llamemos al método .build().

Constantes

Agreguemos una constante a nuestro gráfico:

1
2
3
4
Operation x = graph.opBuilder("Const", "x")
               .setAttr("dtype", DataType.FLOAT)
               .setAttr("value", Tensor.create(3.0f))
               .build(); 

Marcadores de posición

Los marcadores de posición son un "tipo" de variable que no tiene un valor en la declaración. Sus valores serán asignados en una fecha posterior. Esto nos permite construir gráficos con operaciones sin ningún dato real:

1
2
3
Operation y = graph.opBuilder("Placeholder", "y")
        .setAttr("dtype", DataType.FLOAT)
        .build();

Funciones

Y ahora, finalmente, para redondear esto, necesitamos agregar ciertas funciones. Estos pueden ser tan simples como la multiplicación, la división o la suma, o tan complejos como las multiplicaciones de matrices. Al igual que antes, definimos funciones usando el método .opBuilder():

1
2
3
4
Operation xy = graph.opBuilder("Mul", "xy")
  .addInput(x.output(0))
  .addInput(y.output(0))
  .build();         

Nota: Estamos usando output(0) ya que un tensor puede tener más de una salida.

Visualización de gráficos

Lamentablemente, la API de Java aún no incluye ninguna herramienta que le permita visualizar gráficos como lo haría en Python. Cuando la API de Java se actualice, también lo hará este artículo.

Sesiones

Como se mencionó anteriormente, una ‘Sesión’ es el controlador para la ejecución de un ‘Gráfico’. Encapsula el entorno en el que se ejecutan ‘Operaciones’ y ‘Gráficos’ para calcular ‘Tensores’.

Lo que esto significa es que los tensores en nuestro gráfico que construimos en realidad no tienen ningún valor, ya que no ejecutamos el gráfico dentro de una sesión.

Primero agreguemos el gráfico a una sesión:

1
Session session = new Session(graph);

Nuestro cálculo simplemente multiplica el valor de x e y. Para ejecutar nuestro gráfico y calcularlo, buscamos() la operación xy y le damos los valores x e y:

1
2
Tensor tensor = session.runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0);
System.out.println(tensor.floatValue());

Ejecutar este fragmento de código producirá:

1
10.0f

Guardar modelos en Python y cargarlos en Java {#guardar modelos en Python y cargar en Java}

Esto puede sonar un poco extraño, pero dado que Python es el único lenguaje compatible, la API de Java todavía no tiene la funcionalidad para guardar modelos.

Esto significa que la API de Java está pensada solo para el caso de uso de servicio, al menos hasta que sea totalmente compatible con TensorFlow. Al menos, podemos entrenar y guardar modelos en Python y luego cargarlos en Java para servirlos, usando la clase SavedModelBundle:

1
2
3
4
SavedModelBundle model = SavedModelBundle.load("./model", "serve"); 
Tensor tensor = model.session().runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0);  

System.out.println(tensor.floatValue());

Conclusión

TensorFlow es un marco poderoso, robusto y ampliamente utilizado. Se está mejorando constantemente y últimamente se ha introducido en nuevos lenguajes, incluidos Java y JavaScript.

Aunque la API de Java aún no tiene tanta funcionalidad como TensorFlow para Python, aún puede servir como una buena introducción a TensorFlow para desarrolladores de Java.