Skip to content

Quickstart

Welcome to traintool!

In this quickstart, we will train a few models on MNIST. This should give you a rough overview of what traintool can do.

You can follow along interactively in Google Colab (a free Jupyter notebook service):

Open In Colab

We highly recommend to use Colab for this tutorial because it gives you free GPU access, which makes training much faster. Important: To enable GPU support, click on "Runtime" -> "Change runtime type", select "GPU" and hit "Save".


First, let's install traintool:

!pip install -U git+https://github.com/jrieke/traintool

Next, we import traintool and load the mnist dataset (installed with traintool):

import traintool
import mnist
train_images = mnist.train_images()[:, None]  # add color dimension
train_labels = mnist.train_labels()
test_images = mnist.test_images()[:, None]
test_labels = mnist.test_labels()

print("Images shape:", train_images.shape)
print("Labels shape:", train_labels.shape)

As you can see, all data from the mnist package comes as numpy arrays. Images have the shape num samples x color channels x height x width. Note that traintool can handle numpy arrays like here as well as image files on your machine (see here).

Your first model

Let's train our first model! We will use a very simple model, a support vector classifier (called svc in traintool). Training it requires only one line of code:

Note: We use the config parameter num_samples here to train only on a subset of the data to make it faster.

svc = traintool.train("svc", 
                      train_data=[train_images, train_labels], 
                      test_data=[test_images, test_labels], 
                      config={"num_samples": 500})

That looks very simple – but under the hood, a lot of stuff happened:

1) traintool printed some general information about the experiment: Its ID, which model and configuration was used, where the model is saved and how you can load it later.

2) Then, it preprocessed the data. It automatically converted all data to the correct format and applied some light preprocessing that makes sense with this model.

3) It created and trained the model. Under the hood, traintool uses different frameworks for this step (e.g. scikit-learn or pytorch) but as a user, you don't have to worry about any of this. After training, traintool printed the resulting accuracies (should be 80-85 % here).

4) traintool automatically saved the model, console output and tensorboard logs into a time-stamped folder (see below).

Making predictions

To make a prediction with this model, simply use its predict function:

svc.predict(test_images[0])

This gives you a dictionary with the predicted class and probabilities for each class. Note that for now, predict can only process a single image at a time. As the train method, it works with numpy arrays and image files (see here).

Using other models

Ok, now what if you want to train a different model? traintool makes this very easy: You only have to call the train function with a different model name – no need to rewrite the implementation or change the data just because you use a model from a different framework!

Let's train a residual network (resnet18), a deep neural network from pytorch (make sure to use a GPU!):

resnet = traintool.train("resnet18", 
                         train_data=[train_images, train_labels],
                         test_data=[test_images, test_labels],
                         config={"batch_size": 128, "print_every": 10, "num_epochs": 2, "num_samples": 10000})

And with this simple command, you can train all models supported by traintool! See here for a list of models.

As you may have noticed, we set some parameters with the config argument above. config is the central place to define hyperparameters for training. The supported hyperparameters vary from model to model – it's best to have a look at the overview page linked above.

Experiment tracking

traintool automatically keeps track of all experiments you run. Each experiment is stored in a time-stamped folder in ./traintool-experiments. Have a look at this folder now to see the experiments you ran above! (If you are in Colab, click on the folder icon on the top left).

Tip: You can disable saving with save=False.

Each experiment folder contains:

  • info.yml: General information about the experiment
  • stdout.log: The entire console output
  • model files and possibly checkpoints (e.g. the pytorch binary model.pt for resnet18)
  • tensorboard logs (see below)

Visualizations

traintool writes all metrics and evaluations to tensorboard, a powerful visualization platform from tensorflow. Let's start tensorboard now: If you are on a local machine, start a terminal in this dir and type tensorboard --logdir traintool-experiments. If you are on Colab, just run the cell below:

%load_ext tensorboard
%tensorboard --logdir traintool-experiments/

Let's see what's going on here: On the bottom left, you can select individual experiments. On the right, you should by default see scalar metrics: The loss and accuracy for train and test set. You can also click on Images at the top to see some sample images from both datasets along with classification results (use the sliders to look at different epochs!).

Tip: You can also store metrics in comet.ml, see here.

Other functions

Before we end this quickstart, let's look at three other important functions:

  • Loading: To load a saved model, just pass its ID (or directory path) to traintool.load(...). Check out the line starting with Load via: in the console output above – it shows you directly which command to call.
  • Deployment: traintool can easily deploy your trained model through a REST API. Simply call model.deploy() to start the server (note that this call is blocking!). More information here.
  • Raw models: traintool models are implemented in different frameworks, e.g. scikit-learn or pytorch. You can get access to the raw models by calling model.raw().

That's it! You should now be able to start using traintool. Make sure to read the complete tutorial and documentation to learn more!

Please also consider leaving a ⭐ on our Github.