Draft

  Draw Me a Simulator

Using neural networks to build more realistic simulation schemes for causal analysis

Creative Commons BY License ISSN 2824-7795

This study explores the use of Variational Auto-Encoders to construct simulators approximating the law of genuine observations, emphasizing the challenges of evaluating their quality through statistical and predictive comparisons.

Authors
Affiliations

Inserm, Université Paris Cité, Inria

Published

March 21, 2025

Modified

March 21, 2025

Keywords

simulations, variational auto-encoders, counterfactuals

Status

draft

Abstract

This study investigates the use of Variational Auto-Encoders to build a simulator that approximates the law of genuine observations. Using both simulated and real data in scenarios involving counterfactuality, we discuss the general task of evaluating a simulator’s quality, with a focus on comparisons of statistical properties and predictive performance. While the simulator built from simulated data shows minor discrepancies, the results with real data reveal more substantial challenges. Beyond the technical analysis, we reflect on the broader implications of simulator design, and consider its role in modeling reality.

build status Creative Commons License

Asher Rubin walks out of the starosta’s home and heads toward the market square. With evening, the sky has cleared, and now a million stars are shining, but their light is cold and brings down a frost upon the earth, upon Rohatyn. The first of this autumn. Rubin pulls his black wool coat tighter around him; tall and thin, he looks like a vertical line. (Tokarczuk 2021, I(3))

1 Introduction

1.1 Fiction as the original simulation

One of humanity’s oldest creative endeavors, fiction represents an early form of simulation. It extends the imaginative play where children create scenarios, roles, or worlds that are not constrained by the rules of reality, that is “childhood pretence” (Carruthers 2002) or “the make-believe games” of children (Walton 1993). Through stories, myths, and imagined worlds, humans construct alternative realities to explore ideas, express emotions, and reflect on their existence. By presenting hypothetical scenarios and posing “what if things had been different” questions (Pearl and Mackenzie 2018, 34), fiction empowers individuals to explore alternative histories, draw insights from the experiences of others, and engage with possibilities that extend beyond the confines of the physical world. At its core, fiction abstracts and reconstructs elements of reality. An author selectively includes, exaggerates, or omits aspects of the real world, creating models that serve their artistic or thematic intentions. From Homer’s Odyssey (Homère 2000) to speculative tales like Mary Shelley’s Frankenstein (Shelley 1818), fiction mirrors the complexities of human life, enabling readers to engaged with an imagined reality that resonates with their own.

The relationship between fiction and reality has long been a subject of debate. Plato, in his critique of art, viewed fiction as a mere imitation of the physical world, itself a flawed reflection of the ideal “Forms”. By this reasoning, fiction is a “simulation of a simulation”, twice removed from truth (Platon 2002, Livre X). Aristotle, by contrast, argued that fiction, through “mimesis”, the imitation of action and life, can illuminate universal truths (Aristote 2006, Chapitres 1 à 5). By abstracting from the particular, fiction allows exploration of broader patterns and principles.

Following Aristotle’s perspective, this tradition of creating and interacting with imagined realities provides a natural foundation for distinguishing scientific theories from scientific models (Barberousse and Ludwig 2000) and understanding modern simulations. While they stem from the same drive to represent and explore, scientific theories, scientific models and modern simulations introduce a higher degree of mathematical rigor. Nevertheless, fiction remains their conceptual ancestor, reminding us that the human impulse to model and engage with alternate realities is as old as storytelling itself.

1.2 From modern simulations to computer simulations

The concept of modern simulations predates the modern era. Early instances include mechanical devices like the Antikythera, a sophisticated analog computer from the 2nd century BCE designed to simulate celestial movements (and the MacGuffin chased by Indiana Jones in the 2024 installment of the franchise, Solly 2023). The emergence of mathematical models in the works of Galileo and Newton introduced a new form of simulation, where equations were used to predict physical phenomena with increasing precision. By the 18th century, probabilistic experiments like Buffon’s Needle, designed to approximate the number \pi (Aigner and Ziegler 2018, sec. 24), demonstrated the power of simulating complex systems. However, the advent of computer simulations, as we understand them today, began during World War II with the work of J. von Neumann and S. Ulam (Metropolis and Ulam 1949).

While studying neutron behavior, they faced a challenge that was too complex for theoretical analysis and too hazardous, time-consuming, and costly to investigate experimentally. Fundamental properties (e.g., possible events and their probabilities) and basic quantities (e.g., the average distance a neutron would travel before colliding with an atomic nucleus, the likelihood of absorption or reflection, and energy loss after collisions) were known, but predicting the outcomes of entire event sequences was infeasible. To address this challenge, they devised a method of generating random sequences step by step using a computer, naming it “Monte Carlo” after the casino, a suggestion by N.  Metropolis. Statistical analysis of the data produced by repeatedly applying this method provided sufficiently accurate solutions to better understand nuclear chain reactions, a crucial aspect of designing atomic bombs and later nuclear reactors. This breakthrough marked the birth of modern computer simulations.

Today, computer simulations, henceforth referred to simply as simulations, play a fundamental role in applied mathematics. Generally, conducting a simulation involves running a computer program (a “simulator”) designed to represent a “system of interest” at a problem-dependent level of abstraction (that is, with a specific degree of complexity) and collecting the numerical output for analysis.

Examples of systems of interest are virtually limitless and highly diverse. They can represent a real-world process in a holistic fashion, such as the regular functioning of a person’s heart at rest, or the medical trajectories of a cohort of patients undergoing chemotherapy. Alternatively, in a more focused fashion, they can consist of a hybrid pipeline that combines an upstream real-world process with downstream data processing of intermediary outputs, such as the estimation of peripheral oxygen saturation in a healthy patient using a pulse oximeter. Regardless of the context, determining the appropriate levels of abstraction and realism is always a significant challenge.

Here, we focus on simulations used to evaluate the performance of statistical procedures through simulation studies, as discussed by Morris, White, and Crowther (2019) in their excellent tutorial on the design and conduct of such studies. The interested reader will find in their work a carefully curated list of books on simulation methods in general and articles emphasizing rigor in specific aspects of simulation studies. Specifically, we consider scenarios where a statistician, interested in a real-world process, has developed an algorithm tailored to learning a particular feature of that process from collected data and seeks to assess the algorithm’s performance through simulations.

Once the simulator is devised, the following process is repeated multiple times. In each iteration, typically independently from previous iterations: first, the simulator generates a synthetic data set of size n; second, the algorithm is run on the generated data; third, the algorithm’s output is collected for further analysis. After completing these iterations, the next step is to compare the outcome from one run to the algorithm’s target. This is made possible by the design of the simulator. Finally, the overall performance of the algorithm is assessed by comparing all the results collectively to the algorithm’s target. Depending on the task, this evaluation can involve assessing the algorithm’s ability to well estimate its target, the validity of the confidence regions it constructs for its target, the algorithm’s ability to detect whether its target lies within a specified null domain (using an alternative domain as a reference), and more. This list is far from exhaustive. The entire process can be repeated multiple times, for example, to assess how the algorithm’s performance depends on n.

However, in order to carry out these steps, the statistician must first devise a simulator. This simulator should ideally generate synthetic data that resemble the real-world data in a meaningful way, a goal that is often difficult to achieve. So, how can one design a realistic simulator, and what does “realistic simulator” even mean in this context? These are the central questions we explore in this work.

1.3 A probabilistic stance

We adopt a probabilistic framework to model the data collected by observing a real-world process. Specifically, the data are represented as a random variable O^{n} (O as in observations) drawn from a probability law P^{n} (P as in probability). The law P^{n} is assumed to belong to a statistical model \mathcal{M}^{n} (\mathcal{M} as in model), which is the set of all probability laws on the space \mathcal{O}^{n} where O^{n} takes its values. This model incorporates constraints that reflect known properties of the real-world process and, where necessary, minimal assumptions about it.

The superscript n indicates an amount of information. For example, in the context of this study, n typically represents the number of elementary observations drawn independently from a law P on \mathcal{O} and gathered in O^{n}. In this case, \mathcal{O}^{n} corresponds to the Cartesian product \mathcal{O} \times \cdots\times \mathcal{O} (repeated n times) and P^{n} to the product law P^{\otimes n}, with O^{n} decomposing as (O_{1}, \ldots, O_{n})

The feature of interest is an element of a space \mathcal{F} (e.g., a subset of the real line, or a set of functions). It is modeled as the value \Psi(P^{n}) of a functional \Psi:\mathcal{M}^{n} \to \mathcal{F} evaluated at P^{n}. The algorithm developed to estimate this feature is modeled as a functional \mathcal{A}: \mathcal{O}^{n} \to \mathcal{F}. Training the algorithm involves applying \mathcal{A} to the observed data O^{n}, resulting in the estimator \mathcal{A}(O^{n}) for the estimand \Psi(P^{n}).

We emphasize that we address the questions closing Section 1.2 without focusing on the specific nature of the functional of interest \Psi: how can one design a realistic simulator, and what does “realistic simulator” even mean in this context?

1.4 Draw me a simulator

When constructing simulators, there is a spectrum of approaches, varying in complexity and flexibility. At one end of the spectrum, simulators are built upon relatively simple parametric models. While these models are sometimes more elaborate, they often rely on standard forms or recurring techniques, which streamlines their implementation. This approach is further reinforced by the common practice of using models proposed by others. Doing so not only saves effort but also facilitates meaningful comparisons between studies, as the same modeling framework is shared.

Regardless of the model’s simplicity, parametric simulators are inherently limited and unable to capture the complexity of real-world processes. The term “unnatural” aptly describes this shortcoming, as these models are simplifications that abstract away many intricacies of reality. Even with sophisticated parametrizations, it is fundamentally impossible for such simulators to convincingly replicate the multifaceted interactions and variability inherent in “nature”. Thus, parametric simulators, by their very essence, cannot achieve realism.

At the other end of the spectrum, one can also adopt a nonparametric approach through bootstrapping, which involves resampling data directly from the observed dataset. This method bypasses the need to specify a parametric model and instead leverages the structure of the real data to generate simulated samples.

Bootstrapping usually refers to a self-starting process that is supposed to continue or grow without external input. The term is sometimes attributed to the story where Baron Münchausen pulls himself and his horse out of a swamp by his pigtail, not by his bootstraps (Raspe 1866, chap. 4). In France, “bootstrap” is sometimes translated as “à la Cyrano”, in reference to the literary hero Cyrano de Bergerac, who imagined reaching the moon by standing on a metal plate and repeatedly using a magnet to propel himself (Rostand 2005, Act III, Scene 13).

When dealing with independent and identically distributed (i.i.d.) samples, bootstrapping generates data that closely resemble the observed data. However, the origin of the term “bootstrapping” suggests a measure of incompleteness hence dissatisfaction, which is fitting in the context of this article. Indeed, a bootstrapped simulator can be viewed as both transparent and opaque, depending on the perspective. Conditionally on the real data, the simulator’s behavior is transparent, as understanding it reduces to understanding the sampling mechanism over the set of indices \{1, \ldots, n\}. Unconditionally, however, one is again confronted with the limitation of knowledge about P^{n}, beyond recognizing it as an element of \mathcal{M}^{n}.

Hide/Show the code
cowsay -f sheep \
"I am a simulator. Press ENTER to run the synthetic experiment."
 ______________________________________
/ I am a simulator. Press ENTER to run \
\ the synthetic experiment.            /
 --------------------------------------
  \
   \
       __     
      UooU\.'@@@@@@`.
      \__/(@@@@@@@@@@)
           (@@@@@@@@)
           `YY~~~~YY'
            ||    ||

In Le Petit Prince (de Saint-Exupéry 1943), the Little Prince dismisses the pilot’s simple drawings of a sheep as unsatisfactory. Instead, he prefers a drawing of a box, imagining the perfect sheep inside.

Hide/Show the code
echo \
"I am a simulator. Press ENTER to run the synthetic experiment." | \
boxes -d cc
/******************************************************************
 * I am a simulator. Press ENTER to run the synthetic experiment. *
 ******************************************************************/

Similarly, in simulations, straightforward simulators often fail to capture the complexity we seek, while black-box simulators, though opaque, can sometimes offer greater efficiency. Unlike the Little Prince, however, we are not content with the box alone – we want to look inside, to understand and refine the mechanisms driving our simulator.

1.5 Organization of the article

In this article, we explore an avenue to build more realistic simulators by using real data and neural networks, more specifically, Variational Auto-Encoders (VAEs). To illustrate our approach, we focus on a simple example rooted in causal analysis, as the causal framework presents particularly interesting challenges.

Section 2 outlines our objectives and introduces a running example that serves as a unifying thread throughout the study. Section 3 provides a concise overview of VAEs, including their formal definition and the key ideas behind their training. Section 4 offers an explanation of how VAEs are constructed, while Section 5 presents a comprehensive implementation tailored to the running example. Using this VAE, Section 6 describes the construction of a simulator designed to approximate the law of simulated data and discusses methods for evaluating the simulator’s performance. Section 7 extends this approach to a real-world dataset. Finally, Section 8 concludes the article with a literature review, a discussion of the challenges encountered, the limitations of the proposed approach, and some closing reflections.

Note that the online version of this article is preferable to the PDF version, as it allows readers to directly view the code. Throughout the article, we use a mix of Python (Van Rossum and Drake 2009) and R (R Core Team 2020) for implementation, leveraging commonly used libraries in both ecosystems.

2 Objective

Suppose that we have observed O_{1}, \ldots, O_{n}, O_{n+1}, \ldots, O_{n+n'} drawn independently from P, with P known to belong to a model \mathcal{P} consisting of laws on \mathcal{O}. For brevity, we will use the notation O^{1:n} = (O_{1}, \ldots, O_{n}) and O^{(n+1):(n+n')} = (O_{n+1}, \ldots, O_{n+n'}).

Suppose moreover that we are interested in a causal framework where each O_{i} is viewed as a piece of a complete data X_{i} \in \mathcal{X} drawn from a law Q that lives in a model \mathcal{Q}, with X_{1}, \ldots, X_{n}, X_{n+1}, \ldots, X_{n+n'} independent. The piece O_{i} is expressed \pi(X_{i}), with the function \pi projecting a complete data X \sim Q \in \mathcal{Q} onto a coarser real-world data O=\pi(X) \sim P \in \mathcal{P}.

Our objective is twofold. First, we aim to build a generator that approximates P, that is, an element of \mathcal{P} from which it is possible to sample independent data that exhibit statistical properties similar to (or, colloquially, “behave like”) O_{1}, \ldots, O_{n+n'}. In other words, we require that the generator produces data whose joint law approximates the law of the observed data, ensuring that the generated samples reflect the same underlying structure and dependencies as the real-world observations. Second, we require the generator to correspond to the law of \pi(X) with X drawn from an element of \mathcal{Q}.

We use a running example throughout the document.

Running example.

For example, \mathcal{P} can be the set of all laws on \mathcal{O}:= (\{0,1\}^{2} \times \mathbb{R}^{3}) \times \{0,1\} \times \mathbb{R} such that

O := (V,W,A,Y)\sim P\in \mathcal{P}

satisfies

c \leq P(A=1|W,V), P(A=0|W,V) \leq 1-c

P-almost surely for some P-specific constant c \in ]0,1/2], and Y is P-integrable.

Moreover, we view O as \pi(X) with \begin{aligned} X&:=(V, W, Y[0], Y[1], A) \in \mathcal{X}:= ( \{0,1\}^{2} \times \mathbb{R}^{3}) \times \mathbb{R}\times \mathbb{R}\times \{0,1\},\\ \pi&:(v,w,y[0],y[1],a) \mapsto (v,w, a, ay[1] + (1-a)y[0]), \end{aligned}

and \mathcal{Q} defined as the set of all laws on \mathcal{X} such that X\sim Q\in \mathcal{Q} satisfies

c' \leq Q(A=1|W,V), Q(A=0|W,V) \leq 1-c'

Q-almost surely for some Q-specific constant c' \in ]0,1/2], and Y[0] and Y[1] are Q-integrable.

We consider (V, W) as the context in which two possible actions a=0 and a=1 would yield the counterfactual rewards Y[0] and Y[1], respectively. One of these actions, A\in\{0,1\}, is factually carried out, resulting in the factual reward Y = A Y[1] + (1-A) Y[0], that is, Y[1] if A=1 and Y[0] otherwise. In the causal inference literature, this definition of Y is referred to as the consistency assumption.

Running example in action.

The Python function simulate defined in the next chunk of code operationalizes drawing independent data from a law P \in \mathcal{M}.

Hide/Show the code
import numpy as np
import random
from numpy import hstack, zeros, ones

def simulate(n, dimV, dimW):
  def expit(x):
    return 1 / (1 + np.exp(-x))
  p = np.hstack((1/3 * np.ones((n, 1)), 1/2 * np.ones((n, 1))))
  V = np.random.binomial(n = 1, p = p)
  W = np.random.normal(loc = 0, scale = 1, size = (n, dimW))
  WV = np.hstack((W, V))
  pAgivenWV = np.clip(expit(0.8 * WV[:, 0]), 1e-2, 1 - 1e-2)
  A = np.random.binomial(n = 1, p = pAgivenWV)
  meanYgivenAWV  = 0.5 * expit(-5 * A * (WV[:, 0] - 1)\
                               + 3 * (1 - A) * (WV[:, 1] + 0.5))\
                               + 0.5 * expit(WV[:, 2])
  Y = np.random.normal(loc = meanYgivenAWV, scale = 1/25, size = n)
  dataset = np.vstack((np.transpose(WV), A, Y))
  dataset = np.transpose(dataset)
  return dataset

Note that justifying the specific choices made while defining the function simulate is unnecessary. In the context of this study, we are free from the need for, or aspiration to, a realistic simulation scheme. Under the law P that simulate samples from, V and W are independent; V consists of two independent variables V_{1} and V_{2} that are drawn from the Bernoulli laws with parameters \tfrac{1}{3} and \tfrac{1}{2}; W is a standard Gaussian random variable. In addition, given (W,V), A is sampled from the Bernoulli law with parameter

\max\left\{0.01, \min\left[0.99, \mathop{\mathrm{expit}}(0.8 \times W_{1})\right]\right\}

and, given (A,W,V), Y is sampled from the Gaussian law with mean

\tfrac{1}{2} \mathop{\mathrm{expit}}\left[-5A\times(W_{1}-1) + 3(1-A)\times (\tfrac{1}{2} + W_{2})\right] + \tfrac{1}{2} \mathop{\mathrm{expit}}(W_{3})

and (small) standard deviation \tfrac{1}{25}. As noted in the introduction, these choices rely on standard forms and recurring techniques.

Running example, cted.

For future use, we sample in the next chunk of code n+n'=10^{4} independent observations from P. Observations O^{1:n} (gathered in train) will be used for training and observations O^{(n+1):(n+n')} (gathered in test) will be used for testing.

Hide/Show the code
import random
random.seed(54321)
dimV,  dimW = 2, 3
n_train = int(5e3)
train = simulate(n_train, dimV,  dimW)
test = simulate(n_train, dimV,  dimW)
print("The three first observations in 'train':\n",
      "   V_1    V_2    W_1    W_2    W_3    A     Y\n",
      np.around(train[:3, [3, 4, 0, 1, 2, 5, 6]], decimals = 3))
The three first observations in 'train':
    V_1    V_2    W_1    W_2    W_3    A     Y
 [[ 0.     1.    -2.02  -0.85   1.349  0.     0.569]
 [ 1.     0.    -1.231  0.573  0.604  0.     0.79 ]
 [ 0.     0.     0.26  -1.032  0.923  1.     0.84 ]]
Hide/Show the code
## np.savetxt("data/train.csv", train, delimiter = ",")
## np.savetxt("data/test.csv", test, delimiter = ",")

3 VAE in a nutshell

3.1 Formal definition

In the context of this article, a Variational Auto-Encoder (VAE) (Kingma and Welling 2014), (Rezende, Mohamed, and Wierstra 2014) is an algorithm that, once trained, outputs a generator. The generator is the law of a random variable of the form

\mathop{\mathrm{Gen}}_{\theta} (Z) \tag{1}

where

  1. the source of randomness Z in Equation 1 writes as Z := (Z^{(0)}, \ldots, Z^{(d)}) \tag{2} with Z^{(0)}, \ldots, Z^{(d)} independently drawn from
  • the uniform law on \{1, \ldots, n\} for Z^{(0)}

  • the standard normal distribution for Z^{(1)}, \ldots, Z^{(d)};

  1. the function \mathop{\mathrm{Gen}}_{\theta} in Equation 1 is an element of a large collection, parametrized by the finite-dimensional set \Theta, of functions mapping \mathbb{R}^{d+1} to \mathcal{X}.

Because \mathop{\mathrm{Gen}}_{\theta}(Z) belongs to \mathcal{X}, we can evaluate \pi \circ \mathop{\mathrm{Gen}}_{\theta}(Z), hence the generator can also be used to generate random variables in \mathcal{O}. Figure 1 illustrates the architecture of the VAE used in this study. It shows the key components of the model, including the encoder, the latent space, and the decoder, along with the flow of information between them.

Figure 1: Architecture of the simulator. The figure depicts the flow of information through the encoder, latent space, and decoder components. It emphasizes how the input source of randomness Z is transformed into a latent representation and then reconstructed as a complete data, X=\mathop{\mathrm{Gen}}_{\theta}(Z), which can be mapped to a real-world data O=\pi(X).

The word “auto-encoder” reflects the nature of the parametric form of each \mathop{\mathrm{Gen}}_{\theta}. We begin with a formal presentation in four steps, which is then followed by a discussion of what each step implements. Specifically, each \mathop{\mathrm{Gen}}_{\theta} writes as a composition of four mappings J_{n}, \mathop{\mathrm{Enc}}_{\theta_{1}}, K and \mathop{\mathrm{Dec}}_{\theta_{2}} with \theta := (\theta_{1}, \theta_{2}) \in \Theta_{1} \times \Theta_{2} = \Theta:

\mathop{\mathrm{Gen}}_{\theta} = \mathop{\mathrm{Dec}}_{\theta_{2}} \circ K \circ \mathop{\mathrm{Enc}}_{\theta_{1}} \circ J_{n}.

Here,

  1. J_{n}: \{1, \ldots, n\} \times \mathbb{R}^{d} \to \mathcal{O}\times \mathbb{R}^{d} is such that J_{n} (Z) = (O_{i}, (Z^{(1)}, \ldots, Z^{(d)})) with i = Z^{(0)};

  2. \mathop{\mathrm{Enc}}_{\theta_{1}} : \mathcal{O}\times \mathbb{R}^{d} \to \mathbb{R}^{d} \times (\mathbb{R}_{+}^{*})^{d} \times \mathbb{R}^{d} is such that, if \mathop{\mathrm{Enc}}_{\theta_{1}}(o,z) = (\mu, \sigma, z'), then

    • z=z', and
    • \mathop{\mathrm{Enc}}_{\theta_{1}}(o,z'') = (\mu, \sigma, z'') for all z''\in\mathbb{R}^{d};
  3. K : \mathbb{R}^{d} \times (\mathbb{R}_{+}^{*})^{d} \times \mathbb{R}^{d} \to \mathbb{R}^{d} is given by K(\mu,\sigma,z) := \mu + \sigma \odot z, where \odot denotes the componentwise product;

  4. \mathop{\mathrm{Dec}}_{\theta_{2}} maps \mathbb{R}^{d} to \mathcal{X}.

Conditionally on O^{1:n} and Z, the computation of \mathop{\mathrm{Gen}}_{\theta} (Z) is deterministic. The process unfolds in four steps:

  1. Sampling and transfer. Compute J_{n}(Z), which involves sampling one observation O_{i} uniformly among all genuine observations and transfer (Z^{(1)}, \ldots, Z^{(d)}) unchanged.

  2. Encoding step. Compute \mathop{\mathrm{Enc}}_{\theta_{1}} \circ J_{n}(Z), which encodes O_{i} as a vector \mu \in \mathbb{R}^{d} and a d\times d covariance matrix \mathop{\mathrm{diag}}(\sigma)^2. This step does not modify (Z^{(1)}, \ldots, Z^{(d)}), which is transferred unchanged.

  3. Gaussian sampling. Compute K\circ \mathop{\mathrm{Enc}}_{\theta_{1}} \circ J_{n}(Z) by evaluating \mu + \sigma \odot (Z^{(1)}, \ldots, Z^{(d)}) \in \mathbb{R}^{d}. This amounts to sampling from the Gaussian law with mean \mu and covariance matrix \mathop{\mathrm{diag}}(\sigma)^2.

  4. Decoding step. Compute \mathop{\mathrm{Dec}}_{\theta_{2}}\circ K\circ \mathop{\mathrm{Enc}}_{\theta_{1}} \circ J_{n}(Z), which maps the encoded version of O_{i}, that is, \mu + \sigma \odot (Z^{(1)}, \ldots, Z^{(d)}), to an element of \mathcal{X}.

3.2 Formal training

Formally, training the VAE involves maximizing the likelihood of O^{1:n} within a parametric model of laws by maximizing a lower bound of the likelihood. This process begins with the introduction of a working model of mixtures for P. The working model (undoubtedly flawed) postulates the existence of a latent random variable U\in \mathbb{R}^{d} and a parametric model of tractable conditional densities

\{o \mapsto p_{\theta_{2}} (o|u) : u \in \mathbb{R}^{d}, \theta_{2} \in \Theta_{2}\}

such that

  • U is drawn from the standard Gaussian law on \mathbb{R}^{d};

  • there exists \theta_{2} \in \Theta_{2} such that, given U, O is drawn from p_{\theta_{2}} (\cdot|U).

Here, tractable densities refer to those that can be easily worked with analytically, while in contrast, intractable densities are too complex to handle directly.

Therefore, the working model (undoubtedly flawed) postulates the existence of \theta_{2} \in \Theta_{2} such that P admits the generally intractable density

o \mapsto \int p_{\theta_{2}}(o|u) \phi_{d}(u)du

where \phi_{d} denotes the density of the standard Gaussian law on \mathbb{R}^{d}. As suggested by the use of the parameter \theta_{2}, the definition of the conditional densities p_{\theta_{2}}(\cdot|u) (u \in \mathbb{R}^{d}) involves the decoder \mathop{\mathrm{Dec}}_{\theta_{2}}.

Since directly maximizing the likelihood of O^{1:n} under the working model is infeasible, a secondary parametric model of tractable conditional densities is introduced:

\{u \mapsto g_{\theta_{1}} (u|O_{i}) : 1 \leq i \leq n, \theta_{1} \in \Theta_{1}\}

to model the conditional laws of U given O_{1}, given O_{2}, \ldots, given O_{n}. Here too, the use of the parameter \theta_{1} indicates that the definition of the conditional densities g_{\theta_{1}} (\cdot|O_{i}) (1 \leq i \leq n) involves the encoder \mathop{\mathrm{Enc}}_{\theta{1}}.

Now, by Jensen’s inequality, for any 1 \leq i \leq n and all \theta = (\theta_{1}, \theta_{2})\in \Theta,

\begin{aligned} \log p_{\theta_{2}}(O_{i}) &= \log \int p_{\theta_{2}} (O_{i}|u) \frac{\phi_{d}(u)}{g_{\theta_{1}}(u|O_{i})} g_{\theta_{1}}(u|O_{i}) du\\ &\geq \int \log\left(p_{\theta_{2}} (O_{i}|u) \frac{\phi_{d}(u)}{g_{\theta_{1}}(u|O_{i})}\right) g_{\theta_{1}}(u|O_{i}) du\\ &= -\mathop{\mathrm{KL}}(g_{\theta_{1}}(\cdot|O_{i}); \phi_{d}) + E_{U\sim g_{\theta_{1}}(\cdot|O_{i})} [\log p_{\theta_{2}}(O_{i}|U)]=: \mathop{\mathrm{LB}}_{\theta}(O_{i}), \end{aligned} \tag{3}

where \mathop{\mathrm{KL}} denotes the Kullback-Leibler divergence and U in the expectation is drawn from the conditional law with density g_{\theta_{1}}(\cdot | O_{i}). The notation \mathop{\mathrm{LB}} is used to indicate that it represents a lower bound. Thus, the likelihood of O^{1:n} under \theta_{2} \in \Theta is lower-bounded by

\sum_{i=1}^{n} \mathop{\mathrm{LB}}_{\theta}(O_{i})

for all \theta_{1} \in \Theta_{1}. As suggested earlier, training the VAE formally consists of solving

\max_{\theta\in\Theta} \left\{\sum_{i=1}^{n}\mathop{\mathrm{LB}}_{\theta}(O_{i})\right\} \tag{4}

rather than solving

\max_{\theta_{2} \in \Theta_{2}} \left\{\sum_{i=1}^{n} \log p_{\theta_{2}} (O_{i})\right\}.

4 How to build the VAE

4.1 A formal description

We implement the classes of encoders and decoders, that is \{\mathop{\mathrm{Enc}}_{\theta_{1}} : \theta_{1} \in \Theta_{1}\} and \{\mathop{\mathrm{Dec}}_{\theta_{2}} : \theta_{2} \in \Theta_{2}\}, as neural network models. Each encoder \mathop{\mathrm{Enc}}_{\theta_{1}} and decoder \mathop{\mathrm{Dec}}_{\theta_{2}} consist of a stack of layers of two types: densely-connected and activation layers (linear, x \mapsto x; ReLU : x \mapsto \max(0, x), softmax: (x_{1}, x_{2}) \mapsto (e^{x_{1}}, e^{x_{2}})/(e^{x_{1}} + e^{x_{2}})). The neural networks are rather simple in design, but nevertheless (moderately) high-dimensional and arguably over-parametrized, as discussed in Section 4.2.

The model \{u \mapsto g_{\theta_{1}}(u|O_{i}) : 1 \leq i \leq n, \theta_{1} \in \Theta_{1}\} is chosen such that U drawn from g_{\theta_{1}}(\cdot|O_{i}) is a Gaussian vector with mean \mu_{i} and covariance matrix \mathop{\mathrm{diag}}(\sigma_{i})^{2} where \mathop{\mathrm{Enc}}_{\theta_{1}}(O_{i}, \cdot) = (\mu_{i}, \sigma_{i}, \cdot), that is, when the \theta_{1}-specific encoding of O_{i} equals (\mu_{i}, \sigma_{i}). Remarkably, the left-hand side term in the definition of \mathop{\mathrm{LB}}_{\theta}(O_{i}) (Equation 3) is then known in closed form:

-\mathop{\mathrm{KL}}(g_{\theta_{1}}(\cdot|O_{i}); \phi_{d}) = \tfrac{1}{2} \sum_{j=1}^{d} \left(1 + \log (\sigma_{i}^{2})_{j} - (\sigma_{i}^{2})_{j} - (\mu_{i}^{2})_{j} \right), \tag{5}

where (\mu_{i}^{2})_{j} and (\sigma_{i}^{2})_{j} are the j-th components of \mu_{i} \odot \mu_{i} and \sigma_{i} \odot \sigma_{i}, respectively. This is very convenient, because Equation 5 makes estimating the term \mathop{\mathrm{KL}}(g_{\theta_{1}}(\cdot|O_{i}); \phi_{d}) unnecessary, a task that would otherwise introduce more variability in the procedure.

As for the model \{o \mapsto p_{\theta_{2}}(o|u) : u \in \mathbb{R}^{d}, \theta_{2} \in \Theta_{2}\}, the only requirement is that it must be chosen in such a way that \log p_{\theta_{2}}(O_{i} | u) be computable for all 1\leq i \leq n, \theta_{2} \in \Theta_{2} and u \in \mathbb{R}^{d}. This is not a tall order as soon as O can be decomposed as a sequence of (e.g., time-ordered) random variables that are vectors with categorical, or integer or real entries. Indeed, it then suffices (i) to decompose the likelihood accordingly under the form of a product of conditional likelihoods, and (ii) to choose a tractable parametric model for each factor in the decomposition. We illustrate the construction of \{o \mapsto p_{\theta_{2}}(o|u) : u \in \mathbb{R}^{d}, \theta_{2} \in \Theta_{2}\} in the context of our running example.

Running example, cted.

In the context of this example, O=(V,W,A,Y) with V \in \{0,1\}^{2}, W \in \mathbb{R}^{3}, A\in\{0,1\} and Y\in\mathbb{R}. Since the source of randomness Z has dimension (d+1), d must satisfy d = d_{1} + 3 for some integer d_{1} \geq 1.

Set \theta = (\theta_{1}, \theta_{2})\in \Theta, u \in \mathbb{R}^{d}, and let \pi\circ \mathop{\mathrm{Dec}}_{\theta_{2}} (u) = (\tilde{v}, \tilde{w}, \tilde{a}, \tilde{y}) \in \mathcal{O}. The conditional likelihood p_{\theta_{2}}(O|u) (of O given U=u) equals

p_{\theta_{2}}(V,W|u) \times p_{\theta_{2}}(A|W,V,u) \times p_{\theta_{2}}(Y|A,W,V,u)

so it suffices to define the conditional likelihoods p_{\theta_{2}}(V,W|u) (of (V,W) given U=u), p_{\theta_{2}}(A|W,V,u) (of A given (W,V) and U=u) and p_{\theta_{2}}(Y|A,W,V,u) (of Y given (A,W,V) and U=u).

  • We decide that V and W are conditionally independent given U under p_{\theta_{2}}(\cdot|u). Therefore, it suffices to characterize the conditional likelihoods p_{\theta_{2}}(V|u) (of V given U=u) and p_{\theta_{2}}(W|u) (of W given U=u).

  • We choose w\mapsto p_{\theta_{2}}(w|u) to be the Gaussian density with mean \tilde{w} and identity covariance matrix.

Running example, cted.
  • The description of the conditional law of V given U=u under p_{\theta_{2}}(\cdot|u) is slightly more involved. It requires that we give more details on the encoders and decoders.

    • Like every encoder, \mathop{\mathrm{Enc}}_{\theta_{1}} actually maps \mathcal{O}\times \mathbb{R}^{d} to [\mathbb{R}^{d_{1}}\times \{0\}^{3}] \times [(\mathbb{R}_{+}^{*})^{d_{1}} \times \{1\}^{3}] \times \mathbb{R}^{d}. In words, if \mathop{\mathrm{Enc}}_{\theta_{1}}(o, \cdot) = (\mu, \sigma, \cdot), then it necessarily holds that the three last components of \mu and \sigma are 0 and 1, respectively. Therefore the three last components of the random vector K \circ \mathop{\mathrm{Enc}}_{\theta_{1}} \circ J_{n}(Z) equal Z^{(d-2)}, Z^{(d-1)}, Z^{(d)}, three independent standard normal random variables.

    • To compute \mathop{\mathrm{Dec}}_{\theta_{2}}(u) = (\tilde{v}, \tilde{w}, \tilde{y}_{0}, \tilde{y}_{1}, \tilde{a}) \in \mathcal{X}, we actually compute \tilde{w} then \tilde{v}, then (\tilde{y}_{0}, \tilde{y}_{1},\tilde{a}).

      • The output \tilde{w} is a \theta_{2}-specific deterministic function of the first d_{1} components of u.

      • The output \tilde{v} is a \theta_{2}-specific deterministic function of the (d_{1}+2) first components of u.

        More specifically, two (latent) probabilities \tilde{g}_{1}, \tilde{g}_{2} are first computed, as \theta_{2}-specific deterministic functions of the d_{1} first components of u. Then \tilde{v}_{1} and \tilde{v}_{2} are set to \textbf{1}\{\Phi(u^{(d_{1}+1)}) \leq \tilde{g}_{1}\} and \textbf{1}\{\Phi(u^{(d_{1}+2)}) \leq \tilde{g}_{2}\}, where \Phi denotes the standard normal cumulative distribution function (c.d.f) and u^{(d_{1}+1)}, u^{(d_{1}+2)} are the (d_{1}+1)-th and (d_{1}+2)-th components of u.

        For instance, \tilde{v}_{1} is given the value 1 if \Phi(u^{(d_{1}+1)}) \leq \tilde{g}_{1} and 0 otherwise. Note that \textbf{1}\{\Phi(Z^{(d_{1}+1)}) \leq \tilde{g}_{1}\} follows the Bernoulli law with parameter \tilde{g}_{1} because Z^{(d_{1}+1)} is drawn from the standard normal law.

      • The output (\tilde{y}_{0}, \tilde{y}_{1}) is a \theta_{2}-specific deterministic function of (\tilde{v}, \tilde{w}) and the d_{1} first components of u.

      • The output \tilde{a} is a \theta_{2}-specific deterministic function of (\tilde{v}, \tilde{w}) and the last component of u.

        More specifically, a (latent) probability \tilde{h} is first computed, as a \theta_{2}-specific deterministic function of (\tilde{v}, \tilde{w}). Then \tilde{a} is set to \textbf{1}\{\Phi(u^{(d)}) \leq \tilde{h}\}.

        Note that \textbf{1}\{\Phi(Z^{(d)}) \leq \tilde{h}\} follows the Bernoulli law with parameter \tilde{h} because Z^{(d)} is drawn from the standard normal law.

    We are now in a position to describe the conditional law of V given U=u. We decide that, conditionally on U=u, under p_{\theta_{2}}(\cdot|u), V_{1} and V_{2} are independently drawn from the Bernoulli laws with parameters \tilde{g}_{1} and \tilde{g}_{2}. Thus, p_{\theta_{2}}(\cdot|u) is such that p_{\theta_{2}}(v|u) = [v_{1} \tilde{g}_{1} + (1-v_{1})(1-\tilde{g}_{1})] \times [v_{2} \tilde{g}_{2} + (1-v_{2})(1-\tilde{g}_{2})] for v=(v_{1}, v_{2}) \in \{0,1\}^{2}.

Running example, cted.
  • The description of the conditional law of A given (W, V) and U=u under p_{\theta_{2}}(\cdot|u) is similar to that of V given U. We decide that, conditionally on (W, V) and U=u, under p_{\theta_{2}}(\cdot|W, V, u), A follows the Bernoulli law with parameter \tilde{\underline{h}}(V,W), where the probability \tilde{\underline{h}}(v,w) lies between \tilde{h} and \bar{A}_{n}:=\tfrac{1}{n}\sum_{i=1}^{n} A_{i} and is given, for any (v,w) \in \{0,1\}^{2} \times \mathbb{R}^{3}, by

    \begin{aligned} \tilde{\underline{h}}(v,w) :=& t(v,w) \tilde{h} + [1 - t(v,w)] \bar{A}_{n} \quad \text{with}\\ -10\log t(v,w) =& - \left[v_{1} \log \tilde{g}_{1} + (1-v_{1}) \log (1-\tilde{g}_{1})\right]\\ & - \left[v_{2} \log \tilde{g}_{2} + (1-v_{2}) \log (1-\tilde{g}_{2})\right]\\ &+\|w - \tilde{w}\|_{2}^{2}. \end{aligned} Thus, p_{\theta_{2}}(\cdot|W,V,u) is such that p_{\theta_{2}}(a|W,V,u) = a \tilde{\underline{h}(V,W)} + (1-a)(1-\tilde{\underline{h}}(V,W)) for a \in \{0,1\}.

  • Finally, we choose y \mapsto p_{\theta_{2}}(y|A,W,V,u) to be the two-regime density given by

    p_{\theta_{2}}(y|A,W,V,u) = \frac{\textbf{1}\{A=\tilde{a}\}}{\tilde{s}(W)} \phi_{1} \left(\frac{y - \tilde{y}}{\tilde{s}(W)}\right) + \textbf{1}\{A\neq \tilde{a}\} C^{-1} where \tilde{s}(w):= \tfrac{1}{\sqrt{5}} \|w - \tilde{w}\|_{2} for any w \in \mathbb{R}^{3} and C is the Lebesgue measure of the support of the marginal law of Y under P (it does not matter if C is unknown).

    Thus, two cases arise:

    • If A = \tilde{a}, then Y is conditionally drawn under p_{\theta_{2}}(\cdot|A,W,V,u) from the Gaussian law with mean \tilde{y} = a \tilde{y}_{1} + (1-a) \tilde{y}_{0} and variance \tilde{s}(W)^{2}.
    • Otherwise, Y is conditionally drawn under p_{\theta_{2}}(\cdot|A,W,V,u) from the uniform law on the support of the marginal law of Y under P.

    Therefore, the conditional likelihood p_{\theta_{2}}(Y|A,W,V,u) bears information only if A=\tilde{a} (that is, if the actions A and \tilde{a} undertaken when generating O=(V,W,A,Y) and computing \mathop{\mathrm{Dec}}_{\theta_{2}}(u) coincide), which can be interpreted as a necessary condition to justify the comparison of the rewards Y and \tilde{y}. Moreover, when A=\tilde{a}, the closer are the contexts W and \tilde{w}, the more relevant is the comparison and the larger the magnitude of p_{\theta_{2}}(Y|A,W,V,u) can be.

Running example, cted.

In summary, the right-hand side term in the definition of \mathop{\mathrm{LB}}_{\theta}(O_{i}) Equation 3 equals, up to a term that does not depend on \theta,

\begin{aligned} \tfrac{1}{2}E_{U \sim g_{\theta_{1}}(\cdot|O_{i})} \Bigg[ & -2\left(V_{1,i} \log \tilde{G}_{1} + (1-V_{1,i}) \log (1 - \tilde{G}_{1})\right) \\ & -2\left(V_{2,i} \log \tilde{G}_{2} + (1-V_{2,i}) \log (1 - \tilde{G}_{2})\right) \\ & - \|W_{i} - \tilde{W}\|_{2}^{2}\\ & -2\left(A_{i} \log \tilde{\underline{H}} + (1-A_{i}) \log [1 - \tilde{\underline{H}}]\right)\\ & -\textbf{1}\{A_{i} = \tilde{A}\} \times \left(\log \tilde{S}(W_{i})^{2} + \frac{(Y_{i} - \tilde{Y})^{2}}{2\tilde{S}(W_{i})^{2}}\right)\Bigg], \end{aligned} \tag{6}

with the notational conventions \pi \circ \mathop{\mathrm{Dec}}_{\theta_{2}} (U) = (\tilde{V}, \tilde{W}, \tilde{A}, \tilde{Y}), V_{i} = (V_{i,1}, V_{i,2}), and where \tilde{G}_{1}, \tilde{G}_{2}, \tilde{\underline{H}}, \tilde{S} are defined like the above latent quantities \tilde{g}_{1}, \tilde{g}_{2}, \tilde{\underline{h}}, \tilde{s} with U substituted for u. The expression is easily interpreted: the opposite of Equation 6 is an average risk that measures - the likelihood of V_{i,1} and V_{i,2} from the points of view of the Bernoulli laws with parameters \tilde{G}_{1} and \tilde{G}_{2} (first and second terms), - the average proximity between W_{i} and \tilde{W} (third term), - the likelihood of A_{i} from the point of view of the Bernoulli law with parameter \tilde{\underline{H}} (fourth term), - the average proximity between Y_{i} and \tilde{Y} (fifth term) only if A_{i} = \tilde{A} (otherwise, the comparison would be meaningless).

In other terms, the opposite of Equation 6 can be interpreted as a measure of the average faithfulness of the reconstruction of O_{i} under the form \pi \circ \mathop{\mathrm{Dec}}_{\theta_{2}} (U) with U drawn from g_{\theta_{1}}(\cdot|O_{i}). The larger is Equation 6, the better is the reconstruction of O_{i} under the form \pi \circ \mathop{\mathrm{Dec}}_{\theta_{2}} (U) with U drawn from g_{\theta_{1}}(\cdot|O_{i}).

To conclude, note that the conditional laws of W and Y, both Gaussian, could easily be associated with diagonal covariance matrices different from the identity matrix. This adjustment would be particularly relevant in situations where \|W\|_{2} and |Y| are typically not of the same magnitude, with O=(V,W,A,Y) drawn from the law P of the experiment of interest. Alternatively, the genuine observations could be pre-processed to ensure that \|W\|_{2} and |Y| are brought to comparable magnitudes.

The hope is that, once the VAE is trained, yielding a parameter \widehat{\theta}_{n}= ((\widehat{\theta}_{n})_{1}, (\widehat{\theta}_{n})_{2}), the corresponding generator \mathop{\mathrm{Gen}}_{\widehat{\theta}_{n}} produces a synthetic complete data X\in\mathcal{X} such that the law of \pi(X) \in\mathcal{O} closely approximates P. Naturally, this approximation is closely related to the conditional densities g_{(\widehat{\theta}_{n})_{1}}(\cdot|O_{i}) and p_{(\widehat{\theta}_{n})_{2}}(\cdot | u) (1 \leq i \leq n, u \in \mathbb{R}).

For instance, in the context of the running example, if O = (V,W, A, Y) = \pi \circ \mathop{\mathrm{Gen}}_{\widehat{\theta}_{n}}(Z) and if \xi, \zeta are independently drawn from the centered Gaussian laws with an identity covariance matrix on \mathbb{R}^{3} and variance 1 on \mathbb{R}, respectively, then (W + \xi, A, Y + \zeta) follows a law that admits the density

\begin{aligned} (v, w, a, y)\mapsto \int & p_{(\widehat{\theta}_{n})_{2}}(y|a, w, v, u) \times \textbf{1}\{a=\tilde{a}_{(\widehat{\theta}_{n})_{2}}(u)\}\\ & \times p_{(\widehat{\theta}_{n})_{2}}(w,v|u)\left(\frac{1}{n}\sum_{i=1}^{n} g_{(\widehat{\theta}_{n})_{1}}(u |O_{i}) \right) du, \end{aligned} where \tilde{a}_{(\widehat{\theta}_{n})_{2}}(u) is defined as the A-coefficient of \pi\circ \mathop{\mathrm{Dec}}_{(\widehat{\theta}_{n})_{2}}(u).

4.2 About the over-parametrization

In Section 4.1 we acknowledged that the models \{\mathop{\mathrm{Enc}}_{\theta_{1}} : \theta_{1} \in \Theta_{1}\} and \{\mathop{\mathrm{Dec}}_{\theta_{2}} : \theta_{2} \in \Theta_{2}\} are over-parametrized in the sense that the dimensions of the parameter set \Theta_{1}\times \Theta_{2} is potentially large. For instance, the dimension of the model that we build in the next section is 1157. This is a common feature of neural networks.

Our models are also over-parametrized in the sense that they are not identifiable. This is obviously the case because of the loss of information that governs the derivation of an observation O as a piece \pi(X) of a complete data X that we are not given to observe in its entirety.

Running example, cted.

In particular, in the context of this example, it is well know that we cannot learn from O_{1}, \ldots, O_{n} any feature of the joint law of the counterfactual random variables (Y[0], Y[1]) that does not reduce to a feature of the marginal laws of Y[0] or Y[1], unless we make very strong assumptions on this joint law (e.g., that Y[0] and Y[1] are independent).

This is not a source of concern. First, it is generally recognized that the fitting of neural networks often benefits from the high dimensionality of the optimization space and the presence of numerous equivalently good local optima, resulting in a redundant optimization landscape (Choromanska et al. 2015), (Arora, Cohen, and Hazan 2018). Second, our objective is to construct a generator that approximates the law P of O_{1}, \ldots, O_{n}, generating O\in\mathcal{O} by first producing X\in\mathcal{X} (via \mathop{\mathrm{Gen}}_{\theta}(Z)) and then providing \pi(X). The fact that two different generators \mathop{\mathrm{Gen}}_{\theta} and \mathop{\mathrm{Gen}}_{\theta'} can perform equally well is not problematic. Identifying one generator \mathop{\mathrm{Gen}}_{\theta} that performs well is sufficient.

It is possible to search for generators that satisfy user-supplied constraints, provided these can be expressed as a real-valued criterion F(E[\mathcal{C}(\mathop{\mathrm{Gen}}_{\theta}(Z))]). For example, one may wish to construct a generator \mathop{\mathrm{Gen}}_{\theta} such that the components of X under \mathop{\mathrm{law}}(\mathop{\mathrm{Gen}}_{\theta}) exhibit a pre-specified correlation pattern (as demonstrated in the simple example below).

To focus the optimization procedure on generators that approximately meet these constraints, one can modify the original criterion Equation 4 by adding a penalty term. Specifically, given a user-supplied hyper-parameter \lambda > 0, we can substitute

\max_{\theta\in\Theta} \left\{\sum_{i=1}^{n}\mathop{\mathrm{LB}}_{\theta}(O_{i}) + \lambda F(E_{Z \sim \mathop{\mathrm{Unif}}\{1, \ldots, n\} \otimes N(0,1)^{\otimes d}}[\mathcal{C}(\mathop{\mathrm{Gen}}_{\theta}(Z))])\right\} \tag{7}

for Equation 4. From a computational perspective, this adjustment simply involves adding the term

\lambda F\left(\frac{1}{m}\sum_{i=1}^{m} \mathcal{C}(\mathop{\mathrm{Gen}}_{\theta} (Z_{m+i}))\right) \tag{8}

to the expressions within the curly brackets in the definition of g in the algorithm described in Section 5.5.

Running example, cted.

In particular, in the context of this example, we could look for generators \mathop{\mathrm{Gen}}_{\theta} such that the correlation of Y[0] and Y[1] under \mathop{\mathrm{law}}(\mathop{\mathrm{Gen}}_{\theta} (Z)) be close to a target correlation r \in ]-1, 1[. In that case, we could choose \mathcal{C}(\mathop{\mathrm{Gen}}_{\theta} (Z)) := (Y[0]Y[1], Y[0]^{2}, Y[1]^{2}, Y[0], Y[1]) and F : (a,b,c,d,e) \mapsto |(a - de)/\sqrt{(b - d^2)(c - e^2)} - r|.

5 Implementation of the VAE in the context of the running example

We now show how to implement the classes of encoders and decoders, hence of generators, in the context of our running example. We will also define other loss functions that are needed to train the model.

We implemented our approach using the TensorFlow package, but also experimented with PyTorch. Both frameworks yielded similar results, with no noticeable differences in performance. However, we found TensorFlow to be slightly more beginner-friendly, which might make it easier for readers new to neural network frameworks to follow our implementation.

5.1 Implementing the encoder

The first chunk of code defines a function, namely build_encoder, to build \mathop{\mathrm{Enc}}_{\theta_{1}}. The parameter latent_dim is the Python counterpart of d_{1}. The parameters nlayers_encoder and nneurons_encoder are the numbers of layers and of neurons in each layer, respectively. The parameter L will be discussed later.

Hide/Show the code
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
Traceback (most recent call last):
  File "/home/runner/micromamba/envs/micromamba/lib/python3.12/site-packages/tensorflow/python/pywrap_tensorflow.py", line 27, in <module>
    import ssl
  File "/home/runner/work/draw_me_a_simulator/draw_me_a_simulator/renv/cache/v5/linux-ubuntu-noble/R-4.4/x86_64-pc-linux-gnu/reticulate/1.41.0.1/43239d1c5749802890d904295dbdc4a8/reticulate/python/rpytools/loader.py", line 122, in _find_and_load_hook
    return _run_hook(name, _hook)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/draw_me_a_simulator/draw_me_a_simulator/renv/cache/v5/linux-ubuntu-noble/R-4.4/x86_64-pc-linux-gnu/reticulate/1.41.0.1/43239d1c5749802890d904295dbdc4a8/reticulate/python/rpytools/loader.py", line 96, in _run_hook
    module = hook()
             ^^^^^^
  File "/home/runner/work/draw_me_a_simulator/draw_me_a_simulator/renv/cache/v5/linux-ubuntu-noble/R-4.4/x86_64-pc-linux-gnu/reticulate/1.41.0.1/43239d1c5749802890d904295dbdc4a8/reticulate/python/rpytools/loader.py", line 120, in _hook
    return _find_and_load(name, import_)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/micromamba/envs/micromamba/lib/python3.12/ssl.py", line 100, in <module>
    import _ssl             # if we can't import it, let the error propagate
    ^^^^^^^^^^^
  File "/home/runner/work/draw_me_a_simulator/draw_me_a_simulator/renv/cache/v5/linux-ubuntu-noble/R-4.4/x86_64-pc-linux-gnu/reticulate/1.41.0.1/43239d1c5749802890d904295dbdc4a8/reticulate/python/rpytools/loader.py", line 122, in _find_and_load_hook
    return _run_hook(name, _hook)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/draw_me_a_simulator/draw_me_a_simulator/renv/cache/v5/linux-ubuntu-noble/R-4.4/x86_64-pc-linux-gnu/reticulate/1.41.0.1/43239d1c5749802890d904295dbdc4a8/reticulate/python/rpytools/loader.py", line 96, in _run_hook
    module = hook()
             ^^^^^^
  File "/home/runner/work/draw_me_a_simulator/draw_me_a_simulator/renv/cache/v5/linux-ubuntu-noble/R-4.4/x86_64-pc-linux-gnu/reticulate/1.41.0.1/43239d1c5749802890d904295dbdc4a8/reticulate/python/rpytools/loader.py", line 120, in _hook
    return _find_and_load(name, import_)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ImportError: /usr/lib/x86_64-linux-gnu/libcrypto.so.3: version `OPENSSL_3.3.0' not found (required by /home/runner/micromamba/envs/micromamba/lib/python3.12/lib-dynload/_ssl.cpython-312-x86_64-linux-gnu.so)


Warning: Failed to load ssl module. Continuing without ssl support.
Hide/Show the code
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Lambda, Dense, Activation, concatenate,\
  Reshape, Concatenate
import tensorflow.keras.backend as KK
## import tensorflow_probability as tfp

def build_encoder(dimW,
                  dimV,
                  latent_dim,
                  L,
                  nlayers_encoder,
                  nneurons_encoder):
  if isinstance(nneurons_encoder, int):
    nneurons_encoder = np.array([nneurons_encoder])
  if not (nlayers_encoder is None):
    if len(nneurons_encoder) == 1:
      nneurons_encoder = np.repeat(nneurons_encoder, nlayers_encoder)
    elif not (len(nneurons_encoder) == nlayers_encoder):
      print("Inconsistency of 'nneurons' and 'nlayers' in the encoder's definition...\n")
  dimWV = dimW + dimV
  nlayers_encoder = len(nneurons_encoder)
  e_obs_input = Input(shape = (dimWV + 2,),
                      name = "e_obs_input")
  e_alea_input = Input(shape = (L, latent_dim + dimV + 1, ),
                        name = "e_alea_input")
  e = Dense(nneurons_encoder[0],
            name = "dense_e_" + str(0),
            activation = "relu")(e_obs_input)
  for i in range(nlayers_encoder - 2):
    e = Dense(nneurons_encoder[i + 1],
              name = "dense_e_" + str(i + 1),
              activation = "relu")(e)
  mu = Dense(latent_dim,
              name = "mu")(e)
  zeros = KK.zeros_like(mu)
  zeros = Lambda(lambda x: x[:, :(dimV + 1)], name="zeros")(zeros)
  mu = Concatenate(axis = 1,
                   name = "concatenate_mu_zeros")([mu, zeros])
  log_sigma = Dense(latent_dim,
                    name = "log_sigma")(e)
  log_sigma = Lambda(lambda x: -tf.math.exp(x), name = "minus_exponential")(log_sigma)
  log_sigma = Concatenate(axis = 1,
                          name = "concatenate_log_sigma_zeros")([log_sigma, zeros])
  param = Concatenate(name = "param")([mu, log_sigma])
  encoder = Model(inputs = [e_obs_input, e_alea_input],
                  outputs = [param, e_alea_input])
  return encoder

The code related to encoding is complete.

5.2 Implementing the decoder

The first chunk of code defines the component of \mathop{\mathrm{Dec}}_{\theta_{2}}, namely build_WV_decoder, that generates (V,W) based on U. It also defines a function, as_sample, that allows to approximately draw from a discrete distribution. The parameters nlayers_WV_decoder and nneurons_WV_decoder are the numbers of layers and of neurons in each layer, respectively. The parameter L will be discussed later.

We say that as_sample allows to sample approximately from a discrete distribution since we cannot simply draw from it because of the need for this operation to be differentiable with respect to (w.r.t.) the parameters of the neural network. Instead, we use the fact that, for \beta>0 a large constant and Z a standard normal random variable, the law of the random variable

\mathop{\mathrm{expit}}(-\beta(\Phi(Z) - p)) \tag{9}

(recall that \Phi is the c.d.f. of the standard normal law) is concentrated around \{0,1\}, a small neighborhood of 1 having mass approximately p and a small neighborhood of 0 having mass approximately (1-p). For instance Figure 2 shows the empirical cumulative distribution function of 1000 independent copies of the random variable defined in Equation 9 with \beta=30 and p=1/3:

Hide/Show the code
library(magrittr)
beta <- 30
p <- 1/3
tibble::tibble(Z = stats::rnorm(1e3)) %>%
  dplyr::mutate(U = stats::pnorm(Z)) %>%
  dplyr::mutate(V = stats::plogis(-beta * (U - p))) %>%
  ggplot2::ggplot() +
  ggplot2::geom_histogram(ggplot2::aes(x = V,
                                       y = ggplot2::after_stat(count/sum(count)))) +
  ggplot2::scale_y_continuous(labels = scales::percent) +
  ggplot2::xlab(quote(expit(-beta(Phi(Z) - p)))) +
  ggplot2::ylab("empirical proportion")
Figure 2: Empirical c.d.f. of 1000 independent copies of the random variable defined in Equation 9 with \beta=30 and p=1/3. The law is close the Bernoulli law with parameter \tfrac{1}{3}.
Hide/Show the code
import sys

def as_sample(tensors, beta = 30): 
  probs = tensors[0] # (pi_v)_{v in {0,...,nb_cat - 1}} where pi_v = P(V = v)
  nb_cat = probs.shape[2]
  Z = tensors[1] # the source of randomness
  U = tfp.distributions.Normal(loc = 0, scale = 1).cdf(Z)
  Pi = KK.cumsum(probs, axis = 2) # computing (Pi_v)_{v in {0,...,nb_cat - 1}}
  Pi_prev = Pi[..., 0]
  as_V = KK.sigmoid(-beta * (U - Pi_prev)) # Z[0] ~= 1_{U<pi_0}
  as_V = tf.expand_dims(as_V, -1)
  for v in range(nb_cat - 2):
    Pi_next = Pi[..., v + 1]
    as_V_v = KK.sigmoid(- beta * (U - Pi_next)) - KK.sigmoid(- beta * (U - Pi_prev))
    as_V_v = tf.expand_dims(as_V_v, -1)
    Pi_prev = Pi_next
    as_V = concatenate([as_V, as_V_v])
  Pi_last = Pi[..., nb_cat - 2]
  as_V_v = 1 - KK.sigmoid(- beta * (U - Pi_last))
  as_V_v = tf.expand_dims(as_V_v, -1)
  as_V = concatenate([as_V, as_V_v])
  return as_V

def build_WV_decoder(dimW,
                     dimV,
                     nb_cat_V,
                     latent_dim,
                     L,
                     nlayers_WV_decoder,
                     nneurons_WV_decoder,
                     activation_W = "linear"):
  if isinstance(nneurons_WV_decoder, int):
    nneurons_WV_decoder = np.array([nneurons_WV_decoder])
  if not (nlayers_WV_decoder is None):
    if len(nneurons_WV_decoder) == 1:
      nneurons_WV_decoder = np.repeat(nneurons_WV_decoder, nlayers_WV_decoder)
    elif not (len(nneurons_WV_decoder) == nlayers_WV_decoder):
      print("Inconsistency of 'nneurons' and 'nlayers' in the WV decoder's definition...\n")
  nlayers_WV_decoder = len(nneurons_WV_decoder)
  random_latent_vectors = Input(shape = (L, latent_dim + dimV, ),
                                name = "input_WV")
  alea = tf.slice(random_latent_vectors, (0, 0, latent_dim), (-1, -1, dimV))
  WV = tf.slice(random_latent_vectors, (0, 0, 0), (-1, -1, latent_dim))
  WV = Dense(nneurons_WV_decoder[0],
             name = "dense_WV_" + str(0),
             activation = "relu")(WV)
  for i in range(nlayers_WV_decoder - 1):
    WV = Dense(nneurons_WV_decoder[i + 1],
               name = "dense_WV_" + str(i + 1),
               activation = "relu")(WV)
  WV = Dense(dimW + dimV,
             name = "dense_WV_final")(WV)
  if dimW > 0:
    W = tf.slice(WV, (0, 0, 0), (-1, -1, dimW))
    W = Activation(activation_W,
                   name = "activation_W")(W)
  if dimV > 0:
    V_slice = tf.slice(WV, (0, 0, dimW), (-1, -1, 1))
    V_law = Dense(nb_cat_V[0],
                  activation = "softmax",
                  name = "softmax_V_0")(V_slice)
    V_random_latent_vectors = Lambda(lambda x: x[..., 0],
                                     name = "extr_V_alea_0")(alea)
    V = Lambda(as_sample,
               name = "as_V_0")([V_law, V_random_latent_vectors])
    V = Lambda(lambda x, y: tf.linalg.matmul(x, np.reshape(np.arange(y), (y, 1))),
               arguments={'y': nb_cat_V[0]},
               name = "lambda_V_0")(V)
    V = KK.cast(V, dtype = "float32")
    if dimV > 1:
      for v in range(dimV - 1):
        V_slice_v = tf.slice(WV, (0, 0, dimW + v + 1), (-1, -1, 1))
        V_law_v = Dense(nb_cat_V[v + 1],
                        activation = "softmax",
                        name = "softmax_V_" + str(v + 1))(V_slice_v)
        V_random_latent_vectors = Lambda(lambda x: x[..., v + 1],
                                         name = "extr_V_alea_" + str(v + 1))(alea)
        V_v = Lambda(as_sample,
                     name = "as_V_" + str(v + 1))([V_law_v, V_random_latent_vectors])
        V_v = Lambda(lambda x,y: tf.linalg.matmul(x, np.reshape(np.arange(y), (y, 1))),
                     arguments={'y': nb_cat_V[v+1]},
                     name = "lambda_V_" + str(v + 1))(V_v)
        V_v = KK.cast(V_v, dtype = "float32")
        V_law = concatenate([V_law, V_law_v])
        V = concatenate([V, V_v])

  if dimW > 0 and dimV > 0:
    WV = concatenate([W, V])
  elif dimW == 0:
    WV = V
  elif dimV == 0:
    WV = W
  else:
    sys.exit("One at least of 'dimW' and 'dimV' must be positive...\n")
  if dimV > 0:
    WV_decoder = Model(inputs = random_latent_vectors,
                       outputs = [WV, V_law],
                       name = "WV_decoder")
  else:
     WV_decoder = Model(inputs = random_latent_vectors,
                        outputs = WV,
                        name = "WV_decoder")
  return WV_decoder

The second chunk of code defines the component of \mathop{\mathrm{Dec}}_{\theta_{2}}, namely build_Alaw_decoder, that generates a conditional law for A given (V,W). The parameters nlayers_Alaw_decoder and nneurons_Alaw_decoder are the numbers of layers and of neurons in each layer, respectively. The parameter L will be discussed later.

Hide/Show the code
def build_Alaw_decoder(dimW,
                       dimV,
                       L,
                       nlayers_Alaw_decoder,
                       nneurons_Alaw_decoder):
  if isinstance(nneurons_Alaw_decoder, int):
    nneurons_Alaw_decoder = np.array([nneurons_Alaw_decoder])
  if not (nlayers_Alaw_decoder is None):
    if len(nneurons_Alaw_decoder) == 1:
      nneurons_Alaw_decoder = np.repeat(nneurons_Alaw_decoder, nlayers_Alaw_decoder)
    elif not (len(nneurons_Alaw_decoder) == nlayers_Alaw_decoder):
      print("Inconsistency of 'nneurons' and 'nlayers' in  the Alaw decoder's definition...\n")
  dimWV = dimW + dimV
  nlayers_Alaw_decoder = len(nneurons_Alaw_decoder)
  WV = Input(shape = (L, dimWV,), name = "input_Alaw")
  Alaw = Dense(nneurons_Alaw_decoder[0],
               name = "dense_Alaw_0",
               activation = "relu")(WV)
  if nlayers_Alaw_decoder > 1:
    for i in range(nlayers_Alaw_decoder-1):
      Alaw = Dense(nneurons_Alaw_decoder[i + 1],
                   name = "dense_Alaw_" + str(i + 1),
                   activation = "relu")(Alaw)
  Alaw = Dense(2,
               activation = "softmax",
               name = "dense_Alaw_final")(Alaw)
  Alaw_decoder = Model(inputs = WV,
                       outputs = Alaw,
                       name = "Alaw_decoder")
  return Alaw_decoder

The third chunk of code defines the component of \mathop{\mathrm{Dec}}_{\theta_{2}}, namely build_AYaY_decoder, that generates the counterfactual outcomes Y[0] and Y[1], the action carried out A and the corresponding reward Y. The parameters nlayers_AYaY_decoder and nneurons_AYaY_decoder are the numbers of layers and of neurons in each layer, respectively. The parameter L will be discussed later.

Hide/Show the code
def get_Y(tensors): # tensors[0]: (Y_{0}, Y_{1})
                    # tensors[1]: (1-A, A)
  Y = tensors[0] * tensors[1]
  Y = KK.sum(Y, axis = -1, keepdims = True)
  return Y

def trick_with_relu(x):
  out = tf.keras.activations.relu(x - 1/2)/(x - 1/2)
  out = tf.gather(out, 1, axis = 1)
  out = tf.expand_dims(out, -1)
  return out

def build_AYaY_decoder(dimW,
                       dimV,
                       latent_dim,
                       L,
                       nlayers_AYaY_decoder,
                       nneurons_AYaY_decoder,
                       activation_Ya = "linear"):
  if isinstance(nneurons_AYaY_decoder, int):
    nneurons_AYaY_decoder = np.array([nneurons_AYaY_decoder])
  if not (nlayers_AYaY_decoder is None):
    if len(nneurons_AYaY_decoder) == 1:
      nneurons_AYaY_decoder = np.repeat(nneurons_AYaY_decoder, nlayers_AYaY_decoder)
    elif not (len(nneurons_AYaY_decoder) == nlayers_AYaY_decoder):
      print("Inconsistency of 'nneurons' and 'nlayers' in the AYaY decoder's definition...\n")
  dimWV = dimW + dimV
  nlayers_AYaY_decoder = len(nneurons_AYaY_decoder)
  random_latent_vectors = Input(shape = (L, latent_dim + 1, ),
                                name = "alea_input_AYaY")
  WV = Input(shape = (L, dimWV, ),
             name = "WV_input_AYaY")
  Alaw = Input(shape = (L, 2,),
               name = "Alaw_input_AYaY")
  Arandom_latent_vectors = Lambda(lambda x: x[..., latent_dim],
                                  name = "extr_alea_AYaY")(random_latent_vectors)
  as_A = Lambda(as_sample,
                name = "as_A")([Alaw, Arandom_latent_vectors])
  A = Lambda(lambda x: tf.map_fn(trick_with_relu, x),
             name = "lambda_A")(as_A)
  A = Lambda(lambda x: KK.cast(x, dtype = "float32"),
             name = "cast_A_float32")(A)
  Ya = Lambda(lambda x: x[..., 0:latent_dim],
              name = "extr_Ya_alea")(random_latent_vectors)
  Ya = Dense(nneurons_AYaY_decoder[0],
             name = "dense_Ya",
             activation = "relu")(concatenate([WV, Ya]))
  if nlayers_AYaY_decoder > 1:
    for i in range(nlayers_AYaY_decoder - 1):
      Ya = Dense(nneurons_AYaY_decoder[i + 1],
                 name = "dense_Ya_" + str(i + 1),
                 activation = "relu")(Ya)
  Ya = Dense(2,
             activation = activation_Ya,
             name = "Ya")(Ya)
  Y = Lambda(get_Y,
             name = "Y")([Ya, as_A])
  AYaY_decoder = Model(inputs = [random_latent_vectors, WV, Alaw],
                       outputs = concatenate([Ya, A, Y]),
                       name = "AYaY_decoder")
  return AYaY_decoder

Two comments are in order:

  • Given the counterfactual rewards Y[0] and Y[1] (outputs of the layer 'Ya' in build_AYaY_decoder), given the approximate action A^{\flat} (output of the layer 'as_A' in build_AYaY_decoder), the actual reward Y (output of the layer 'Y'in build_AYaY_decoder) is defined as the weighted mean

    Y = A^{\flat} Y[1] + (1 - A^{\flat}) Y[0] with A^{\flat} close to 0 and 1 (see the above comment on as_sample).

  • The actual action A (output of the layer lambda_A in build_AYaY_decoder) is derived from A^{\flat} under the form A = \frac{\mathop{\mathrm{ReLU}}\left(A^{\flat} - \tfrac{1}{2}\right)}{A^{\flat} - \tfrac{1}{2}}, assuming that A^{\flat} never takes on the value \tfrac{1}{2}. By doing so, A is (almost everywhere) differentiable w.r.t. the parameters of the neural network.

The code related to decoding is complete.

5.3 Implementing the coarsening functions

The first chunk of code defines a function used to build the coarsening function \pi.

Hide/Show the code
def build_pi(dimW,
             dimV,
             L):
  dimWV = dimW + dimV
  dim_data = dimWV + 6
  input_data = Input(shape = (L, dim_data, ),
                     name = "input_data")
  li = list(range(0, dimWV)) + list(range(dim_data - 2, dim_data))
  WVAY = Lambda(lambda x, li: tf.gather(x, li, axis = 2),
                arguments = {"li": li})(input_data)
  pi = Model(inputs = input_data,
             outputs = WVAY,
             name = "pi")
  return pi

The next chunk of code defines a function used to extract the conditional probability that A=1 given W (denoted earlier as \tilde{G}).

Hide/Show the code
def build_extract_G(dimW,
                    dimV,
                    L):
  dimWV = dimW + dimV
  dim_data = dimWV + 6
  input_data = Input(shape = (L, dim_data, ),
                     name = "input_data")
  Alaw = Lambda(lambda x: tf.gather(x, dimWV + 1, axis = 2))(input_data)
  extract_G = Model(inputs = input_data,
                    outputs = Alaw,
                    name = "extract_G")
  return extract_G

5.4 Implementing the generator

At long last we are in a position to define a function, namely build_generator, whose purpose is to build the generator \mathop{\mathrm{Gen}}_{\theta}. The chunk of code also defines the function K which is the Python counterpart of K introduced in Section 3.1 and Section 4.1.

Hide/Show the code
def K(x):
  param = x[0]
  alea = x[1]
  d = tf.cast(param.shape[1]/2, tf.int32)
  L = tf.cast(alea.shape[1], tf.int32)
  batch_size = tf.shape(alea)[0]
  mu = tf.slice(param, [0, 0], [-1, d])
  mu = tf.repeat(mu, L, axis = 0)
  mu = tf.reshape(mu, (batch_size, L, latent_dim + dimV + 1))
  log_sigma = tf.slice(param, [0, d], [-1, -1])
  log_sigma = tf.repeat(log_sigma, L, axis = 0)
  log_sigma = tf.reshape(log_sigma, (batch_size, L, latent_dim + dimV + 1))
  out = KK.cast(mu, dtype = "float32") \
        + KK.cast(tf.exp(log_sigma / 2), dtype = "float32") * alea
  return out

def build_generator(dimW,
                    dimV,
                    encoder,
                    latent_dim,
                    L,
                    WV_decoder,
                    Alaw_decoder,
                    AYaY_decoder,
                    pi,
                    extract_G):

  dimWV = dimW + dimV
  e_obs_input = Input(shape = (dimWV + 2,))
  e_alea_input = Input(shape = (L, latent_dim + dimV + 1, ))
  code = encoder([e_obs_input, e_alea_input])
  param = code[0]
  e = Lambda(lambda x: K(x), name = "K")(code)

  wv_all = Lambda(lambda x: x[..., 0:(latent_dim + dimV)],
                  name = "extr_WV_random_latent_vectors")(e)
  wv_all = WV_decoder(wv_all)
  if dimV > 0:
    wv = wv_all[0]
    V_law_d = wv_all[1]
  else:
    wv = wv_all

  alaw = Alaw_decoder(wv)
  e_1 = Lambda(lambda x: x[..., 0:latent_dim],
               name = "extr_AYaY_decoder_random_latent_vectors_1")(e)
  e_2 = Lambda(lambda x: x[..., (latent_dim + dimV):(latent_dim + dimV + 1)],
               name = "extr_AYaY_decoder_random_latent_vectors_2")(e)
  main_e = concatenate([e_1, e_2])
  main_data = AYaY_decoder([main_e, wv, alaw])
  d = concatenate([wv, alaw, main_data])

  wvay_d  = pi(d)
  alaw_d = extract_G(d)

  dim_data = dimWV + 6
  y0 = Lambda(lambda x: x[..., dimWV + 2])(d)
  y1 = Lambda(lambda x: x[..., dimWV + 3])(d)
  y_pred = Lambda(lambda x:
                  (x[0] - tf.math.reduce_mean(x[0])) \
                  * (x[1] - tf.math.reduce_mean(x[1]))/ \
                  (tf.math.reduce_std(x[0]) \
                    * tf.math.reduce_std(x[1])))([y0, y1])
  if dimV > 0:
    generator = Model(inputs = [e_obs_input, e_alea_input],
                      outputs = [param,
                                 wvay_d,
                                 V_law_d,
                                 concatenate([d, V_law_d]),
                                 concatenate([d, V_law_d]),
                                 y_pred])
  else:
    generator = Model(inputs = [e_obs_input, e_alea_input],
                      outputs = [param,
                                 wvay_d,
                                 d,
                                 d,
                                 y_pred])
  return generator

5.5 Implementing the loss functions and training algorithm

The last step of the encoding consists of defining the loss functions and optimization algorithm to drive the training of the VAE involved in Equation 6 by solving Equation 7. The definitions of the loss functions follow straightforwardly from the equations. As for the optimization algorithm, we rely on the minibatch stochastic ascent algorithm presented below:

\begin{algorithm} \caption{Minibatch stochastic gradient ascent training.} \begin{algorithmic} \Require number of epochs EPOCH, batch size $m$, number of repetitions $L$, learning rate $\alpha$, exponential decay rates $\beta_1$ and $\beta_2$ for the 1st and 2nd moments estimates, small constant $\epsilon$, initial parameter $\theta^{(0)} = (\theta_1^{(0)}, \theta_2^{(0)}) \in \Theta$ \State Initialize $D \leftarrow \{O_{1}, \ldots, O_{n}\}$ \State Initialize $t \leftarrow 0$, $\text{first}^{(0)} \leftarrow 0_{\mathbb{R}^{\dim(\Theta)}}$, $\text{second}^{(0)} \leftarrow 0_{\mathbb{R}^{\dim(\Theta)}}$ \While{$t<$ EPOCH} \State Sample uniformly without replacement a minibatch of $m$ genuine observations $\tilde{O}_{1}, \ldots, \tilde{O}_{m}$ from $D$ \State Sample a minibatch of $m \times L$ independent sources of randomness $Z_{1,1}, \ldots, Z_{1,L}, Z_{2,1},\ldots, Z_{2,L}, \dots, Z_{m,1},\ldots, Z_{m,L}$ from $(\mathcal N(0,1))^{\otimes d}$ \For{$i=1,\cdots,m$} \State Compute $\mathop{\mathrm{Enc}}_{\theta_{1}^{(t)}} (\tilde{O}_{i}, Z_{1,1}) = ((\mu_i)^{(t)}, (\sigma^2_i)^{(t)}, Z_{1,1})$ \For{$\ell=1,\cdots,L$} \State $U_{i,\ell} \leftarrow (\mu_i)^{(t)} + \sqrt{(\sigma^2_i)^{(t)}} \odot (Z_{i,\ell}^{(1)},\cdots,Z_{i,\ell}^{(d)})$ \EndFor \EndFor \State Update the encoder and decoder by performing one step of stochastic gradient ascent:\\ $g \leftarrow \nabla_{\theta} \left.\left\{\dfrac{1}{m}\sum_{i=1}^{m} \left(-\mathop{\mathrm{KL}}(g_{\theta_{1}} (\cdot | \tilde{O}_{i});\phi_{d}) + \dfrac{1}{L} \sum_{\ell=1}^L \log p_{\theta_{2}}(\tilde O_{i}|U_{i,\ell}) \right)\right\}\right|_{\theta=\theta^{(t)}}$\\ where, for each $1 \leq i \leq m$, \\ $-\mathop{\mathrm{KL}}(g_{\theta_{1}^{(t)}} (\cdot | \tilde{O}_{i});\phi_{d}) = \frac{1}{2} \sum_{j=1}^{d} \left(1 + \log (\sigma_{i}^{2})_{j}^{(t)} - (\sigma_{i}^{2})_{j}^{(t)} - [(\mu_{i})_{j}^{(t)}]^{2} \right)$\\ $\text{first}^{(t+1)} \leftarrow \beta_1 \text{first}^{(t)} + (1-\beta_1) g$\\ $\text{second}^{(t+1)} \leftarrow \beta_2 \text{second}^{(t)} + (1-\beta_2) g\odot g$\\ $\widehat{\text{first}}^{(t+1)} \leftarrow \dfrac{\text{first}^{(t)}}{1-\beta_1^{t+1}}$\\ $\widehat{\text{second}}^{(t+1)} \leftarrow \dfrac{\text{second}^{(t)}}{1-\beta_2^{t+1}}$\\ $\theta^{(t+1)} \leftarrow \theta^{(t)} + \alpha \dfrac{\widehat{\text{first}}^{(t+1)}}{\sqrt{\widehat{\text{second}}^{(t+1)}}+\epsilon}$ \\ \State Update $t \leftarrow t+1$ \EndWhile \end{algorithmic} \end{algorithm}

In our experiments, we set \text{EPOCH}=10, m=10^{3}, L=8, \alpha=0.01, \beta_1=0.9, \beta_2 = 0.999, \epsilon = 10^{-7}. The value of L is chosen to be small for computational efficiency and to help the algorithm avoid getting stuck in local minima. The initial parameter \theta^{(0)} is drawn randomly as follows: each component corresponding to a bias term in a densely-connected layer is set to 0; each component corresponding to a kernel coefficient is drawn independently of the others from the Glorot uniform initializer (Glorot and Bengio 2010) (that is, from the uniform law on \sqrt{6/\ell} \times [-1, 1] where \ell is the sum of the number of input units in the weight tensor and of the number of output units).

The next chunk of code defines the loss functions, optimization algorithm, and the VAE class which wraps up the implementation. The so-called penalization_loss is the counterpart of Equation 8.

Hide/Show the code
# from tensorflow.keras.optimizers.legacy import Adam

def mse_W(x_true, x_pred):
  w_true = x_true[:, 0:dimW]
  batch_size = tf.shape(w_true)[0]
  w_true = tf.repeat(w_true, L, axis = 0)
  w_true = tf.reshape(w_true, (batch_size, L, dimW))
  w_pred = x_pred[..., 0:dimW]
  loss = tf.keras.losses.mse(w_true, w_pred)
  return loss

def reconstruction_loss_W(x_true, x_pred):
  loss = mse_W(x_true, x_pred)
  loss = tf.reduce_mean(loss)
  return loss

def categorical_crossentropy_V(x_true, x_pred):
  x_true_0 = x_true[..., 0:nb_cat_V[0]]
  batch_size = tf.shape(x_true_0)[0]
  x_true_0 = tf.repeat(x_true_0, L, axis = 0)
  x_true_0 = tf.reshape(x_true_0, (batch_size, L, nb_cat_V[0]))
  x_pred_0 = x_pred[..., 0:nb_cat_V[0]]
  loss = tf.keras.losses.categorical_crossentropy(x_true_0, x_pred_0)
  if dimV > 1:
    for v in range(dimV - 1):
      x_true_v = x_true[..., nb_cat_V[v]:(nb_cat_V[v] + nb_cat_V[v+1])]
      x_true_v = tf.repeat(x_true_v, L, axis = 0)
      x_true_v = tf.reshape(x_true_v, (batch_size, L, nb_cat_V[v+1]))
      x_pred_v = x_pred[..., nb_cat_V[v]:(nb_cat_V[v] + nb_cat_V[v+1])]
      loss = loss + tf.keras.losses.categorical_crossentropy(x_true_v, x_pred_v)
  return loss

def reconstruction_loss_V(x_true, x_pred):
  loss = categorical_crossentropy_V(x_true, x_pred)
  loss = tf.reduce_mean(loss)
  return loss

def reconstruction_loss_Y(x_true, x_pred):
  dimWV = dimW + dimV
  dim_data = dimWV + 6
  mse_WV = mse_W(x_true, x_pred)
  a_true = x_true[:, dimWV]
  batch_size = tf.shape(a_true)[0]
  a_true = tf.repeat(a_true, L, axis = 0)
  a_true = tf.reshape(a_true, (batch_size, L))
  a_pred = x_pred[..., dimWV + 2 + 2]
  y_true = x_true[:, (dimWV + 1):(dimWV + 2)]
  y_true = tf.repeat(y_true, L, axis = 0)
  y_true = tf.reshape(y_true, (batch_size, L))
  y_pred = x_pred[..., (dimWV + 2 + 2 + 1):(dimWV + 2 + 2 + 2)]
  y_pred = tf.reshape(y_pred, (batch_size, L))
  indicator = (a_true * a_pred + (1 - a_true) \
               * (1 - a_pred))
  loss = indicator * (1/2 * tf.math.log(mse_WV / 5)\
                      + tf.square(y_true - y_pred)/(2 * mse_WV / 5))
  loss = loss + (1-indicator) * 10
  loss = tf.reduce_mean(loss)
  return loss

def reconstruction_loss_A(x_true, x_pred):
  dimWV = dimW + dimV
  dim_data = dimWV + 6
  mse_WV = mse_W(x_true, x_pred)
  if dimV > 0:
    mse_WV = mse_WV + categorical_crossentropy_V(x_true[..., (dimWV + 2):],
                                                 x_pred[..., dim_data:])
  pi = tf.math.exp(-mse_WV / 10)
  g_pred = x_pred[..., dimWV + 1]
  batch_size = tf.shape(x_true)[0]
  a_true = x_true[:, dimWV]
  a_true = tf.repeat(a_true, L, axis = 0)
  a_true = tf.reshape(a_true, (batch_size, L))
  g_tilde = pi * g_pred + (1 - pi) * tf.reduce_mean(a_true)
  loss = tf.keras.losses.binary_crossentropy(a_true, g_tilde, label_smoothing = 0.1)
  loss = tf.reduce_mean(loss)
  return loss

def kl_loss(y_true, y_pred):
  d = tf.cast(y_pred.shape[1]/2, tf.int32) - 3
  mu = tf.slice(y_pred, [0, 0], [-1, d])
  log_sigma = tf.slice(y_pred, [0, d + 3], [-1, d])
  loss = -1 - log_sigma + tf.exp(log_sigma) + tf.square(mu)
  loss = tf.reduce_sum(loss, axis = 1)
  loss = tf.reduce_mean(loss/2)
  return loss

def penalization_loss(y_true, y_pred):
  batch_size = tf.shape(y_true)[0]
  y_true = tf.repeat(y_true, L, axis = 0)
  y_true = tf.reshape(y_true, (batch_size, L))
  loss = KK.abs(KK.mean(y_pred - y_true))
  return loss

def one_hot(indices, depth):
  out = zeros((indices.size, depth))
  out[np.arange(indices.size), indices.astype(int)] = 1
  return out

# optimizer = Adam(learning_rate = 0.01)
class VAE():
  def __init__(self,
               generator,
               latent_dim, L,
               optimizer):
    self.generator = generator
    self.latent_dim = latent_dim
    self.L = L
    self.optimizer = optimizer
    if dimV > 0:
      generator.compile(optimizer,
                          loss = [kl_loss,
                                  reconstruction_loss_W,
                                  reconstruction_loss_V,
                                  reconstruction_loss_Y,
                                  reconstruction_loss_A,
                                  penalization_loss],
                        loss_weights = [1, 200, 200, 200, 200, 10])
    else:
      generator.compile(optimizer,
                          loss = [kl_loss,
                                  reconstruction_loss_W,
                                  reconstruction_loss_Y,
                                  reconstruction_loss_A,
                                  penalization_loss],
                        loss_weights = [1, 200, 200, 200, 10])
  def train(self, data_train, epochs, batch_size):
    param_dummy = zeros((batch_size, self.latent_dim, 6))
    for i in range(epochs):
      np.random.shuffle(data_train)
      data_real = data_train
      loss = 0.0
      while data_real.shape[0] >= batch_size:
        # subsample real observations
        real_obs = data_real[range(np.min((batch_size, np.shape(data_real)[0]))), :]
        if dimV > 0:
          real_V = real_obs[:, dimW : (dimW + dimV)]
          real_one_hot_V = one_hot(real_V[:,0], nb_cat_V[0])
          if dimV > 1:
            for v in range(dimV - 1):
              real_one_hot_V = np.concatenate((real_one_hot_V,
                                               one_hot(real_V[:,v+1],
                                                       nb_cat_V[v+1])), axis = 1)

        data_real = data_real[batch_size:, :]
        # generate alea
        alea = np.random.normal(loc = 0, scale = 1,
                                size = (np.shape(real_obs)[0], L, latent_dim + dimV + 1))
        # train the generator
        y_true = 0.5 * ones((np.shape(real_obs)[0], 1))

        if dimV > 0:
          trained_model = self.generator.train_on_batch([real_obs, alea],
                                                        [param_dummy[range(np.shape(real_obs)[0])],
                                                         real_obs,
                                                         real_one_hot_V,
                                                         np.concatenate((real_obs,
                                                                         real_one_hot_V), axis = 1),
                                                         np.concatenate((real_obs,
                                                                         real_one_hot_V), axis = 1),
                                                         y_true])
        else:
          trained_model = self.generator.train_on_batch([real_obs, alea],
                                                        [param_dummy[range(np.shape(real_obs)[0])],
                                                         real_obs,
                                                         real_obs,
                                                         real_obs,
                                                         y_true])
        loss += trained_model[0]

6 Illustration on simulated data

In Section 2, in the context of the running example, we define a simulation law P and simulated from P a training data set train and a testing data set test using the function simulate. The two independent data sets consist of n=5000 mutually independent realizations O_{i} = (V_{i},W_{i},A_{i},Y_{i}) \in \mathcal{O}. We present here how to use train and the VAE coded in Section 5 to learn a function \mathop{\mathrm{Gen}}_{\theta} so that, if Z is sampled as in Equation 2, then \mathop{\mathrm{Gen}}_{\theta}(Z) is a random element of \mathcal{X} and \pi \circ \mathop{\mathrm{Gen}}_{\theta}(Z) is a random element of \mathcal{O} whose law approximates P.

6.1 Training the VAE

By running the next chunk of code, we set the VAE’s configuration.

Hide/Show the code
dimW = 3
dimV = 2
nb_cat_V = np.array([len(np.unique(train[:,3])),
                     len(np.unique(train[:,4]))])
latent_dim = 10
L = 8

nlayers_encoder = 2
nneurons_encoder = 8

nlayers_WV_decoder = 2
nneurons_WV_decoder = 8

nlayers_Alaw_decoder = 2
nneurons_Alaw_decoder = 8

nlayers_AYaY_decoder = 2
nneurons_AYaY_decoder = 16

The next chunk of code repeatedly generates and initializes a VAE then trains it.

Hide/Show the code
epochs = 10
batch_size = np.round(0.2 * np.shape(train)[0]).astype(int)
synth_size = 1000
nb_tries = 100
for i in range(nb_tries):
  print("try #", i, ", ")
  encoder = build_encoder(dimW = dimW,
                          dimV = dimV,
                          latent_dim = latent_dim,
                          L = L,
                          nlayers_encoder = nlayers_encoder,
                          nneurons_encoder = nneurons_encoder)
  WV_decoder = build_WV_decoder(dimW = dimW,
                                dimV = dimV,
                                nb_cat_V = nb_cat_V,
                                latent_dim = latent_dim,
                                L = L,
                                nlayers_WV_decoder = nlayers_WV_decoder,
                                nneurons_WV_decoder = nneurons_WV_decoder)
  Alaw_decoder = build_Alaw_decoder(dimW = dimW,
                                    dimV = dimV,
                                    L = L,
                                    nlayers_Alaw_decoder = nlayers_Alaw_decoder,
                                    nneurons_Alaw_decoder = nneurons_Alaw_decoder)
  AYaY_decoder = build_AYaY_decoder(dimW = dimW,
                                    dimV = dimV,
                                    latent_dim = latent_dim,
                                    L = L,
                                    nlayers_AYaY_decoder = nlayers_AYaY_decoder,
                                    nneurons_AYaY_decoder = nneurons_AYaY_decoder)
  pi = build_pi(dimW = dimW,
                dimV = dimV,
                L = L)
  extract_G = build_extract_G(dimW = dimW,
                              dimV = dimV,
                              L = L)
  generator = build_generator(dimW = dimW,
                              dimV = dimV,
                              encoder = encoder,
                              latent_dim = latent_dim,
                              L = L,
                              WV_decoder = WV_decoder,
                              Alaw_decoder = Alaw_decoder,
                              AYaY_decoder = AYaY_decoder,
                              pi = pi,
                              extract_G = extract_G)
  vae = VAE(generator = generator,
            latent_dim = latent_dim,
            L = L,
            optimizer = optimizer)
  vae.train(data_train = train,
            epochs = epochs,
            batch_size = batch_size)
  random_sample = np.random.choice(n_train, synth_size, replace = True)
  O_test = train[random_sample, :]
  alea = np.random.normal(loc = 0, scale = 1, size = (synth_size, L, latent_dim + dimV + 1))
  synth = vae.generator.predict([O_test, alea])[1]
  np.savetxt("synth_try#" + str(i) + ".csv", synth[:, 0, :], delimiter = ",")

Because running the chunk is time-consuming, we stored one trained VAE that we considered good enough. We explain what we mean by good enough in the next section.

6.2 A formal view on how to evaluate the quality of the generator

Suppose that we have built a generator \mathop{\mathrm{Gen}}_{\widehat{\theta}_{n}} based on the genuine observations O_{1}, \ldots, O_{n} drawn from P. How can we assess how well the generator approximates P? In other words, how can we assess how convincing are synthetic observations drawn from \mathop{\mathrm{law}}(\mathop{\mathrm{Gen}}_{\theta_n}(Z)) in their attempt to look like observations drawn from P?

We propose three ways to address this question. Each of them uses the genuine observations O_{n+1}, \ldots, O_{n+n'} that were not used to build \mathop{\mathrm{Gen}}_{\widehat{\theta}_{n}} and N synthetic observations O^{\sharp}_{1}, \ldots, O^{\sharp}_{N} drawn independently from \mathop{\mathrm{Gen}}_{\widehat{\theta}_{n}}.

6.2.1 Criterion 1

The overly faithful replication (a form of overfitting) by \mathop{\mathrm{Gen}}_{\widehat{\theta}_{n}} of O_{1}, \ldots, O_{n}, the genuine observations upon which its construction is based, is a pitfall that we aim to avoid. As a side note, the simplest generator that one can build from O_{1}, \ldots, O_{n} is the empirical measure based on them, which corresponds to the bootstrap approach (see Section 1.4).

The first criterion we propose is inspired by a commonly used machine learning metric for comparing synthetic images generated by a neural network to the original training images. To assess the potential over-faithfulness of the replication process, we suggest comparing two empirical distributions:

  • \mu_{1:n}, the empirical law of the distance to the nearest neighbor within \{O^{\sharp}_{1}, \ldots, O^{\sharp}_{N}\} of each O_{i} (1 \leq i \leq n);
  • \mu_{(n+1):(n+n')}, the empirical law of the distance to the nearest neighbor within \{O^{\sharp}_{1}, \ldots, O^{\sharp}_{N}\} of each O_{n+i} (1 \leq i \leq n').

Ideally, \mu_{1:n} and \mu_{(n+1):(n+n')} should be similar, indicating that the training and testing performances align well. However, if \mathop{\mathrm{Gen}}_{\widehat{\theta}_{n}} replicates O_{1}, \ldots, O_{n} too faithfully, then \mu_{1:n} will become very concentrated around 0 while \mu_{(n+1):(n+n')} will not exhibit the same behavior. Note that within a bootstrap approach, the generator that merely samples uniformly from \{O_{1}, \ldots, O_{n}\} would result in a \mu_{1:n} having all its mass at 0 if we let N go to infinity, according to the law of large numbers.

6.2.2 Criterion 2

The second criterion involves comparing the marginal distributions of each real-valued component of O under sampling from P and from \mathop{\mathrm{law}}(\mathop{\mathrm{Gen}}_{\widehat{\theta}_{n}}(Z)). This comparison can be conducted visually, by plotting the empirical distribution functions, or numerically, by computing p-values of hypotheses tests. Depending on the nature of the components of O, appropriate tests include the binomial, multinomial, \chi^{2} or Kolmogorov-Smirnov tests.

6.2.3 Criterion 3

The third criterion aims to capture discrepancies between P and \mathop{\mathrm{law}}(\mathop{\mathrm{Gen}}_{\widehat{\theta}_{n}}(Z)) beyond marginal comparisons. To do so in general we propose, for a user-specified collection of prediction algorithms \mathcal{A}_{1}, \ldots, \mathcal{A}_{K}, to compare their outputs when trained on \{O_{1}, \ldots, O_{n}\} versus \{O^{\sharp}_{1}, \ldots, O^{\sharp}_{n}\}, using the predictions they make for each O_{n+1}, \ldots, O_{n+n'}.

For instance, \mathcal{A}_{1} could be an algorithm that learns to predict A given (V,W) based on the logistic regression model

\{(v,w) \mapsto m_{\gamma}(v,w) := \mathop{\mathrm{expit}}(\gamma^{0} + \gamma^{1} (v,w)) : \gamma = (\gamma^{0}, \gamma^{1}) \in \mathbb{R}\times\mathbb{R}^{5}\}.

Training \mathcal{A}_{1} on \{O_{1}, \ldots, O_{n}\} (respectively, \{O^{\sharp}_{1}, \ldots, O^{\sharp}_{n}\}) yields \gamma_{n} (respectively, \gamma_{n}^{\sharp}), hence the predictions m_{\gamma_{n}}(V_{n+i},W_{n+i}) and m_{\gamma_{n}^{\sharp}}(V_{n+i}, W_{n+i}) (1 \leq i \leq n). The closer \mathop{\mathrm{Gen}}_{\widehat{\theta}_{n}} approximates P, the nearer the points (m_{\gamma_{n}}(V_{n+1},W_{n+1}), m_{\gamma_{n}^{\sharp}}(V_{n+1},W_{n+1})), \ldots, (m_{\gamma_{n}}(V_{n+n'},W_{n+n'}), m_{\gamma_{n}^{\sharp}}(V_{n+n'},W_{n+n'})) are to the y=x line in the xy-plane.

Importantly, the algorithms need not rely on parametric working models. For instance, \mathcal{A}_{2} could learn to predict A given (V,W) using a nonparametric algorithm such as a random forest.

6.3 Implementing an evaluation of the quality of the generator

We now show how to implement the three criteria presented in Section 6.2. The next chunk of code loads the data into R: train and test are the R counterparts of the Python objects train and test (keeping only the first 1000 observations) and synth is the collection of 1000 synthetic observations drawn from the generator associated to the VAE that we stored in Section 6.1. For later use (while implementing Criterion 1) we add a dummy column named Z.

Hide/Show the code
dimW <- 3
dimV <- 2
train <- readr::read_csv("data/train.csv",
                         col_names = c(paste0("W", 1:dimW),
                                       paste0("V", 1:dimV),
                                       "A", "Y")) %>%
  dplyr::mutate(A = as.integer(A),
                Z = stats::rnorm(dplyr::n())) %>%
  dplyr::slice_head(n = 1e3)
test <- readr::read_csv("data/test.csv",
                        col_names = c(paste0("W", 1:dimW),
                                      paste0("V", 1:dimV),
                                      "A", "Y")) %>%
  dplyr::mutate(A = as.integer(A),
                Z = stats::rnorm(dplyr::n())) %>%
  dplyr::slice_head(n = 1e3)
synth <- readr::read_csv("data/synth_try#77.csv",
                         col_names = c(paste0("W", 1:dimW),
                                       paste0("V", 1:dimV),
                                       "A", "Y")) %>%
  dplyr::mutate(A = as.integer(A),
                V1 = round(V1),
                V2 = round(V2),
                Z = stats::rnorm(dplyr::n()))

6.3.1 Criterion 1

The next chunk of code implements the first criterion.

Hide/Show the code
frml <- paste0("Z ~", 
               paste0(c(paste0("W", 1:dimW),
                        paste0("V", 1:dimV),
                        "A", "Y"),
                     collapse = " + ")) %>%
  stats::as.formula()
fig <- tibble::tibble(d = c(kknn::kknn(frml, train, synth, k = 1)$D,
                            kknn::kknn(frml, test, synth, k = 1)$D)) %>%
  dplyr::mutate(type = c(rep("training data", dplyr::n()/2),
                         rep("testing data", dplyr::n()/2))) %>%
  ggplot2::ggplot() + 
  ggplot2::stat_ecdf(ggplot2::aes(x = d, color = type)) +
  ggplot2::scale_x_log10() +
  ggplot2::xlab("distance to nearest neighbor") + 
  ggplot2::ylab("empirical cumulative distribution function")
print(fig)
Figure 3: Empirical c.d.f. of the distance to the nearest neighbor within the synthetic observations of the training and of the testing data points (logarithmic scale). The two c.d.f. are quite close.

The two empirical c.d.f. shown in Figure 3 are quite similar, suggesting that \mu_{1:n} and \mu_{(n+1):(n+n')} are close. To quantify this proximity, we rely on statistical tests.

The Directed Acyclic Graph (DAG) in Figure 4 represents the experiment of law \Pi that consists successively of

  • drawing O_{1}, \ldots, O_{n}, O_{n+1}, \ldots, O_{n+n'} independently from P;
  • learning \widehat{\theta}_{n};
  • sampling O^{\sharp}_{1}, \ldots, O^{\sharp}_{N} independently from \mathop{\mathrm{law}}(\mathop{\mathrm{Gen}}_{\widehat{\theta}_{n}}(Z));
  • determining, for each 1 \leq i \leq n+n', the nearest neighbor f^{\sharp}(O_{i}) of O_{i} among O^{\sharp}_{1}, \ldots, O^{\sharp}_{N}.
Figure 4: DAG representing how the random variables produced by \Pi depend on each other.

The DAG is very useful to unravel how the random variables produced by \Pi depend on each other. In particular, by d-separation (Lauritzen 1996), we learn from the DAG that the distances to the nearest neighbor within \{O^{\sharp}_{1}, \ldots, O^{\sharp}_{N}\} of O_{1}, \ldots, O_{n+n'} are dependent pairwise. This dependency prevents the use of a Kolmogorov-Smirnov test to compare \mu_{1:n} and \mu_{(n+1):(n+n')}.

Moreover, conditionally on O^{\sharp}_{1}, \ldots, O^{\sharp}_{N},

  • O_{1}, \ldots, O_{n} are not independent (because, for any 1 \leq i < j \leq n, O^{\sharp}_{1} is a collider on the path O_{i} \to O^{\sharp}_{1} \leftarrow O_{j});

  • f^{\sharp}(O_{1}), \ldots, f^{\sharp}(O_{n+n'}) are independent (because, for any 1 \leq i < j \leq n+n', all paths leading from f^{\sharp}(O_{i}) to f^{\sharp}(O_{j}) are blocked);

  • the distances to the nearest neighbor within \{O^{\sharp}_{1}, \ldots, O^{\sharp}_{N}\} of O_{n+1}, \ldots, O_{n+n'} are mutually independent.

Therefore, conditionally on O^{\sharp}_{1}, \ldots, O^{\sharp}_{N} and \mu_{1:n}, we can use t-tests to compare the three first moments of \mu_{(n+1):(n+n')} to those of \mu_{1:n}. By the central limit theorem and Slutsky’s lemma (van der Vaart 1998, Example 2.1 and Lemma 2.8), the tests are asymptotically valid as n' goes to infinity.

The next chunk of code retrieves the p-values of the three tests using all 1000 synthetic observations.

Hide/Show the code
(moments_test <- fig$data %>%
   dplyr::mutate(`1st_moment` = mean(fig$data %>%
                                     dplyr::filter(type == "training data") %>%
                                     dplyr::pull(d)),
                 `2nd_moment` = mean(fig$data %>%
                                     dplyr::filter(type == "training data") %>%
                                     dplyr::pull(d) %>% .^2),
                 `3rd_moment` = mean(fig$data %>%
                                     dplyr::filter(type == "training data") %>%
                                     dplyr::pull(d) %>% .^3)) %>%
   dplyr::filter(type == "testing data") %>%
   tidyr::nest() %>%
   dplyr::mutate(
              `1st_moment_test` =
                  purrr::map(data,
                             function(df)
                                 stats::t.test(df$d, mu = df$`1st_moment`[1])$p.val),
              `2nd_moment_test` =
                  purrr::map(data,
                             function(df)
                                 stats::t.test(df$d^2, mu = df$`2nd_moment`[1])$p.val),
              `3rd_moment_test` =
                  purrr::map(data,
                             function(df)
                                 stats::t.test(df$d^3, mu = df$`3rd_moment`[1])$p.val)
          ) %>%
   dplyr::select(-data) %>%
   tidyr::unnest(cols = c(`1st_moment_test`, `2nd_moment_test`, `3rd_moment_test`)))
# A tibble: 1 × 3
  `1st_moment_test` `2nd_moment_test` `3rd_moment_test`
              <dbl>             <dbl>             <dbl>
1           0.00706            0.0178             0.142

The p-values from the first two tests are small, but not strikingly so, especially when accounting for multiple testing. This indicates only moderate evidence of a discrepancy.

It is tempting to investigate what happens when only 100 synthetic observations are used.

Hide/Show the code
(tibble::tibble(d = c(kknn::kknn(frml, train, synth %>% dplyr::slice_head(n = 100), k = 1)$D,
                      kknn::kknn(frml, test, synth %>% dplyr::slice_head(n = 100), k = 1)$D)) %>%
  dplyr::mutate(type = c(rep("training data", dplyr::n()/2),
                         rep("testing data", dplyr::n()/2))) %>%
  dplyr::mutate(`1st_moment` = mean(fig$data %>%
                                     dplyr::filter(type == "training data") %>%
                                     dplyr::pull(d)),
                 `2nd_moment` = mean(fig$data %>%
                                     dplyr::filter(type == "training data") %>%
                                     dplyr::pull(d) %>% .^2),
                 `3rd_moment` = mean(fig$data %>%
                                     dplyr::filter(type == "training data") %>%
                                     dplyr::pull(d) %>% .^3)) %>%
   dplyr::filter(type == "testing data") %>%
   tidyr::nest() %>%
   dplyr::mutate(
              `1st_moment_test` =
                  purrr::map(data,
                             function(df)
                                 stats::t.test(df$d, mu = df$`1st_moment`[1])$p.val),
              `2nd_moment_test` =
                  purrr::map(data,
                             function(df)
                                 stats::t.test(df$d^2, mu = df$`2nd_moment`[1])$p.val),
              `3rd_moment_test` =
                  purrr::map(data,
                             function(df)
                                 stats::t.test(df$d^3, mu = df$`3rd_moment`[1])$p.val)
          ) %>%
   dplyr::select(-data) %>%
   tidyr::unnest(cols = c(`1st_moment_test`, `2nd_moment_test`, `3rd_moment_test`)))
# A tibble: 1 × 3
  `1st_moment_test` `2nd_moment_test` `3rd_moment_test`
              <dbl>             <dbl>             <dbl>
1             0.156            0.0865            0.0282

This time, only the p-value from the third tests is small, but not markedly so when accounting for multiple testing. The evidence of a discrepancy is significantly weaker when 100 synthetic observations are used compared to 1000. This highlights that distinguishing N synthetic observations from genuine observations becomes increasingly difficult as N decreases.

6.3.2 Criterion 2

The next chunk of code implements the second criterion, in its visual form.

Hide/Show the code
fig <- dplyr::bind_rows(test, synth) %>%
  dplyr::select(-Z) %>%
  dplyr::mutate(type = c(rep("testing data", dplyr::n()/2),
                         rep("synthetic data", dplyr::n()/2))) %>%
  tidyr::pivot_longer(-type, names_to = "what", values_to = "values") %>%
  dplyr::mutate(what = dplyr::case_when(
           what == "W1" ~ "W[1]",
           what == "W2" ~ "W[2]",
           what == "W3" ~ "W[3]",
           what == "V1" ~ "V[1]",
           what == "V2" ~ "V[2]",
           TRUE ~ what
  )) %>%
  dplyr::mutate(what = factor(what,
                              levels = c(paste0("W[", 1:dimW, "]"), 
                                         paste0("V[", 1:dimV, "]"),
                                         "A", "Y"))) %>%
  ggplot2::ggplot() +
  ggplot2::stat_ecdf(ggplot2::aes(x = values, color = type)) +
  ggplot2::facet_wrap(~ what,
                      scales = "free_x",
                      labeller = ggplot2::label_parsed) +
  ggplot2::ylab("empirical cumulative distribution function")
print(fig)
Figure 5: Empirical c.d.f. of each covariate based on either the synthetic or the testing data sets.

Firstly, inspecting the first row of Figure 5 suggests that the marginal laws of W_{1}, W_{2}, W_{3} under the synthetic law do not align very well with their counterparts under P, although the locations and ranges of the true marginal laws are reasonably well approximated. The restriction to \mathbb{R}_{+} of the marginal law of W_{1} under the synthetic law is very similar to its counterpart under P, but its restriction to \mathbb{R}_{-} is too thin-tailed. As for the marginal laws of W_{2}, W_{3} under the synthetic law, they are too thin-tailed compared to their counterparts under P.

Secondly, inspecting the second row of Figure 5 reveals that the marginal laws of V_{1}, V_{2} under the synthetic law align perfectly (V_{1}) and reasonably well (V_{2}) with their counterparts under P. However, the marginal law of A under the synthetic law assigns more weight to the event [A=1] than its counterpart under P.

Lastly, inspecting the third row of Figure 5 reveals that the marginal law of Y under the synthetic law does not align very well with its counterpart under P. While the location and range of the true marginal law are reasonably well approximated, the overall shape of the true density is not faithfully reproduced.

The next chunk of code implements the version of the second criterion based on hypotheses testing. Conditionally on the training data set, the testing procedures are valid because (i) the synthetic and testing data sets are independent, (ii) the testing data are drawn independently from P, (iii) the synthetic data are drawn independently from \mathop{\mathrm{law}}(\mathop{\mathrm{Gen}}_{\widehat{\theta}_{n}}(Z)).

We first address the continuous covariates (W_1, W_2, W_3 and Y) and then the binary covariates (V_{1}, V_{2} and A). For the former, we use Kolmogorov-Smirnov tests. For the latter, we use exact Fisher tests.

Hide/Show the code
(ks_tests <- fig$data %>%
  dplyr::filter(!what %in% c(paste0("V[", 1:dimV, "]"), "A")) %>%
  tidyr::nest(data = -what) %>%
  dplyr::mutate(p.val = purrr::map(data,
                                   ~ stats::ks.test(stats::formula(values ~ type),
                                                    data = .x)$p.val)) %>%
  dplyr::select(-data) %>%
  tidyr::unnest(cols = p.val))
# A tibble: 4 × 2
  what     p.val
  <fct>    <dbl>
1 W[1]  6.06e- 5
2 W[2]  2.43e-14
3 W[3]  6.14e-10
4 Y     2.55e- 7
Hide/Show the code
(fisher_test <- fig$data %>%
  dplyr::filter(what %in% c(paste0("V[", 1:dimV, "]"), "A")) %>%
  dplyr::group_by(what, type) %>%
  dplyr::summarize(`=1` = sum(values),
                   `=0` = dplyr::n()-sum(values)) %>%
  dplyr::select(-type) %>%
  tidyr::nest(data = -what) %>%
  dplyr::mutate(p.val = purrr::map(data,
                                   ~ stats::fisher.test(.x)$p.val)) %>%
  dplyr::select(-data) %>%
  tidyr::unnest(cols = p.val))
# A tibble: 3 × 2
# Groups:   what [3]
  what     p.val
  <fct>    <dbl>
1 V[1]  0.924   
2 V[2]  0.117   
3 A     0.000235

Most p-values are very small, supporting the conclusions drawn from inspecting Figure 5. Unlike the marginal laws of V_{1}, V_{2}, the marginal laws of W_{1}, W_{2}, W_{3}, A, Y are not well approximated, as the tests detect discrepancies when both the synthetic and testing data sets contain 1000 data points.

Naturally, one might wonder whether this result still holds when comparing smaller synthetic and testing data sets. The following chunk of code reproduces the same statistical analysis as before, but now using two samples of 100 data points each.

Hide/Show the code
smaller_dataset <- fig$data %>%
   dplyr::group_by(type) %>%
   dplyr::slice_head(n = 100) %>%
   dplyr::ungroup()
(ks_tests <- smaller_dataset %>% 
  dplyr::filter(!what %in% c(paste0("V[", 1:dimV, "]"), "A")) %>%
  tidyr::nest(data = -what) %>%
  dplyr::mutate(p.val = purrr::map(data,
                                   ~ stats::ks.test(stats::formula(values ~ type),
                                                    data = .x)$p.val)) %>%
  dplyr::select(-data) %>%
  tidyr::unnest(cols = p.val))
# A tibble: 4 × 2
  what   p.val
  <fct>  <dbl>
1 W[1]  0.386 
2 W[2]  0.678 
3 W[3]  0.0590
4 Y     0.921 
Hide/Show the code
(fisher_test <- smaller_dataset %>%
  dplyr::filter(what %in% c(paste0("V[", 1:dimV, "]"), "A")) %>%
  dplyr::group_by(what, type) %>%
  dplyr::summarize(`=1` = sum(values),
                   `=0` = dplyr::n()-sum(values)) %>%
  dplyr::select(-type) %>%
  tidyr::nest(data = -what) %>%
  dplyr::mutate(p.val = purrr::map(data,
                                   ~ stats::fisher.test(.x)$p.val)) %>%
  dplyr::select(-data) %>%
  tidyr::unnest(cols = p.val))
# A tibble: 3 × 2
# Groups:   what [3]
  what   p.val
  <fct>  <dbl>
1 V[1]  0.0542
2 V[2]  0.706 
3 A     1     

This time, the p-values are large, indicating that the tests cannot detect discrepancies when the synthetic and testing data sets contain only 100 data points. Surprisingly, the same conclusion holds when comparing a synthetic data set of 100 data points with a testing data set of 1000 data points, as demonstrated in the next chunk of code.

Hide/Show the code
smaller_synth_dataset <- fig$data %>%
  dplyr::filter(type == "synthetic data") %>%
  dplyr::slice_head(n = 100) %>%
  dplyr::bind_rows(fig$data %>%
                   dplyr::filter(type == "testing data"))
(ks_tests <- smaller_synth_dataset %>% 
  dplyr::filter(!what %in% c(paste0("V[", 1:dimV, "]"), "A")) %>%
  tidyr::nest(data = -what) %>%
  dplyr::mutate(p.val = purrr::map(data,
                                   ~ stats::ks.test(stats::formula(values ~ type),
                                                    data = .x)$p.val)) %>%
  dplyr::select(-data) %>%
  tidyr::unnest(cols = p.val))
# A tibble: 4 × 2
  what   p.val
  <fct>  <dbl>
1 W[1]  0.718 
2 W[2]  0.849 
3 W[3]  0.0466
4 Y     0.677 
Hide/Show the code
(fisher_test <- smaller_synth_dataset %>%
  dplyr::filter(what %in% c(paste0("V[", 1:dimV, "]"), "A")) %>%
  dplyr::group_by(what, type) %>%
  dplyr::summarize(`=1` = sum(values),
                   `=0` = dplyr::n()-sum(values)) %>%
  dplyr::select(-type) %>%
  tidyr::nest(data = -what) %>%
  dplyr::mutate(p.val = purrr::map(data,
                                   ~ stats::fisher.test(.x)$p.val)) %>%
  dplyr::select(-data) %>%
  tidyr::unnest(cols = p.val))
# A tibble: 3 × 2
# Groups:   what [3]
  what  p.val
  <fct> <dbl>
1 V[1]  0.568
2 V[2]  0.592
3 A     0.601

In conclusion, while a large synthetic data set can be shown to differ in law from a large testing data set, a smaller synthetic data set does not exhibit noticeable differences in marginal laws when compared to either a small or a large testing data set.

6.3.3 Criterion 3

The next chunk of code builds a super learning algorithm to estimate either the conditional probability that A=1 given (W,V) or the conditional mean of Y given (A,W,V) by aggregating 5 base learners. We use the SuperLearner package in R. Specifically, the 5 base learners estimate the above conditional means by a constant (SL.mean), or based on generalized linear models (SL.glm and SL.glm.interaction), or by a random forest (SL.ranger), or based on a single-hidden-layer neural network (SL.nnet).

Hide/Show the code
library(SuperLearner)
algo <- function(learning_data, testing_data,
                 outcome, covariates) {
  SL.lib <- c("SL.mean", "SL.glm", "SL.glm.interaction", "SL.ranger", "SL.nnet")
  cvControl <- SuperLearner::SuperLearner.CV.control(V = 10)
  family <- switch(outcome,
                   A = stats::binomial(),
                   Y = stats::gaussian())
  sl <- SuperLearner::SuperLearner(
                        Y = dplyr::pull(learning_data, outcome),
                        X = learning_data[, covariates],
                        newX = testing_data[, covariates], 
                        family = family,
                        SL.library = SL.lib,
                        method = "method.NNLS",
                        cvControl = cvControl)
  list(as.vector(sl$SL.predict))
}

We train the super learning algorithm three times: once on each of two distinct halves of the training data set, and once on half of the synthetic data set. This results in three estimators of the conditional probability that A=1 given (W,V) and three estimators of the conditional mean of Y given (A,W,V). The next chunk of code prepares the three training data sets.

Hide/Show the code
(tib <- dplyr::bind_rows(train, test, synth) %>%
   dplyr::select(-Z) %>%
   dplyr::mutate(type = c(rep("using training data a", dplyr::n()/6),
                          rep("using training data b", dplyr::n()/6),
                          rep("testing data a", dplyr::n()/6),
                          rep("testing data b", dplyr::n()/6),
                          rep("using synthetic data", dplyr::n()/3))) %>%
   dplyr::group_by(type) %>%
   dplyr::mutate(id = 1:dplyr::n()) %>%
   dplyr::ungroup() %>%
   dplyr::filter(id <= dplyr::n()/6) %>%
   dplyr::mutate(type = dplyr::case_when(
                                 stringr::str_detect(type, "testing") ~ "testing data",
                                 TRUE ~ type)) %>%
   dplyr::select(-id) %>%
   tidyr::nest(data = -type) %>%
   dplyr::mutate(testing = data[3]) %>%
     dplyr::filter(type != "testing data"))
# A tibble: 3 × 3
  type                  data               testing             
  <chr>                 <list>             <list>              
1 using training data a <tibble [500 × 7]> <tibble [1,000 × 7]>
2 using training data b <tibble [500 × 7]> <tibble [1,000 × 7]>
3 using synthetic data  <tibble [500 × 7]> <tibble [1,000 × 7]>

The following chunk of code trains the super learning algorithm and evaluates the six resulting estimators on the testing data points. To compare the estimators, we use scatter plots. Specifically, denoting by \widehat{\text{pr}}_{1}, \widehat{\text{pr}}_{2}, \widehat{\text{pr}}_{3} the estimators of the conditional probability that A=1 given (W,V) obtained by training the super learning algorithm on each of the two distinct halves of the training data set (\widehat{\text{pr}}_{1} and \widehat{\text{pr}}_{2}), and on half of the synthetic data set (\widehat{\text{pr}}_{3}), we plot in the left-hand side panel \{(\widehat{\text{pr}}_{1}(W_{n+i},V_{n+i}), \widehat{\text{pr}}_{2}(W_{n+i},V_{n+i})) : 1 \leq i \leq n\} (in red) and \{(\widehat{\text{pr}}_{1}(W_{n+i},V_{n+i}), \widehat{\text{pr}}_{3}(W_{n+i},V_{n+i})) : 1 \leq i \leq n\} (in blue).

Therein, the spread of the red scatter plot along the y=x line in the xy-plane is an evidence of the inherent and irreducible randomness that one faces when one learns the conditional probability that A=1 given (W,V). By comparison, the blue scatter plot is more widely spread around the line, revealing a measure of discrepancy between the training and synthetic data.

The right-hand side panel is obtained analogously. The red scatter plot is more concentrated around the y=x line than its counterpart in the left-hand side panel. The blue scatter plot is more widely spread than the red one, which again reveals a measure of discrepancy between the training and synthetic data. In summary, we consider that the red and blue scatter plots do not strongly differ in their bulks. However, it seems that the blue scatter plots feature more outliers than their red counterparts, revealing that the estimators may be quite different in some parts of the space of covariates.

Hide/Show the code
tib %>%
  dplyr::group_by(type) %>%
  dplyr::mutate(AgivenWV = algo(data[[1]], testing[[1]],
                                "A",
                                c(paste0("W", 1:dimW),
                                  paste0("V", 1:dimV))),
                YgivenAWV = algo(data[[1]], testing[[1]],
                                 "Y",
                                 c(paste0("W", 1:dimW),
                                   paste0("V", 1:dimV),
                                   "A"))) %>%
  dplyr::ungroup() %>%
  dplyr::select(type, AgivenWV, YgivenAWV) %>%
  tidyr::unnest(cols = c(AgivenWV, YgivenAWV)) %>%
  tidyr::pivot_longer(!type, names_to = "what", values_to = "preds") %>%
  dplyr::group_by(type, what) %>%
  dplyr::mutate(id = 1:dplyr::n()) %>%
  tidyr::pivot_wider(names_from = type,
                     values_from = "preds",
                     names_prefix = "preds ") %>%
  dplyr::ungroup() %>%
  dplyr::mutate(what = dplyr::case_when(
                                what == "AgivenWV" ~ "A given (W,V)",
                                what == "YgivenAWV" ~ "Y given (A, W, V)"
                              )) %>%
  ggplot2::ggplot() +
  ggplot2::geom_point(ggplot2::aes(`preds using training data a`,
                                   `preds using synthetic data`,
                                   ),
                      color = "#00BFC4", alpha = 0.5) +
  ggplot2::geom_point(ggplot2::aes(`preds using training data a`,
                                   `preds using training data b`,
                                   ),
                      color = "#F8766D", alpha = 0.5) +
  ggplot2::geom_abline() +
  ggplot2::facet_wrap(dplyr::vars(what),
                      labeller = ggplot2::as_labeller(\(str) paste("predicting", str))) +
  ggplot2::xlab("algorithm trained on 1st half of training data set") +
    ggplot2::ylab("algorithm trained either\n
on 2nd half of training data set (red) or\n
on synthetic data set (blue)")
Figure 6: Comparing predicted conditional probabilities that A=1 given (W,V) (left) or predicted conditional means of Y given (A,W,V) (right) when a super learning algorithm is trained twice on two distinct halves of the training data set (red points) or on the first half of the training data set, x-axis, and on half of the synthetic data set, y-axis (blue points).

6.3.4 Summary

We implement three criteria to evaluate the synthetic observations. The first criterion compared empirical distributions of distances between genuine observations, both involved and not involved in the generator’s construction, and synthetic observations, detecting minor over-replication in the synthetic data set. The second criterion assessed marginal distributions of individual features, revealing discrepancies, particularly in continuous variables, which often exhibited overly thin tails. The third criterion compared predictions from an algorithm trained on synthetic versus genuine observations, showing good replication for predicting Y given (A,W,V) but less so for predicting A given (W,V). Overall, while the synthetic observations show some discrepancies from the genuine ones, these differences are not overly substantial. Moreover, detecting significant differences becomes much harder with smaller synthetic datasets (100 versus 1000 synthetic observations).

7 Illustration on real data

In this section, we extend the analysis conducted in the previous section to real data. We use a subset of the International Warfarin Pharmacogenetics Consortium IWPC data set (The International Warfarin Pharmacogenetics Consortium 2009). Warfarin therapy is a commonly prescribed anticoagulant employed to treat thrombosis and thromboembolism.

7.1 The International Warfarin Pharmacogenetics Consortium data set

In order to limit the number of incomplete observations, we keep only the following variables:

  • height, in centimeters;
  • weight, in kilograms;
  • indicator of whether or not VKORC1 consensus (obtained from genotype data) is “A/A”;
  • indicator of whether or not CYP2C9 consensus (obtained from genotype data) is “*1/*1”;
  • indicator of whether or not ethnicity is Asian;
  • indicator of whether or not therapeutic dose of Warfarin is greater than or equal to 21 mg;
  • international normalized ratio on reported therapeutic dose of Warfarin (INR, a measure of blood clotting function).
Hide/Show the code
dimW <- 2 
dimV <- 3
dataset <- readr::read_csv("data/IWPC_data.csv",
                           skip = 1,
                           col_names = c("id",
                                         paste0("W", 1:dimW),
                                         paste0("V", 1:dimV),
                                         "A", "Y"))

The original database includes 3193 patients with complete observations for these variables. We refer to the table below for a brief description of the data.

Hide/Show the code
dataset %>%
    dplyr::select(V1, V2, V3, W1, W2, A, Y) %>%
    gtsummary::tbl_summary(
                   label = list(W1 = "Height (cm)",
                                W2 = "Weight (kg)", 
                                V1 = "VKORC1 consensus is A/A", 
                                V2 = "CYP2C9 consensus is *1/*1", 
                                V3 = "Ethnicity is Asian", 
                                A = "Therapeutic dose >= 21 mg per week", 
                                Y = "INR"),
                   statistic = list(
                       gtsummary::all_continuous() ~ "{min}", 
                       gtsummary::all_dichotomous() ~ "{n} ({p}%)"),
                   digits = list(
                       gtsummary::all_continuous() ~ rep(0, 5),
                       gtsummary::all_categorical() ~ rep(0, 2)),
                   type = list(W1 = "continuous",
                               W2 = "continuous", 
                               V1 = "dichotomous",
                               V2 = "dichotomous",
                               V3 = "dichotomous",
                               A = "dichotomous",
                               Y = "continuous"),
                   value = list(V1 ~ 0, V2 ~ 0, V3 ~ 1)
               ) %>% 
    gtsummary::add_stat(
                   fns = gtsummary::all_continuous() ~
                       \(data, variable, ...) stats::median(data[[variable]])
               ) %>%
    gtsummary::add_stat(
                   fns = gtsummary::all_continuous() ~
                       \(data, variable, ...) mean(data[[variable]])
               ) %>%
    gtsummary::add_stat(
                   fns = gtsummary::all_continuous() ~
                       \(data, variable, ...) max(data[[variable]])
               ) %>%
    gtsummary::add_stat(
                   fns = gtsummary::all_continuous() ~
                       \(data, variable, ...) stats::sd(data[[variable]])
               )  %>%
    gtsummary::modify_spanning_header(
                   c("variable", "label") ~
                       "**Variables**", 
                   c("stat_0", "add_stat_1", "add_stat_2", "add_stat_3", "add_stat_4") ~
                       "**Descriptive statistics**") %>%
    gtsummary::modify_header(
                   variable = "", label = "", stat_0 = "**n (%) or Min**",
                   add_stat_1 = "**Median**", add_stat_2 = "**Mean**",
                   add_stat_3 = "**Max**", add_stat_4 = "**SD**") %>%
    gtsummary::modify_footnote(stat_0 = NA)
Variables
Descriptive statistics
n (%) or Min Median Mean Max SD
V1 VKORC1 consensus is A/A 1,150 (36%)



V2 CYP2C9 consensus is *1/*1 2,467 (77%)



V3 Ethnicity is Asian 1,087 (34%)



W1 Height (cm) 125 168 168 202 10.9
W2 Weight (kg) 30 73.0 76.9 238 22.0
A Therapeutic dose >= 21 mg per week 2,336 (73%)



Y INR 4 65.3 74.1 680 47.1

Let us load the data set into Python.

Hide/Show the code
import pandas as pd

dataset = pd.read_csv("data/IWPC_data.csv")
dataset = dataset.drop(['Unnamed: 0'], axis = 1)
dataset = dataset.values
n = np.shape(dataset)[0]

dimV,  dimW = 3, 2
dimWV = dimW + dimV
nb_cat_V = np.array([len(np.unique(dataset[:, 2])),
                     len(np.unique(dataset[:, 3])),
                     len(np.unique(dataset[:, 4]))])

It is convenient to rescale the continuous variables.

Hide/Show the code
from sklearn.preprocessing import MinMaxScaler

scalerW = MinMaxScaler(feature_range = (0, 1))
scalerW.fit(dataset[:, 0:2]);
dataset[:,0:2] = scalerW.transform(dataset[:, 0:2])

scalerY = MinMaxScaler(feature_range = (0, 1))
scalerY.fit(dataset[:, dimWV + 1].reshape(-1, 1));
dataset[:, dimWV + 1] = scalerY.transform(dataset[:, dimWV + 1].reshape(-1, 1)).reshape(1,-1)

We finally define the training and testing data sets.

Hide/Show the code
n_train = int(n/2)
train = dataset[0:n_train, :]
test = dataset[n_train:n, :]
print("The three first observations in 'train':\n",
      "  V_1   V_2   V_3   W_1   W_2   A     Y\n",
      np.around(train[:3, [2, 3, 4, 0, 1, 5, 6]], decimals = 3))
The three first observations in 'train':
   V_1   V_2   V_3   W_1   W_2   A     Y
 [[1.    0.    0.    0.688 0.515 1.    0.121]
 [0.    0.    1.    0.584 0.13  0.    0.029]
 [0.    0.    1.    0.402 0.169 0.    0.04 ]]

7.2 Training the VAE

By running the next chunk of code, we set the VAE’s configuration.

Hide/Show the code
latent_dim = 10
L = 8

nlayers_encoder = 2
nneurons_encoder = 8

nlayers_WV_decoder = 2
nneurons_WV_decoder = 8

nlayers_Alaw_decoder = 2
nneurons_Alaw_decoder = 8

nlayers_AYaY_decoder = 2
nneurons_AYaY_decoder = 16

The next chunk of code repeatedly generates and initializes a VAE then trains it.

Hide/Show the code
epochs = 40
batch_size = np.round(0.2 * np.shape(train)[0]).astype(int)
synth_size = 1000
nb_tries = 100
for i in range(nb_tries):
  print("try #", i, ", ")
  encoder = build_encoder(dimW = dimW,
                          dimV = dimV,
                          latent_dim = latent_dim,
                          L = L,
                          nlayers_encoder = nlayers_encoder,
                          nneurons_encoder = nneurons_encoder)
  WV_decoder = build_WV_decoder(dimW = dimW,
                                dimV = dimV,
                                nb_cat_V = nb_cat_V,
                                latent_dim = latent_dim,
                                L = L,
                                nlayers_WV_decoder = nlayers_WV_decoder,
                                nneurons_WV_decoder = nneurons_WV_decoder,
                                activation_W = "sigmoid")
  Alaw_decoder = build_Alaw_decoder(dimW = dimW,
                                    dimV = dimV,
                                    L = L,
                                    nlayers_Alaw_decoder = nlayers_Alaw_decoder,
                                    nneurons_Alaw_decoder = nneurons_Alaw_decoder)
  AYaY_decoder = build_AYaY_decoder(dimW = dimW,
                                    dimV = dimV,
                                    latent_dim = latent_dim,
                                    L = L,
                                    nlayers_AYaY_decoder = nlayers_AYaY_decoder,
                                    nneurons_AYaY_decoder = nneurons_AYaY_decoder,
                                    activation_Ya = "sigmoid")
  pi = build_pi(dimW = dimW,
                dimV = dimV,
                L = L)
  extract_G = build_extract_G(dimW = dimW,
                              dimV = dimV,
                              L = L)
  generator = build_generator(dimW = dimW,
                              dimV = dimV,
                              encoder = encoder,
                              latent_dim = latent_dim,
                              L = L,
                              WV_decoder = WV_decoder,
                              Alaw_decoder = Alaw_decoder,
                              AYaY_decoder = AYaY_decoder,
                              pi = pi,
                              extract_G = extract_G)
  vae = VAE(generator = generator,
            latent_dim = latent_dim,
            L = L,
            optimizer = optimizer)
  vae.train(data_train = train,
            epochs = epochs,
            batch_size = batch_size)
  random_sample = np.random.choice(n_train, synth_size, replace = True)
  O_test = train[random_sample, :]
  alea = np.random.normal(loc = 0, scale = 1, size = (synth_size, L, latent_dim + dimV + 1))
  synth = vae.generator.predict([O_test, alea])[1]

  synth = synth[:, 0, :]
  synth[:, 0:2] = scalerW.inverse_transform(synth[:, 0:2])
  synth[:, dimWV + 1]\
    = scalerY.inverse_transform(synth[:, dimWV + 1].reshape(-1, 1)).reshape(1,-1)
  synth[:, 2:5] = np.round(synth[:, 2:5])
  np.savetxt("IWPC-synth_try#" + str(i) + ".csv", synth, delimiter = ",") 

Because running the chunk is time-consuming, we stored one trained VAE that we considered good enough. We now turn to its evaluation based on the three criteria discussed in Section 6.2 and Section 6.3.

7.3 Evaluating the quality of the generator

The next chunk of code defines in R the counterparts train and test of the Python objects train and test (keeping only the first 1000 observations), and synth, the collection of 1000 synthetic observations drawn from the generator associated to the VAE that we stored in Section 7.2. For later use (while implementing Criterion 1) we add a dummy column named Z.

Hide/Show the code
dataset <- dataset %>%
    dplyr::select(-1) %>%
    dplyr::mutate(A = as.integer(A),
                  Z = stats::rnorm(dplyr::n()))
train <- dataset %>%
    dplyr::slice_head(n = 1e3)
test <- dataset %>%
    dplyr::slice_tail(n = 1e3)
synth <- readr::read_csv("data/IWPC-synth_try#20.csv",
                         col_names = c(paste0("W", 1:dimW),
                                       paste0("V", 1:dimV),
                                       "A", "Y")) %>%
  dplyr::mutate(A = as.integer(A),
                V1 = round(V1),
                V2 = round(V2),
                V3 = round(V3),
                Z = stats::rnorm(dplyr::n()))

7.3.1 Criterion 1

The next chunk of code implements the first criterion.

Hide/Show the code
frml <- paste0("Z ~", 
               paste0(c(paste0("W", 1:dimW),
                        paste0("V", 1:dimV),
                        "A", "Y"),
                     collapse = " + ")) %>%
  stats::as.formula()
fig <- tibble::tibble(d = c(kknn::kknn(frml, train, synth, k = 1)$D,
                            kknn::kknn(frml, test, synth, k = 1)$D)) %>%
  dplyr::mutate(type = c(rep("training data", dplyr::n()/2),
                         rep("testing data", dplyr::n()/2))) %>%
  ggplot2::ggplot() + 
  ggplot2::stat_ecdf(ggplot2::aes(x = d, color = type)) +
  ggplot2::scale_x_log10() +
  ggplot2::xlab("distance to nearest neighbor") + 
  ggplot2::ylab("empirical cumulative distribution function")
print(fig)
Figure 7: Empirical c.d.f. of the distance to the nearest neighbor within the synthetic observations of the training and of the testing IWPC data points (logarithmic scale). The two c.d.f. are quite close.

The two empirical c.d.f. shown in Figure 7 are not as similar as those in Figure 3. The next chunk of code implements the t-tests comparing the three first moments of \mu_{(n+1):(n+n')} to those of \mu_{1:n}.

Hide/Show the code
(moments_test <- fig$data %>%
   dplyr::mutate(`1st_moment` = mean(fig$data %>%
                                     dplyr::filter(type == "training data") %>%
                                     dplyr::pull(d)),
                 `2nd_moment` = mean(fig$data %>%
                                     dplyr::filter(type == "training data") %>%
                                     dplyr::pull(d) %>% .^2),
                 `3rd_moment` = mean(fig$data %>%
                                     dplyr::filter(type == "training data") %>%
                                     dplyr::pull(d) %>% .^3)) %>%
   dplyr::filter(type == "testing data") %>%
   tidyr::nest() %>%
   dplyr::mutate(
              `1st_moment_test` =
                  purrr::map(data,
                             function(df)
                                 stats::t.test(df$d, mu = df$`1st_moment`[1])$p.val),
              `2nd_moment_test` =
                  purrr::map(data,
                             function(df)
                                 stats::t.test(df$d^2, mu = df$`2nd_moment`[1])$p.val),
              `3rd_moment_test` =
                  purrr::map(data,
                             function(df)
                                 stats::t.test(df$d^3, mu = df$`3rd_moment`[1])$p.val)
          ) %>%
   dplyr::select(-data) %>%
   tidyr::unnest(cols = c(`1st_moment_test`, `2nd_moment_test`, `3rd_moment_test`)))
# A tibble: 1 × 3
  `1st_moment_test` `2nd_moment_test` `3rd_moment_test`
              <dbl>             <dbl>             <dbl>
1          7.59e-13          2.47e-18          8.18e-26

The numerical evidence of discrepancy is compelling. But is it still as compelling when only 100 synthetic observations are used? The next chunk of code addresses this question.

Hide/Show the code
(tibble::tibble(d = c(kknn::kknn(frml, train, synth %>% dplyr::slice_head(n = 100), k = 1)$D,
                      kknn::kknn(frml, test, synth %>% dplyr::slice_head(n = 100), k = 1)$D)) %>%
  dplyr::mutate(type = c(rep("training data", dplyr::n()/2),
                         rep("testing data", dplyr::n()/2))) %>%
  dplyr::mutate(`1st_moment` = mean(fig$data %>%
                                     dplyr::filter(type == "training data") %>%
                                     dplyr::pull(d)),
                 `2nd_moment` = mean(fig$data %>%
                                     dplyr::filter(type == "training data") %>%
                                     dplyr::pull(d) %>% .^2),
                 `3rd_moment` = mean(fig$data %>%
                                     dplyr::filter(type == "training data") %>%
                                     dplyr::pull(d) %>% .^3)) %>%
   dplyr::filter(type == "testing data") %>%
   tidyr::nest() %>%
   dplyr::mutate(
              `1st_moment_test` =
                  purrr::map(data,
                             function(df)
                                 stats::t.test(df$d, mu = df$`1st_moment`[1])$p.val),
              `2nd_moment_test` =
                  purrr::map(data,
                             function(df)
                                 stats::t.test(df$d^2, mu = df$`2nd_moment`[1])$p.val),
              `3rd_moment_test` =
                  purrr::map(data,
                             function(df)
                                 stats::t.test(df$d^3, mu = df$`3rd_moment`[1])$p.val)
          ) %>%
   dplyr::select(-data) %>%
   tidyr::unnest(cols = c(`1st_moment_test`, `2nd_moment_test`, `3rd_moment_test`)))
# A tibble: 1 × 3
  `1st_moment_test` `2nd_moment_test` `3rd_moment_test`
              <dbl>             <dbl>             <dbl>
1            0.0330            0.0135           0.00313

The strength of evidence has dropped considerably, reflected by the larger p-values compared to earlier results. As in Section 6.3, distinguishing N synthetic observations from genuine observations becomes more challenging when N=100 compared to N=1000.

7.3.2 Criterion 2

The next chunk of code implements the second criterion, in its visual form.

Hide/Show the code
fig <- dplyr::bind_rows(test, synth) %>%
  dplyr::select(-Z) %>%
  dplyr::mutate(type = c(rep("testing data", dplyr::n()/2),
                         rep("synthetic data", dplyr::n()/2))) %>%
  tidyr::pivot_longer(-type, names_to = "what", values_to = "values") %>%
  dplyr::mutate(what = dplyr::case_when(
           what == "W1" ~ "W[1]",
           what == "W2" ~ "W[2]",
           what == "V1" ~ "V[1]",
           what == "V2" ~ "V[2]",
           what == "V3" ~ "V[3]",
           TRUE ~ what
  )) %>%
  dplyr::mutate(what = factor(what,
                              levels = c(paste0("W[", 1:dimW, "]"), 
                                         paste0("V[", 1:dimV, "]"),
                                         "A", "Y"))) %>%
  ggplot2::ggplot() +
  ggplot2::stat_ecdf(ggplot2::aes(x = values, color = type)) +
  ggplot2::facet_wrap(~ what,
                      scales = "free_x",
                      labeller = ggplot2::label_parsed) +
  ggplot2::ylab("empirical cumulative distribution function")
print(fig)
Figure 8: Empirical c.d.f. of each covariate based on either the synthetic observations or the testing IPWC data set.

Figure 8 suggests that except for V_2, A and, to a lesser extent, V_1, the marginal laws under the synthetic law do not align well with their counterparts under P. This is confirmed by the following (Kolmogorov-Smirnov or exact Fisher) hypotheses tests:

Hide/Show the code
(ks_tests <- fig$data %>%
  dplyr::filter(!what %in% c(paste0("V[", 1:dimV, "]"), "A")) %>%
  tidyr::nest(data = -what) %>%
  dplyr::mutate(p.val = purrr::map(data,
                                   ~ stats::ks.test(stats::formula(values ~ type),
                                                    data = .x)$p.val)) %>%
  dplyr::select(-data) %>%
  tidyr::unnest(cols = p.val))
# A tibble: 3 × 2
  what      p.val
  <fct>     <dbl>
1 W[1]  1.93e- 25
2 W[2]  5.53e-180
3 Y     4.37e- 32
Hide/Show the code
(fisher_test <- fig$data %>%
  dplyr::filter(what %in% c(paste0("V[", 1:dimV, "]"), "A")) %>%
  dplyr::group_by(what, type) %>%
  dplyr::summarize(`=1` = sum(values),
                   `=0` = dplyr::n()-sum(values)) %>%
  dplyr::select(-type) %>%
  tidyr::nest(data = -what) %>%
  dplyr::mutate(p.val = purrr::map(data,
                                   ~ stats::fisher.test(.x)$p.val)) %>%
  dplyr::select(-data) %>%
  tidyr::unnest(cols = p.val))
# A tibble: 4 × 2
# Groups:   what [4]
  what     p.val
  <fct>    <dbl>
1 V[1]  1.08e- 3
2 V[2]  5.99e- 1
3 V[3]  2.23e-22
4 A     3.59e- 1

One might again question whether this result persists when comparing smaller synthetic and testing datasets. The next chunk of code replicates the previous statistical analysis, this time using two samples of 100 data points each.

Hide/Show the code
smaller_dataset <- fig$data %>%
   dplyr::group_by(type) %>%
   dplyr::slice_head(n = 100) %>%
   dplyr::ungroup()
(ks_tests <- smaller_dataset %>% 
  dplyr::filter(!what %in% c(paste0("V[", 1:dimV, "]"), "A")) %>%
  tidyr::nest(data = -what) %>%
  dplyr::mutate(p.val = purrr::map(data,
                                   ~ stats::ks.test(stats::formula(values ~ type),
                                                    data = .x)$p.val)) %>%
  dplyr::select(-data) %>%
  tidyr::unnest(cols = p.val))
# A tibble: 3 × 2
  what    p.val
  <fct>   <dbl>
1 W[1]  0.373  
2 W[2]  0.00184
3 Y     0.0590 
Hide/Show the code
(fisher_test <- smaller_dataset %>%
  dplyr::filter(what %in% c(paste0("V[", 1:dimV, "]"), "A")) %>%
  dplyr::group_by(what, type) %>%
  dplyr::summarize(`=1` = sum(values),
                   `=0` = dplyr::n()-sum(values)) %>%
  dplyr::select(-type) %>%
  tidyr::nest(data = -what) %>%
  dplyr::mutate(p.val = purrr::map(data,
                                   ~ stats::fisher.test(.x)$p.val)) %>%
  dplyr::select(-data) %>%
  tidyr::unnest(cols = p.val))
# A tibble: 4 × 2
# Groups:   what [4]
  what  p.val
  <fct> <dbl>
1 V[1]  1    
2 V[2]  0.596
3 V[3]  1    
4 A     1    

This time, except for W_{2}, the p-values are large, indicating that the tests cannot detect discrepancies when the synthetic and testing data sets each contain only 100 data points. As observed in Section 6.3, a similar conclusion holds when comparing a synthetic dataset of 100 data points with a testing dataset of 1000 data points, with Y now also associated with a small p-value. This is demonstrated by the next chunk of code.

Hide/Show the code
smaller_synth_dataset <- fig$data %>%
  dplyr::filter(type == "synthetic data") %>%
  dplyr::slice_head(n = 100) %>%
  dplyr::bind_rows(fig$data %>%
                   dplyr::filter(type == "testing data"))
(ks_tests <- smaller_synth_dataset %>% 
  dplyr::filter(!what %in% c(paste0("V[", 1:dimV, "]"), "A")) %>%
  tidyr::nest(data = -what) %>%
  dplyr::mutate(p.val = purrr::map(data,
                                   ~ stats::ks.test(stats::formula(values ~ type),
                                                    data = .x)$p.val)) %>%
  dplyr::select(-data) %>%
  tidyr::unnest(cols = p.val))
# A tibble: 3 × 2
  what        p.val
  <fct>       <dbl>
1 W[1]  0.129      
2 W[2]  0.000000179
3 Y     0.00730    
Hide/Show the code
(fisher_test <- smaller_synth_dataset %>%
  dplyr::filter(what %in% c(paste0("V[", 1:dimV, "]"), "A")) %>%
  dplyr::group_by(what, type) %>%
  dplyr::summarize(`=1` = sum(values),
                   `=0` = dplyr::n()-sum(values)) %>%
  dplyr::select(-type) %>%
  tidyr::nest(data = -what) %>%
  dplyr::mutate(p.val = purrr::map(data,
                                   ~ stats::fisher.test(.x)$p.val)) %>%
  dplyr::select(-data) %>%
  tidyr::unnest(cols = p.val))
# A tibble: 4 × 2
# Groups:   what [4]
  what   p.val
  <fct>  <dbl>
1 V[1]  1     
2 V[2]  1     
3 V[3]  0.0830
4 A     1     

In conclusion, although a large synthetic dataset can be shown to differ in distribution from a large testing dataset, a smaller synthetic dataset does not display clear differences in marginal distributions (apart from W_{3} and potentially Y) when compared to a small testing dataset.

7.3.3 Criterion 3

The next chunk of code builds a super learning algorithm to estimate either the conditional probability that A=1 given (W,V) or the conditional mean of Y given (A,W,V) by aggregating the same 5 base learners as in Section 6.3.

Hide/Show the code
library(SuperLearner)
algo <- function(learning_data, testing_data,
                 outcome, covariates) {
  SL.lib <- c("SL.mean", "SL.glm", "SL.glm.interaction", "SL.ranger", "SL.nnet")
  cvControl <- SuperLearner::SuperLearner.CV.control(V = 10)
  family <- switch(outcome,
                   A = stats::binomial(),
                   Y = stats::gaussian())
  sl <- SuperLearner::SuperLearner(
                        Y = dplyr::pull(learning_data, outcome),
                        X = learning_data[, covariates],
                        newX = testing_data[, covariates], 
                        family = family,
                        SL.library = SL.lib,
                        method = "method.NNLS",
                        cvControl = cvControl)
  list(as.vector(sl$SL.predict))
}

We train the super learning algorithm three times: once on each of two distinct halves of the training data set, and once on half of the synthetic data set. This results in three estimators of the conditional probability that A=1 given (W,V) and three estimators of the conditional mean of Y given (A,W,V). The next chunk of code prepares the three training data sets.

Hide/Show the code
(tib <- dplyr::bind_rows(train, test, synth) %>%
   dplyr::select(-Z) %>%
   dplyr::mutate(type = c(rep("using training data a", dplyr::n()/6),
                          rep("using training data b", dplyr::n()/6),
                          rep("testing data a", dplyr::n()/6),
                          rep("testing data b", dplyr::n()/6),
                          rep("using synthetic data", dplyr::n()/3))) %>%
   dplyr::group_by(type) %>%
   dplyr::mutate(id = 1:dplyr::n()) %>%
   dplyr::ungroup() %>%
   dplyr::filter(id <= dplyr::n()/6) %>%
   dplyr::mutate(type = dplyr::case_when(
                                 stringr::str_detect(type, "testing") ~ "testing data",
                                 TRUE ~ type)) %>%
   dplyr::select(-id) %>%
   tidyr::nest(data = -type) %>%
   dplyr::mutate(testing = data[3]) %>%
   dplyr::filter(type != "testing data"))
# A tibble: 3 × 3
  type                  data               testing             
  <chr>                 <list>             <list>              
1 using training data a <tibble [500 × 7]> <tibble [1,000 × 7]>
2 using training data b <tibble [500 × 7]> <tibble [1,000 × 7]>
3 using synthetic data  <tibble [500 × 7]> <tibble [1,000 × 7]>

The following chunk of code trains the super learning algorithm and evaluates the six resulting estimators on the testing data points. To compare the estimators, we use scatter plots in the same manner as in Section 6.3.

Hide/Show the code
tib %>%
  dplyr::group_by(type) %>%
  dplyr::mutate(AgivenWV = algo(data[[1]], testing[[1]],
                                "A",
                                c(paste0("W", 1:dimW),
                                  paste0("V", 1:dimV))),
                YgivenAWV = algo(data[[1]], testing[[1]],
                                 "Y",
                                 c(paste0("W", 1:dimW),
                                   paste0("V", 1:dimV),
                                   "A"))) %>%
  dplyr::ungroup() %>%
  dplyr::select(type, AgivenWV, YgivenAWV) %>%
  tidyr::unnest(cols = c(AgivenWV, YgivenAWV)) %>%
  tidyr::pivot_longer(!type, names_to = "what", values_to = "preds") %>%
  dplyr::group_by(type, what) %>%
  dplyr::mutate(id = 1:dplyr::n()) %>%
  tidyr::pivot_wider(names_from = type,
                     values_from = "preds",
                     names_prefix = "preds ") %>%
  dplyr::ungroup() %>%
  dplyr::mutate(what = dplyr::case_when(
                                what == "AgivenWV" ~ "A given (W,V)",
                                what == "YgivenAWV" ~ "Y given (A, W, V)"
                              )) %>%
  ggplot2::ggplot() +
  ggplot2::geom_point(ggplot2::aes(`preds using training data a`,
                                   `preds using synthetic data`,
                                   ),
                      color = "#00BFC4", alpha = 0.5) +
  ggplot2::geom_smooth(ggplot2::aes(`preds using training data a`,
                                    `preds using synthetic data`,
                                   ),
                      color = "blue", method = "lm", se = FALSE) +  
  ggplot2::geom_point(ggplot2::aes(`preds using training data a`,
                                   `preds using training data b`,
                                   ),
                      color = "#F8766D", alpha = 0.5) +
  ggplot2::geom_smooth(ggplot2::aes(`preds using training data a`,
                                    `preds using training data b`,
                                   ),
                      color = "red", method = "lm", se = FALSE) +  
  ggplot2::geom_abline() +
    ggplot2::facet_wrap(
                 dplyr::vars(what),
                 scales = "free",
                 labeller = ggplot2::as_labeller(\(str) paste("predicting", str))
             ) +
  ggplot2::xlab("algorithm trained on 1st half of training data set") +
    ggplot2::ylab("algorithm trained either\n
on 2nd half of training data set (red) or\n
on synthetic data set (blue)")
Figure 9: Comparing predicted conditional probabilities that A=1 given (W,V) (left) or predicted conditional means of Y given (A,W,V) (right) when a super learning algorithm is trained twice on two distinct halves of the training IWPC data set (red points) or on the first half of the training IWPC data set, x-axis, and on half of the synthetic data set, y-axis (blue points).

Therein, the spread and asymmetry of the red scatter plot along the y=x line in the xy-plane are evidences of how difficult it is to estimate the conditional probability that A=1 given (W,V). To ease comparisons, we also superimpose the regression lines obtained by fitting two separate linear models on the blue and red data points. By comparison, the blue scatter plot is less widely spread than the red one, around the blue line which deviates more from the y=x line than the red one.

The right-hand side panel is obtained analogously. The red scatter plot is more concentrated around the y=x line than its counterpart in the left-hand side panel. This indicates that it is less difficult to estimate the conditional mean of Y given (A,W,V) than the probability that A=1 given (W,V). The blue scatter plot is more widely spread than the red one, which again reveals a measure of discrepancy between the training and synthetic data. This is counterbalanced by the fact that the blue regression line almost coincides with the y=x line, whereas the red one deviates from it.

7.4 Summary

We implemented the same three criteria as in Section 6.3. Overall, the synthetic observations showed more substantial discrepancies from the IWPC (genuine) ones compared to the analysis on simulated data. Furthermore, detecting significant differences remained challenging with smaller synthetic datasets (100 versus 1000 synthetic observations), but the gaps were more evident in the real-data context.

8 Conclusion

This final section contextualizes our study by reviewing related works, discussing the challenges and limitations encountered, and offering a closing reflection on the broader implications of our approach and findings.

8.2 Challenges and limitations

The results of our study, while informative, are somewhat disappointing. Increasing the quantity of genuine data substantially did not improve the simulator’s performance in this context. This is in stark contrast to fields like image generation, where the abundance of inherent regularities in visual patterns enables models to learn effectively from larger datasets. In our case, the limited improvement may stem from a lack of rich regularities in the genuine data, which constrains the simulator’s ability to capture meaningful structures.

Another challenge lies in the question of sharing the simulator. While it would be appealing to make the simulator widely available, doing so raises concerns about the genuine data required to run the code. This dependency could potentially compromise the privacy or utility of the original dataset, creating additional barriers to adoption.

It is also worth noting that we deliberately neutralized in this article the VAE’s repeated training from random initializations due to the high computational time required. This is telling to the extent that the computational cost underscores a practical limitation of the approach: the trade-off between feasibility and the potential benefits of repeated and extended training cycles, which might otherwise improve the simulator’s performance.

Looking ahead, addressing some of these limitations requires practical and theoretical advances. For instance, future efforts could focus on effectively handling missing data (NA values) within the simulation framework. Additionally, establishing general design principles for simulator architectures could improve their robustness and adaptability across a variety of datasets and applications.

8.3 Closing reflection

As we reflect on the limitations and implications of simulators, it is worth revisiting the paragraph in Section 1.4 where we state that parametric simulators “cannot convincingly replicate the multifaceted interactions and variability inherent in ‘nature’”. The lexical field surrounding “nature” itself warrants reflection.

Historically, the notion of “nature” has evolved significantly. In ancient Greece, philosophers used the term “physis”, nowadays often translated as “nature”, to explore the inherent essence or intrinsic qualities of things. “Natura”, the Roman adaptation, extended these ideas, while medieval thought integrated nature into theological frameworks, portraying it as divine creation.

Often regarded as a figure of the late Renaissance and an early architect of the Scientific Revolution, Bacon emphasized in 1620 the idea of conquering nature, viewing it as an object to be studied, understood, and controlled – “for nature is only to be commanded by obeying her” (Bacon 1854). During the Enlightenment, the concept of “nature” further shifted, increasingly separating it from humanity and framing it as an object of scientific study and exploitation. These developments have frequently served as a conceptual tool to justify humanity’s dominion and looting of the non-human world.

In this context, referring to “nature” as something simulations seek to imitate is a testament to the evolving notion of “nature,” now encompassing phenomena like human health. This shift should be questioned, especially if it follows the Enlightenment logic of treating “nature” as an object to be understood, controlled, and exploited. Applying such a framework to humans risks reducing individuals to abstract data points or exploitable systems, ignoring their intrinsic complexity and moral agency.

Recognizing these risks invites us to critically examine not only the limitations of simulators but also their ethical and philosophical implications. Among these are the challenges posed by a lack of fair representability in the data used to train algorithms, which can perpetuate existing inequities or create new ones. Fairness becomes a central issue, particularly when simulations influence decisions that affect diverse populations, as the underlying models may not account for all relevant perspectives or experiences. Furthermore, the use of advanced simulations can contribute to elitism, as access to the expertise and computational resources needed to develop and deploy such systems is far from being universal, with numerous countries facing more urgent challenges. Finally, the environmental and financial cost of training complex algorithms, particularly those based on generative AI, raises questions about sustainability and the trade-offs between progress and resource consumption.

Fiction has always provided a space to explore hypothetical scenarios that might be impractical, impossible, or even unethical in reality. Utopian and dystopian literature, for example, simulates alternative societies to test ideas about governance, morality, and human behavior. Similarly, speculative fiction pushes boundaries by imagining futures shaped by scientific and technological advancements. In doing so, fiction serves as a conceptual laboratory, allowing its creators and audiences to investigate possibilities and their consequences. This creative exploration, which has long shaped human understanding, should continue to inform and inspire the design and purpose of computer simulations.

9 Acknowledgements

A number of open-source libraries made this work possible. In Python, we relied on the packages numpy (Harris et al. 2020), pandas (McKinney 2010), random (Van Rossum 2020), sklearn (Pedregosa et al. 2011) and tensorflow (Abadi et al. 2015). In R, we used the packages gtsummary (Sjoberg et al. 2021), kknn (Schliep and Hechenbichler 2016), SuperLearner (Polley et al. 2024) and tidyverse (Wickham et al. 2019). We are grateful to the developers and maintainers of these packages for their contributions to the research community.

The authors warmly thank Isabelle Drouet (Sorbonne Université) and Alexander Reisach (Université Paris Cité) for their valuable feedback and insightful comments on this project.

Bibliography

Abadi, Martín, Ashish Agarwal, Paul Barham, Eugene Brevdo, Zhifeng Chen, Craig Citro, Greg S. Corrado, et al. 2015. TensorFlow: Large-Scale Machine Learning on Heterogeneous Systems.” https://www.tensorflow.org/.
Aggarwal, Alankrita, Mamta Mittal, and Gopi Battineni. 2021. “Generative Adversarial Network: An Overview of Theory and Applications.” International Journal of Information Management Data Insights 1 (1): 100004. https://doi.org/https://doi.org/10.1016/j.jjimei.2020.100004.
Aigner, Martin, and Günter M. Ziegler. 2018. Proofs from THE BOOK. Springer.
Alaa, Ahmed M., Boris van Breugel, Evgeny Saveliev, and Mihaela van der Schaar. 2022. How Faithful is your Synthetic Data? Sample-level Metrics for Evaluating and Auditing Generative Models.” https://arxiv.org/abs/2102.08921.
Aristote. 2006. Poétique. Paris: Édition Mille et une nuits.
Arjovsky, Martin, Soumith Chintala, and Léon Bottou. 2017. Wasserstein Generative Adversarial Networks.” In Proceedings of the 34th International Conference on Machine Learning, edited by Doina Precup and Yee Whye Teh, 70:214–23. Proceedings of Machine Learning Research. PMLR.
Arora, Sanjeev, Nadav Cohen, and Elad Hazan. 2018. “On the Optimization of Deep Networks: Implicit Acceleration by Overparameterization.” In International Conference on Machine Learning, 244–53. PMLR.
Athey, Susan, Guido W. Imbens, Jonas Metzger, and Evan Munro. 2024. “Using Wasserstein Generative Adversarial Networks for the Design of Monte Carlo Simulations.” Journal of Econometrics 240 (2): 105076. https://doi.org/https://doi.org/10.1016/j.jeconom.2020.09.013.
Bacon, Francis. 1854. Novum Organum 1620. Vol. 3. Philadelphia: Parry & MacMillan. https://history.hanover.edu/texts/bacon/novorg.html.
Baowaly, Mrinal Kanti, Chia-Ching Lin, Chao-Lin Liu, and Kuan-Ta Chen. 2018. “Synthesizing Electronic Health Records Using Improved Generative Adversarial Networks.” Journal of the American Medical Informatics Association 26: 228–41.
Barberousse, Anouk, and Pascal Ludwig. 2000. “Les Modèles Comme Fiction.” Philosophie 68: 16–43.
Carruthers, Peter. 2002. “Human Creativity: Its Cognitive Basis, Its Evolution, and Its Connections with Childhood Pretence.” British Journal for the Philosophy of Science 53: 225–49.
Che, Zhengping, Yu Cheng, Shuangfei Zhai, Zhaonan Sun, and Yan Liu. 2017. “Boosting Deep Learning Risk Prediction with Generative Adversarial Networks for Electronic Health Records.” In 2017 IEEE International Conference on Data Mining (ICDM), 787–92. https://doi.org/10.1109/ICDM.2017.93.
Choi, Edward, Siddharth Biswal, Bradley Malin, Jon Duke, Walter F. Stewart, and Jimeng Sun. 2017. “Generating Multi-Label Discrete Patient Records Using Generative Adversarial Networks.” In Proceedings of the 2nd Machine Learning for Healthcare Conference, edited by Finale Doshi-Velez, Jim Fackler, David Kale, Rajesh Ranganath, Byron Wallace, and Jenna Wiens, 68:286–305. Proceedings of Machine Learning Research. PMLR. https://proceedings.mlr.press/v68/choi17a.html.
Choromanska, Anna, Mikael Henaff, Michael Mathieu, Gérard Ben Arous, and Yann LeCun. 2015. “The Loss Surfaces of Multilayer Networks.” In Artificial Intelligence and Statistics, 192–204. PMLR.
Creswell, Antonia, Tom White, Vincent Dumoulin, Kai Arulkumaran, Biswa Sengupta, and Anil A. Bharath. 2018. Generative Adversarial Networks: An Overview.” IEEE Signal Processing Magazine 35 (1): 53–65. https://doi.org/10.1109/MSP.2017.2765202.
de Saint-Exupéry, Antoine. 1943. Le Petit Prince. New-York: Reynal & Hitchcock.
Figueira, Alvaro, and Bruno Vaz. 2022. “Survey on Synthetic Data Generation, Evaluation Methods and GANs.” Mathematics 10 (15). https://doi.org/10.3390/math10152733.
Glorot, Xavier, and Yoshua Bengio. 2010. “Understanding the Difficulty of Training Deep Feedforward Neural Networks.” In Proceedings of the Thirteenth International Conference on Artificial Intelligence and Statistics, edited by Yee Whye Teh and Mike Titterington, 9:249–56. Proceedings of Machine Learning Research. Chia Laguna Resort, Sardinia, Italy: PMLR. https://proceedings.mlr.press/v9/glorot10a.html.
Goodfellow, Ian, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. 2014. “Generative Adversarial Nets.” In Advances in Neural Information Processing Systems, edited by Z. Ghahramani, M. Welling, C. Cortes, N. Lawrence, and K. Q. Weinberger. Vol. 27. Curran Associates, Inc.
Gui, Jie, Zhenan Sun, Yonggang Wen, Dacheng Tao, and Jieping Ye. 2023. “A Review on Generative Adversarial Networks: Algorithms, Theory, and Applications.” IEEE Transactions on Knowledge and Data Engineering 35 (4): 3313–32. https://doi.org/10.1109/TKDE.2021.3130191.
Gulrajani, Ishaan, Faruk Ahmed, Martin Arjovsky, Vincent Dumoulin, and Aaron C Courville. 2017. “Improved Training of Wasserstein GANs.” In Advances in Neural Information Processing Systems, edited by I. Guyon, U. Von Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett. Vol. 30. Curran Associates, Inc.
Harris, Charles R., K. Jarrod Millman, Stéfan J.van der Walt, Ralf Gommers, Pauli Virtanen, DavidCournapeau, Eric Wieser, et al. 2020. “Array Programming with NumPy.” Nature 585 (7825): 357–62. https://doi.org/10.1038/s41586-020-2649-2.
Homère. 2000. L’Odyssée. Paris: La Découverte.
Kingma, Diederik P., and Max Welling. 2014. Auto-Encoding Variational Bayes.” In 2nd International Conference on Learning Representations, ICLR 2014, Banff, AB, Canada, April 14-16, 2014, Conference Track Proceedings, edited by Yoshua Bengio and Yann LeCun.
Kocaoglu, Murat, Christopher Snyder, Alexandros G. Dimakis, and Sriram Vishwanath. 2017. CausalGAN: Learning Causal Implicit Generative Models with Adversarial Training.” CoRR abs/1709.02023. http://arxiv.org/abs/1709.02023.
Lauritzen, Steffen L. 1996. Graphical Models. Vol. 17. Oxford Statistical Science Series. The Clarendon Press, Oxford University Press, New York.
Lee, Scott H. 2018. “Natural Language Generation for Electronic Health Records.” Npj Digital Medicine 1 (1). https://doi.org/10.1038/s41746-018-0070-0.
Lu, Yingzhou, Minjie Shen, Huazheng Wang, Xiao Wang, Capucine van Rechem, Tianfan Fu, and Wenqi Wei. 2024. “Machine Learning for Synthetic Data Generation: A Review.” https://arxiv.org/abs/2302.04062.
McKinney, Wes. 2010. Data Structures for Statistical Computing in Python.” In Proceedings of the 9th Python in Science Conference, edited by Stéfan van der Walt and Jarrod Millman, 56–61. https://doi.org/ 10.25080/Majora-92bf1922-00a .
Metropolis, Nicholas, and Stanislaw Ulam. 1949. “The Monte Carlo Method.” Journal of the American Statistical Association 44 (247): 335–41.
Morris, T. P., I. R. White, and M. J. Crowther. 2019. “Using Simulation Studies to Evaluate Statistical Methods.” Statistics in Medicine 38. https://doi.org/DOI: 10.1002/sim.8086.
Neal, Brady, Chin-Wei Huang, and Sunand Raghupathi. 2021. “RealCause: Realistic Causal Inference Benchmarking.” https://arxiv.org/abs/2011.15007.
Parikh, Harsh, Carlos Varjao, Louise Xu, and Eric Tchetgen Tchetgen. 2022. “Validating Causal Inference Methods.” In Proceedings of the 39th International Conference on Machine Learning, edited by Kamalika Chaudhuri, Stefanie Jegelka, Le Song, Csaba Szepesvari, Gang Niu, and Sivan Sabato, 162:17346–58. Proceedings of Machine Learning Research. PMLR.
Pearl, Judea, and Dana Mackenzie. 2018. The Book of Why: The New Science of Cause and Effect. New-York: Basic Books.
Pedregosa, Fabian, Gaël Varoquaux, Alexandre Gramfort, Vincent Michel, Bertrand Thirion, Olivier Grisel, Mathieu Blondel, et al. 2011. “Scikit-Learn: Machine Learning in Python.” Journal of Machine Learning Research 12 (Oct): 2825–30.
Petrakos, Niki Z., Erica E. M. Moodie, and Nicolas Savy. 2025. “A Framework for Generating Realistic Synthetic Tabular Data in a Randomized Controlled Trial Setting.” https://arxiv.org/abs/2501.17719.
Platon. 2002. La République. Paris: Flammarion.
Polley, Eric, Erin LeDell, Chris Kennedy, and Mark van der Laan. 2024. SuperLearner: Super Learner Prediction. https://CRAN.R-project.org/package=SuperLearner.
R Core Team. 2020. R: A Language and Environment for Statistical Computing. Vienna, Austria: R Foundation for Statistical Computing. https://www.R-project.org/.
Raspe, Rudolf E. 1866. Aventures Du Baron de Münchausen. Paris: Furne, Jouvet et Cie. https://gallica.bnf.fr/ark:/12148/bpt6k6582615r.
Rezende, Danilo Jimenez, Shakir Mohamed, and Daan Wierstra. 2014. “Stochastic Backpropagation and Approximate Inference in Deep Generative Models.” In Proceedings of the 31st International Conference on Machine Learning, edited by Eric P. Xing and Tony Jebara, 32:1278–86. Proceedings of Machine Learning Research 2. Bejing, China: PMLR. http://proceedings.mlr.press/v32/rezende14.html.
Rostand, Edmond. 2005. Cyrano de Bergerac. Paris: E. Fasquelle. https://gallica.bnf.fr/ark:/12148/bpt6k64960772.
Schliep, Klaus, and Klaus Hechenbichler. 2016. kknn: Weighted k-Nearest Neighbors. https://CRAN.R-project.org/package=kknn.
Shelley, Mary. 1818. Frankenstein; or, The Modern Prometheus. London: Lackington, Hughes, Harding, Marvor & Jones.
Sjoberg, Daniel D., Karissa Whiting, Michael Curry, Jessica A. Lavery, and Joseph Larmarange. 2021. Reproducible Summary Tables with the gtsummary Package.” The R Journal 13: 570–80. https://doi.org/10.32614/RJ-2021-053.
Solly, Meilan. 2023. “The Real History Behind the Archimedes Dial in ‘Indiana Jones and the Dial of Destiny’.” Smithsonian Magazine. https://www.smithsonianmag.com/history/the-real-history-behind-archimedes-dial-in-indiana-jones-and-the-dial-of-destiny-180982435/.
The International Warfarin Pharmacogenetics Consortium. 2009. “Estimation of the Warfarin Dose with Clinical and Pharmacogenetic Data.” New England Journal of Medicine 360 (8): 753–64. https://doi.org/10.1056/NEJMoa0809329.
Tokarczuk, Olga. 2021. The Books of Jacob. Melbourne: The Text Publishing Company.
van der Vaart, Aad W. 1998. Asymptotic Statistics. Vol. 3. Cambridge Series in Statistical and Probabilistic Mathematics. Cambridge University Press, Cambridge.
Van Rossum, Guido. 2020. The Python Library Reference, Release 3.8.2. Python Software Foundation.
Van Rossum, Guido, and Fred L. Drake. 2009. Python 3 Reference Manual. Scotts Valley, CA: CreateSpace.
Vondrick, Carl, Hamed Pirsiavash, and Antonio Torralba. 2016. “Generating Videos with Scene Dynamics.” In Proceedings of the 30th International Conference on Neural Information Processing Systems, 613–21. NIPS’16. Red Hook, NY, USA: Curran Associates Inc.
Walton, Kendall L. 1993. Mimesis as Make-Believe: On the Foundations of the Representational Arts. Harvard University Press.
Wickham, Hadley, Mara Averick, Jennifer Bryan, Winston Chang, Lucy D’Agostino McGowan, Romain François, Garrett Grolemund, et al. 2019. “Welcome to the tidyverse.” Journal of Open Source Software 4 (43): 1686. https://doi.org/10.21105/joss.01686.
Xu, Lei, Maria Skoularidou, Alfredo Cuesta-Infante, and Kalyan Veeramachaneni. 2019. “Modeling Tabular Data Using Conditional GAN.” In Proceedings of the 33rd International Conference on Neural Information Processing Systems. Red Hook, NY, USA: Curran Associates Inc.
Yi, Xin, Ekta Walia, and Paul Babyn. 2019. “Generative Adversarial Network in Medical Imaging: A Review.” Medical Image Analysis 58: 101552. https://doi.org/https://doi.org/10.1016/j.media.2019.101552.

Reuse

Citation

BibTeX citation:
@article{boulet2025,
  author = {Boulet, Sandrine and Chambaz, Antoine},
  publisher = {Société Française de Statistique},
  title = {Draw {Me} a {Simulator}},
  journal = {Computo},
  date = {2025-03-21},
  url = {https://computo.sfds.asso.fr/template-computo-quarto},
  doi = {xxxx},
  issn = {2824-7795},
  langid = {en},
  abstract = {This study investigates the use of Variational
    Auto-Encoders to build a simulator that approximates the law of
    genuine observations. Using both simulated and real data in
    scenarios involving counterfactuality, we discuss the general task
    of evaluating a simulator’s quality, with a focus on comparisons of
    statistical properties and predictive performance. While the
    simulator built from simulated data shows minor discrepancies, the
    results with real data reveal more substantial challenges. Beyond
    the technical analysis, we reflect on the broader implications of
    simulator design, and consider its role in modeling reality.}
}
For attribution, please cite this work as:
Boulet, Sandrine, and Antoine Chambaz. 2025. “Draw Me a Simulator.” Computo, March. https://doi.org/xxxx.