Skip to main content

Vision, image and machine learning (partie AM)

7 mins

Posture-guided image synthesis of a person (TP) #

This TP aims to implement with PyTorch the transfer of motion from a source video to a target person, following a paper by Chan∗ etal presented at ICCV 2019: Everybody Dance Now. This approach has been chosen to provide an introduction to GAN, but the subject goes through various stages. To understand the principle, the TP proposes an approach that searches for the closest data in the initial dataset (no network). Then a simple network that generates an image of the person from the array of points in a skeleton of the person. And finally, a network that takes as input a “stick” image of the person, which is “boosted” by a GAN-like discriminator.

The paper proposes various methods for improving temporal continuity and faces, which we won’t be looking at.

Image alt

Principle #

As an input, you need a video of the target person performing few movements. Note: recent approaches can be satisfied with a single image, but the idea here is to practice, not to be on the latest paper.

From a 2nd video of a source person, the final goal is to get the target person performing the same movements of the source person. This is done by producing a new, frame-by-frame video of the target person with the pose/skeleton extracted from the source video. To extract the skeleton from the videos, we use a pre-trained network from Mediapipe. The given code already does this job.

The machine learning model (NN) has to learn from the images of the target video how to produce a new image of this person in a new posture given as input. If the video of this person is rich enough, and contains all possible postures, we could simply search for the image whose skeleton is “similar” (question 1). Next, we’ll be looking for a network that will generalize. It will be able to produce an image, even with a posture that has not be been before. We’ll try a direct network, then a GAN.

The starting code #

The starting code is here or a git here. You need to install the classics (numpy, pytorch), but also OpenCV (cv2) and mediapipe. Maybe, you have to change the path “tp/danse/data/…” by “data/…”.

The various files are as follows.

  • VideoReader: basic functions for video playback and image retrieval (uses cv2).
  • Vec3: 3D points, based on a numpy array.
  • Skeleton: a class that stores the 3D positions of a skeleton. There are 33 joints given by mediapipe, so 99 floats in all. It’s possible to switch to reduced mode (reduced=True as a parameter to various functions) to have just 13 joints in 2D, so 26 floats.
  • VideoSkeleton: a class that associates a skeleton with each frame of a video. The skeleton is stored in memory, but the video frame is represented by the frame file name (storing all the frames of a video would take up too much memory if the video is long). This class slices a video into images saved on disk.
  • GenNearest, GenVanilla and GenGan: the 3 image generators to be written.
  • DanceDemo: the main class running a dance demo. The animation/posture from self.source is applied to the character defined by self.target using self.gen.

In GenXXX classes, the heart of the problem is the function def generator(self, ske): which returns the image of the target person with the skeleton ske received as a parameter. This is generated from the dataset containing a set of pairs (image, skeleton).

Image alt

Setup the data #

First run the VideoSkeleton script, which will produce images from a video. With the default settings, the script produces 1400 frames from the video taichi.mp4, which contains 14000 frames.

Closest skeleton (nearest neighbor) #

The basic solution is to search the dataset for the image whose associated skeleton is closest to the one you’re looking for. This is encoded in the GenNearest::generate function. This is not an efficient solution, as it consumes a lot of memory and can take a long time to search.

  1. Run DemoDance.py. This is the main program. The target image is white to start with.
  2. Code generate in class GenNearest.py.

Key insights

  • Each image is realistic.
  • Searching for the nearest image is slow.
  • Temporal continuity is not guaranteed.
  • Generalization is poor.

Direct neural network #

The idea here is to train a basic network that produces an image from the skeleton. The skeleton is represented by an array of numbers. In the Skeleton code, you can choose to extract the skeleton in a reduced size: 13 joints in 2D. The network can therefore take 26 numbers as input and produce an image.

  1. Look at the class GenVanillaNN (option 1 in the constructor) that use the network GenNNSke26ToImage.
  2. Code train.
  3. Code Generate.

The problem is to go from a tensor of (_,26,1,1) to a tensor of (_,3,64,64) (RGB x 64 x 64). The simplest is to use a linear nn.Linear(26, 3 * 64 * 64) and to change the view in the forward fonction with

x = x.view(x.size(0), 3, 64, 64)  # Reshape (batch_size, 3, 64, 64) 

Or you can use ConvTranspose2d. For instance:

nn.ConvTranspose2d(26, 128, kernel_size=4, stride=1, padding=0)  # (26, 1, 1) -> (128, 4, 4) 
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)  # (128, 4, 4) -> (64, 8, 8)
...

(See here for an example)[https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html].

Important remark: the tgt_transform of the dataloader provides ouput image (target) in the range [-1,1] after normalization. So the last layer of you network should be a tanh (and not a signmoid).

Do not spend much time on that approach, it is not the best one.

Image alt

Neural network with the skeleton as input #

The paper suggests working with an intermediate image in which the skeleton is represented with sticks. This image is easy to obtain. In the Skeleton class, the draw_reduced function produce such an image. Change the previous network so that it takes the image as input instead of the skeleton.

Image alt

  1. Look at the class GenVanillaNN (option 2 in the constructor) that use the network GenNNSkeimToImage.
  2. Code train.
  3. Code Generate.

In this version, the src_transform add SkeToImageTransform(image_size) to convert the 26 dimensions skeleton into an image of the skeleton.

GAN #

To improve the quality of the generator, the paper adds a discriminator network that detects whether the image is a false image or a true image. This principle is similar to GAN, although in GAN the input image is noise.

Image alt

Train #

Usually, the training loop looks something like this:

criterion = nn.MSE()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

Important remarks #

  1. The tgt_transform of the dataloader provides ouput image (target) in the range [-1,1] after normalization. So the last layer of you network should be a tanh (and not a signmoid).

  2. You are free to improve the Generator and the training as you want. Ideas:

  • BatchNorm based on a ResNet (often good for classification); InstantNorm (often better with GAN);
  • ResidualBlock in the middle of the U-shaped generator;
  • self-atention;
  • WGAN-GP;
  • least-squares adversarial loss (LSGAN) as a stabilized surrogate
  • etc.

For instance, WGAN-GP, see file traning.py with

 def compute_gradient_penalty(self, real_samples, fake_samples):
        """Compute gradient penalty for WGAN-GP"""
        device = real_samples.device
        batch_size = real_samples.size(0)
        alpha = torch.rand(batch_size, 1, 1, 1, device=device)
        alpha = alpha.expand_as(real_samples)

        # Interpolate between real and fake samples
        interpolated = alpha * real_samples + (1 - alpha) * fake_samples
        interpolated.requires_grad_(True)

        # Get discriminator output for interpolated images
        d_interpolated = self.netD(interpolated)

        # Calculate gradients of probabilities with respect to examples
        gradients = torch.autograd.grad(outputs=d_interpolated, inputs=interpolated,
                                        grad_outputs=torch.ones_like(d_interpolated),
                                        create_graph=True, retain_graph=True)[0]

        # Calculate gradient penalty
        gradients = gradients.view(batch_size, -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

        return gradient_penalty