"Portrait photos" without portrait rights are generated by yourself, so please don't hesitate to use them!

Posted May 27, 202014 min read

Exclusive Image 1.png

In the advertisements or marketing materials of various products, we often see some exquisite big pictures showing the product usage scenes by models trying to attract us to produce a desire to buy. However, after engaging in related work, you will know how troublesome these things are to prepare. The pictures used must have proper authorization, and if the pictures contain portraits of people, the related authorization work becomes more troublesome, and even different countries and regions have different requirements for this.

First look at the following girl photo:

easily-build-pytorch-generative-adversarial-networks1.png

Very beautiful, right? And the shooting quality is quite high, with rich details and vivid colors. But the truth? She doesn't exist this girl! This is just a virtual character created by a machine learning model(picture taken from Wikipedia GAN entry ).

Generative Adversarial Networks(GAN) is a generative machine learning model that has been widely used in advertising, games, entertainment, media, pharmaceuticals, and other industries. It can be used to create fictional characters, scenes, and simulate human faces Aging and image style changes are even used to generate chemical formulas.

The following two pictures show the effect of picture-to-picture conversion and the effect of synthesizing scenes based on semantic layout:

1.png

The following will lead you from the perspective of engineering practice, with the help of AWS machine learning related cloud computing services, based on the PyTorch machine learning framework, build a generative confrontation network, and thereby start a new and interesting machine learning and artificial intelligence experience.

Overview of topics and solutions

First take a look at the two sets of handwritten digital pictures shown below. Can you tell which set is the real handwriting and which set is generated by the computer?

2.png

The subject of this article is to use machine learning methods to "imitate handwritten fonts". In order to complete this subject, we will experience the design and implementation of generative adversarial networks. The basic principles and engineering process of imitating handwritten fonts and portrait generation are basically the same. Although their complexity and accuracy requirements are somewhat different, by solving the problem of imitating handwritten fonts, we can lay the foundation for the principle and engineering practice of generating confrontation networks Then gradually try and explore more complex and advanced network architecture and application scenarios.

The Generative Adversarial Network(GAN) was proposed by Ian Goodfellow et al. In 2014. It is a deep neural network architecture consisting of a generative network and a discriminant network. The generating network generates "fake" data and attempts to deceive the discriminating network; the discriminating network performs authenticity authentication on the generated data, trying to correctly identify all "fake" data. During the training iteration, the two networks will continue to evolve and confront each other until they reach a state of equilibrium(reference:Nash equilibrium). The network is no longer able to recognize "false" data and the training is complete.

In 2016, in the paper "Deep Convolution Generated Adversarial Networks"(DCGAN) published by Alec Radford et al., It pioneered the application of convolutional neural networks to the design of model algorithms for generating adversarial networks, replacing the full-link layer and improving Stability of training in the picture scene.

Amazon SageMaker is AWS fully managed machine learning service. Data processing and machine learning training can be completed quickly and easily through Amazon SageMaker, and the trained model can be directly deployed in a fully managed production environment.

Amazon SageMaker provides a hosted Jupyter Notebook instance, which integrates with AWS 'various cloud services through the SageMaker SDK, making it easy for you to access data sources for exploration and analysis. SageMaker SDK is a set of open source Amazon SageMaker development kits that can help us better use the managed container images provided by Amazon SageMaker, as well as other cloud services of AWS, such as computing and storage resources.

easily-build-pytorch-generative-adversarial-networks6.png

As shown in the figure above, the training data will come from the Amazon S3 bucket; the training framework and hosting algorithm provide services in the form of container images, which are combined with the code during training; the model code runs in the computing instance hosted by Amazon SageMaker, Combine with data during training; the training output will enter the Amazon S3 special bucket. In the following explanation, we will learn how to use these resources through the SageMaker SDK.

The following operations will use Amazon SageMaker, Amazon S3, Amazon EC2 and other AWS services, which will incur certain cloud resource usage fees.

Model development environment

Create a Notebook instance

Open the Amazon SageMaker dashboard(click to open Beijing Region | Ningxia Region ), and then click the Notebook instances button to enter the list of notebook instances.

easily-build-pytorch-generative-adversarial-networks7.png

If you are using Amazon SageMaker for the first time, your Notebook instances list will be displayed as an empty list, and you need to click the Create notebook instance button to create a new Jupyter Notebook instance.

easily-build-pytorch-generative-adversarial-networks8.png

After entering the Create notebook instance page, please enter the instance name in the Notebook instance name field, this article will use "MySageMakerInstance" as the instance name. Here you can choose the name you think is appropriate. This article will use the default instance type, so the Notebook instance type option will remain ml.t2.medium.
If you are using Amazon SageMaker for the first time, you also need to create an IAM role so that notebook instances can access Amazon S3 services. Please click Create a new role in the IAM role option. Amazon SageMaker will create a role with the necessary permissions and assign this role to the instance being created. In addition, according to the actual situation, we can also choose an existing role.

easily-build-pytorch-generative-adversarial-networks9.png

In the Create an IAM role pop-up window, you can select Any S3 bucket, so that the notebook instance will be able to access all the buckets in your account. In addition, if necessary, you can also select Specific S3 buckets and enter the bucket name. Click the Create role button and this new role will be created.

easily-build-pytorch-generative-adversarial-networks10.png

Then you can see that Amazon SageMaker has created a character with a name like * AmazonSageMaker-ExecutionRole-****. For other fields, you can use the default values, click the Create notebook instance button to create an instance.

easily-build-pytorch-generative-adversarial-networks11.png

Go back to the Notebook instances page and you will see that the MySageMakerInstance notebook instance is displayed in the Pending state, which will last for about 2 minutes until it transitions to the InService state.

easily-build-pytorch-generative-adversarial-networks12.png

Write the first line of code

Click the Open JupyterLab link and you will see the familiar Jupyter Notebook loading interface on the new page. This article uses JupyterLab notebook as the engineering environment by default. You can also choose to use traditional Jupyter notebooks as needed.

easily-build-pytorch-generative-adversarial-networks13.png

Then click the conda \ _pytorch \ _p36 notebook icon to create a notebook called Untitled.ipynb, and you can change its name later. In addition, you can also create this notebook by choosing File> New> Notebook menu path and selecting conda \ _pytorch \ _p36 as Kernel.

easily-build-pytorch-generative-adversarial-networks14.png

Enter the first line of instructions in the newly created Untitled.ipynb notebook as follows:

import torch
A
print(f "Hello PyTorch {torch .__ version__}")

Source code download

Please enter the following instruction in the notebook to download the code to the local file system of the instance:

! git clone "https://github.com/mf523/ml-on-aws.git" "ml-on-aws"

After downloading, you can browse the source code structure through File browser.

easily-build-pytorch-generative-adversarial-networks15.png

The code and notebooks involved in this article have been verified by Python 3.6, PyTorch 1.4 and JupyterLab hosted by Amazon SageMaker. Related codes and notebooks can be obtained here.

Generate an adversarial network model

Algorithm principle

The generation network of the DCGAN model contains 10 layers. It uses a stepped transposed convolutional layer to improve the resolution of the tensor. The input shape is(batchsize, 100) and the output shape is(batchsize, 64, 64, 3). In other words, the generation network accepts the noise vector, and then continuously transforms until the final image is generated.

The discriminant network also contains 10 layers. It receives pictures in(64, 64, 3) format, uses 2D convolutional layers for downsampling, and finally passes them to the full-link layer for classification. The classification result is 1 or 0, which is true and false.

easily-build-pytorch-generative-adversarial-networks16.png

The training process of the DCGAN model can be roughly divided into three sub-processes.

easily-build-pytorch-generative-adversarial-networks17.png

First, the Generator network takes a random number as input to generate a "fake" picture; next, it uses the "true" picture and "fake" picture to train the Discriminator network and update the parameters; finally, the Generator network parameters are updated.

Code analysis

The file structure of the project directory byos-pytorch-gan is as follows,

    data
      empty
    dcgan
      entry_point.py
      model.py
    dcgan.ipynb
    helper.py
    model
      empty
    tmp
    empty

The file model.py contains three classes, which are network generator and discriminator:

class Generator(nn.Module):
...
A
class Discriminator(nn.Module):
...
A
class DCGAN(object):
"" "
A wrapper class for Generator and Discriminator,
The 'train_step' method is for single batch training.
"" "
...

The file train.py is used for the training of two neural networks, Generator and Discriminator, and mainly includes the following methods:

def parse_args():
...
A
def get_datasets(dataset_name, ...):
...
A
def train(dataloader, hps, ...):

Model debugging

In the development and debugging phase, you can run the train.py script directly from the Linux command line. Hyperparameters, input data channels, models and other training output storage directories can be specified through command line parameters.

python dcgan/train.py --dataset qmnist \
--Model-dir '/home/myhome/byom-pytorch-gan/model' \
--Output-dir '/home/myhome/byom-pytorch-gan/tmp' \
--Data-dir '/home/myhome/byom-pytorch-gan/data' \
--hps' {"beta1":0.5, "dataset":"qmnist", "epochs":15, "learning-rate":0.0002, "log-interval":64, "nc":1, "nz" :100, "sample-interval":100} '

This training script parameter design not only provides a good debugging method, but also the specifications and necessary conditions for integration with SageMaker Container, which takes into account both the freedom of model development and the portability of the training environment.

Model training and validation

Please find and open the notebook file named dcgan.ipynb. The training process will be introduced and executed by this notebook. The content code in this section is omitted, please refer to the notebook code.

There are many public data sets in the Internet environment, which are very helpful for machine learning engineering and scientific research, such as algorithm learning and effect evaluation. We will use the QMNIST handwriting font data set to train the model, and finally generate realistic "handwriting" font effect patterns.

data preparation

The torchvision.datasets package of the PyTorch framework provides the QMNIST data set. We can download the QMNIST data set to the local standby through the following instructions:

from torchvision import datasets
A
dataroot = './data'
trainset = datasets.QMNIST(root = dataroot, train = True, download = True)
testset = datasets.QMNIST(root = dataroot, train = False, download = True)

Amazon SageMaker created a default Amazon S3 bucket for us to access various files and data that may be needed in the machine learning workflow. We can get the name of this bucket through the default \ _bucket method of the sagemaker.session.Session class in the SageMaker SDK:
from sagemaker.session import Session
A

sess = Session()
A
# S3 bucket for saving code and model artifacts.
# Feel free to specify a different bucket here if you wish.
bucket = sess.default_bucket()

The SageMaker SDK provides packages and classes for operating Amazon S3 services. The S3Downloader class is used to access or download objects in S3, and the S3Uploader is used to upload local files to S3. Please upload the downloaded data to Amazon S3 for model training. Do not download data from the Internet during the model training process, to avoid network delays caused by obtaining training data through the Internet, and to avoid security risks that may arise from model training due to direct access to the Internet.

from sagemaker.s3 import S3Uploader as s3up
A
s3_data_location = s3up.upload(f "{dataroot}/QMNIST", f "s3://{bucket}/data/qmnist")

Training execution

Through the sagemaker.getexecutionrole() method, the current notebook can get the role pre-assigned to the notebook instance. This role will be used to obtain training resources, such as downloading training frame images, allocating Amazon EC2 computing resources, etc.

The hyperparameters used for training the model can be defined in the notebook to achieve separation from the algorithm code. When creating the training task, the hyperparameters are passed in and dynamically combined with the training task.

hps = {
"Learning-rate":0.0002,
"Epochs":15,
"Dataset":"qmnist",
"Beta1":0.5,
"Sample-interval":200,
"Log-interval":64
}}

The PyTorch class in the sagemaker.pytorch package is a model fitter based on the PyTorch framework, which can be used to create and execute training tasks, as well as deploy the trained model. In the parameter list, train \ _instance \ _type is used to specify the CPU or GPU instance type. The training script and the directory where the model code is included are specified by source \ _dir. The name of the training script file must be clearly defined by entry \ _point. These parameters will be passed to the training task along with the remaining parameters. They determine the running environment of the training task and the parameters during model training.

from sagemaker.pytorch import PyTorch
A
estimator = PyTorch(role = role,
_ Entry _ entry = 'train.py',
Source_dir = 'dcgan',
Output_path = s3_model_artifacts_location,
_ Code_location = s3_custom_code_upload_location,
_ Train_instance_count = 1,
Train_instance_type = 'ml.c5.xlarge',
Train_use_spot_instances = True, train_use_spot_instances = True,
_ Train_max_wait = 86400,
Framework_version = '1.4.0',
py_version = 'py3',

() Hyperparameters = hps)

Please pay special attention to the train \ _use \ _spot \ _instances parameter, the True value means that you want to use the SPOT instance first. Since machine learning training usually requires a lot of computing resources to run for a long time, the effective use of SPOT can achieve effective cost control. The SPOT instance price may be 20%to 60%of the on-demand instance price, depending on the choice of instance type, region, and time. The price is different.
After creating the PyTorch object, you can use it to fit the data pre-existing on Amazon S3. The following instructions will perform the training task, and the training data will be imported into the training environment as an input channel named QMNIST. During the training execution, the training data on Amazon S3 will be downloaded to the local file system of the model training environment, and the training script train.py will load the data from the local disk for training.

# Start training
estimator.fit({'QMNIST':s3_data_location}, wait = False)

Depending on the selected training example, the training process may last from tens of minutes to several hours. It is recommended to set the wait parameter to False. This option will separate the notebook from the training task. In the scenario of long training time and many training logs, the notebook context can be prevented from being lost due to network interruption or session timeout. After the training task leaves the notebook, the output will be temporarily invisible. You can execute the following code, and the notebook will obtain and load the previous training session:

%%time
from sagemaker.estimator import Estimator
A
# Attaching previous training session
training_job_name = estimator.latest_training_job.name
attached_estimator = Estimator.attach(training_job_name)

Because the model design takes into account the GPU's ability to accelerate training, the GPU instance training will be faster than the CPU instance. For example, the p3.2xlarge instance takes about 15 minutes, while the c5.xlarge instance may take more than 6 hours. At present, the model does not support distributed and parallel training, so multiple instances and multiple CPUs/GPUs will not bring more training speed.
After the training is completed, the model will be uploaded to Amazon S3, and the upload location is specified by the output \ _path parameter provided when the PyTorch object is created.

Validation of the model

To do this, we need to download the trained model from Amazon S3 to the local file system of the instance where the notebook is located. The following code will load the model, and then enter a random number to obtain the inference results and display them in the form of pictures.

Run the following command to load the trained model, and use this model to generate a set of "handwritten" digital fonts:

from helper import *
import matplotlib.pyplot as plt
import numpy as np
import torch
from dcgan.model import Generator
A
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
A
params = {'nz':nz, 'nc':nc, 'ngf':ngf}
model = load_model(Generator, params, "./model/generator_state.pth", device = device)
img = generate_fake_handwriting(model, batch_size = batch_size, nz = nz, device = device)
A
plt.imshow(np.asarray(img))

easily-build-pytorch-generative-adversarial-networks18.png

Conclusion and summary

In recent years, the fast-growing PyTorch framework is being widely recognized and applied. More and more new models adopt the PyTorch framework, and some models have been migrated to PyTorch, or have been completely reimplemented based on PyTorch. The ecological environment continues to be enriched and the application fields continue to expand. PyTorch has become one of the de facto mainstream frameworks.

Amazon SageMaker is tightly integrated with multiple AWS services, such as Amazon EC2 computing instances of various types and sizes, Amazon S3, Amazon ECR, etc., providing an end-to-end consistent experience for machine learning engineering practices. Amazon SageMaker continues to support mainstream machine learning frameworks, and PyTorch is one of them.

Machine learning algorithms and models developed with PyTorch can be easily ported to Amazon SageMaker's engineering and service environment, and then use Amazon SageMaker fully managed Jupyter Notebook, training container image, service container image, training task management, deployment environment hosting and other functions , Simplify machine learning engineering complexity, improve production efficiency, and reduce operation and maintenance costs.

DCGAN is a landmark in the field of generative adversarial networks and the cornerstone of many complex generative adversarial networks today. The StyleGAN mentioned at the beginning of the article, StackGAN that synthesizes images with text, Pix2pix that generates images from sketches, and DeepFakes, which are controversial on the Internet, all have the shadow of DCGAN. I believe that through the introduction and engineering practice of this article, it will be helpful for everyone to understand the principles and engineering methods of generating an adversarial network.

Bottom image 2.png