Generative Models & EM

Understanding EM

**Notes by Harrison Wang (2022), Viet Vu (2021)

You can also download a PDF copy of the slides from Friday section.

1. Generative Model: parameters and variables

1.1. Types of parameters and variables

The parameters and variables of any inference model can be grouped as follows:

  • Unknown parameters: parameters that we don't know the exact values of. We use them to generate variables.

  • Known parameters: parameters that we know the exact values of. We also use them to generate variables.

  • hidden variables: these variables are generated from the parameters (unknown/known), but we do not observe them.

  • Observed data: these variables are generated from the hidden variables and the parameters, and we observe them (i.e. we know their values).

Example 1 (RNA-Seq) Let's say that there are \(M\) transcripts, and \(N\) reads.

  • The nucleotide abundances \(\nu_{1}, \nu_{2},\cdots, \nu_{M}\) are unknown parameters. Recall that in week 2, kallisto attempts to estimate these parameters using the reads. However, they are not variables, because they do not change from experiment to experiment: we do not know the values of \(\boldsymbol{\nu}\), but they are fixed.

  • The lengths of the transcripts, \(L_{1}, L_{2},\cdots, L_{M}\) are known parameters. We simply know the transcript length--either it's given in a problem statement, or, more realistically, we can just look up the transcript length if you know which transcripts you're dealing with.

  • To make every read \(R_{n}\) (\(1\leq n\leq N\)), we have to know 1) which transcript it comes from, \(G_{n}\), 2) the starting point of the read on that transcript, \(S_{n}\), and 3) the orientation \(O_{n}\). We call \(G_{n}\), \(S_{n}\), and \(O_{n}\) hidden variables: we do not know what the variables are by just looking at \(R_{n}\), but \(\{G_{n}, S_{n}, O_{n}\}\) generate \(R_{n}\). In the joint factorization, this dependence is encoded in the term \(P(R_{n}|G_{n}, S_{n}, O_{n})\). In the graphical model, we see the arrows pointing to \(R_{n}\) from \(G_{n}\), \(S_{n}\), and \(O_{n}\).

  • The reads, \(R_{n}\) (\(1\leq n\leq N\)) are observed data. The output of the RNA-seq experiment is the reads.

  • It turns out that we also observe the actual nucleotide sequences of the transcripts--but that's often dealt with using alignment algorithms before being fed into a generative model.

Example 2 (mixture negative binomial fitting) Let's say we have Wiggins' scRNA-seq dataset from week 5, where there are \(N\) points belonging to \(K\) equal to 5 groups that are lognormally distributed, and we tried to fit on pset 5 using a mixture negative binomial model, which you all surely did very well on.

  • We didn't know the centroids of the five groups, \(\mu_{k}\), or the mixture coefficients, \(\pi_{k}\). They're unknown parameters, which we spent a while trying to get.

  • The dispersion of each group (same for all groups), \(\phi\), is a known parameter. We were told this number to use was 0.3--this value is empirically determined and general, which means we can know the parameter value from other peoples' RNA-seq data.

  • The group identity of each data point \(G_n\), is a hidden variable. When we look at our raw data, we do not know which group each data point belongs to. But when we simulate our raw data, we know it.

  • The data points \(x_{1}, x_{2},\cdots, x_{N}\), are observed data.

1.2. Generative Model

The generative model specifies the relationships between unknown parameters, known parameters, hidden variables, and observed data. Using it is like doing a simulation.

In detail, a generative model explains the following:

  • Which parameters are known and which parameters are unknown
  • How the hidden variables are generated from the parameters
  • How the observed data are generated from the parameters and the hidden variables

The most general flowchart of a generative model is as follows:


Example (mixture negative binomial fitting) The generative model of mixture negative binomial fitting meets the above specs:

  • parameters: we specified them in the previous section.

  • generating the hidden variables: For the \(n\)-th point, its group \(G_{n}\) is randomly chosen according to the probabilities \(\pi_{k}\). The normalized "size" of each group was \(\pi_{k}\).

  • Generating the observed data: With \(G_{n}=j\), the data point \(x_{n}\) is generated from a negative binomial distribution with mean \(\mu_{j}\) and dispersion \(\phi\):

    $$P(x_{n}|(G_{n}=j)\sim\mathcal{NB}(\mu_{j}, \sigma^{2})$$

  • We can see that the hidden variable \(G_{n}\) specifies the group that the \(n\)-th point belongs to, and when we already know the group, \(x_{n}\) will be generated by parameters unique to that group (in mixture NB fitting, it is from the centroid \(\mu_n\) corresponding to group \(G_{n}\), along with the dispersion \(\phi\) and the mixture coefficients \(\pi_k\)).

Question: For RNA-seq, can you specify how we generate hidden variables from the parameters, and how we generate observed data from the parameters and the hidden variables?

2. Expectation-Maximization (EM)

\subsection{2.1. The goal of the game} In this setting, we have access to the observed data. For example, in the RNA-seq experiment, we observe the reads \(\{R_{n}\}_{n=1}^{N}\). However, we do not have access to both the unknown parameters and the hidden variables.

We would like to use the generative model over and over again, each time updating the unknown parameters until the model outputs match the observations.


2.2. General scheme for EM

We would notate \(G_{n}\) as the hidden variable (group assignment) of the \(n\)-th data point, and \(x_{n}\) the \(n\)-th data point. In this section, the unknown parameters are denoted as a vector \(\boldsymbol{\theta}\).

  1. Initialization: make a guess for the unknown parameters

  2. In the RNA-Seq example, we would make a guess for the nucleotide abundances vector, \(\boldsymbol{\nu}\). Since the elements of \(\boldsymbol{\nu}\) must sum to \(1\), we sample a random probability vector (as using the Dirichlet distribution in lecture).

  3. In the mixture negative binomial example, we would make a guess for the centroids.

  4. Expectation Step: infer the hidden variables, given the observed data and the guessed parameters

  5. In words, given a point \(x_{n}\) and guessed parameters \(\boldsymbol{\theta}\), what is the probability that the point belongs to group \(k\)? For example, given a read \(R_{n}\) and nucleotide abundances \(\boldsymbol{\nu}\), what is the probability that this read belongs to transcript \(k\)?

  6. Mathematically, we calculate

    $$P(G_{n}=k|x_{n}, \boldsymbol{\theta})$$
    for all points \(n\) from \(1\) to \(N\), and for all groups \(k\) from \(1\) to \(K\).

  7. At the expectation step, we always use Bayes' rule to calculate these probabilities:

    $$q_{nk}=P(G_{n}=k|x_{n}, \boldsymbol{\theta})=\dfrac{P(x_{n}|G_{n}=k,\boldsymbol{\theta})P(G_{n}=k|\boldsymbol{\theta})}{\displaystyle\sum_{j=1}^{K}P(x_{n}|G_{n}=j,\boldsymbol{\theta})P(G_{n}=j|\boldsymbol{\theta})}$$

  8. The probabilities \(P(G_{n}=j|\boldsymbol{\theta})\) and \(P(x_{n}|G_{n}=j, \boldsymbol{\theta})\) are specified by the generative model.

    • The first probability represents how to assign a group for point \(n\), given the guessed parameters
    • The second probability represents how to generate the observed data point \(n\), given the group \(G_{n}\) and the guessed parameters.
    • Remind yourself of the three specifications of a generative model!
  9. Maximization Step: given the probabilities of the hidden variables \(\{G_{n}\}_{n=1}^{N}\), estimate the parameters \(\boldsymbol{\theta}^{\text{ML}}\) through maximum likelihood. For the purposes of this week, the only parameters we will have to infer are the nucleotide abundances, \(\mathbf{\nu}\). To get the maximum-likelihood estimate for \(\boldsymbol{\nu}\), we need to have the estimated counts of reads to transcript \(k\), \(\hat{c}_{k}\) for all \(1\leq k\leq M\) (there are \(M\) transcripts, or \(M\) groups to assign reads to). The formula for the estimated counts is

    and the estimated nucleotide abundance for transcript \(k\) is
    where \(N\) is the number of reads.

  10. Iterate: We let \(\boldsymbol{\theta}^{\text{ML}}\) be our next guessed parameters, and go back to the expectation step.

  11. Convergence Criterion: Somehow we have to stop the iterations; we cannot iterate forever. Here we introduce the log-likelihood of the data:

    $$\mathcal{L}=\sum_{n=1}^{N} \log P(x_{n}|\boldsymbol{\theta})=\sum_{n=1}^{N}\sum_{k=1}^{K}\log P(x_{n}|G_{n}=k,\boldsymbol{\theta})+\log P(G_{n}=k|\boldsymbol{\theta})$$

  12. We have to calculate using logarithms, because there will be 100000 reads! The likelihood would be underflowed to 0 on any computer, so we have to resort to taking the logarithm of the likelihood, or the log-likelihood.

  13. We iterate until \(\mathcal{L}\) stops changing. A better strategy is to stop changing \(\mathcal{L}\) when the change is negligibly small:

    The threshold \(\epsilon\) is usually set to anywhere around \(0.01\).

2.3. Strategies for implementing EM

  • Initialization: The unknown parameters \(\boldsymbol{\theta}\) can be a list, or an NumPy array. NumPy arrays are recommended.

  • Expectation Step: Encode \(\{q_{nk}\}\) (\(1\leq n\leq N, 1\leq k\leq K\)) in a list of lists, or \(N\times K\) matrix.

    • For every \(n\), make a list (array) of \(K\) elements, where the \(j\)-th element is \(P(x_{n}|G_{n}=j,\boldsymbol{\theta})P(G_{n}=j|\boldsymbol{\theta})\). Then, normalize the list (array); in other words, divide each element of the list (array) by the sum of that list (array). The output of this will be \(\{q_{nk}\}_{k=1}^{K}\) for a fixed value of \(n\).
  • Maximization Step: No particular strategies; follow the formulas in Section 2.2, part 3.

  • Convergence Criterion: To calculate the log-likelihood of the data, we should not loop from \(n=1\) to \(N\) as the formula in Section 2.2, part 5 suggests. Note that if \(x_{i}=x_{j}\), then

    In other words, if the \(i\)-th observation and the \(j\)-th observation are the same, then their probabilities will be the same. Note that
    $$\mathcal{L}=\sum_{n=1}^{N} \log P(x_{n}|\boldsymbol{\theta})$$
    and so do we need to go through \(x_{j}\) in our for loop if we already know \(x_{i}=x_{j}\)?

When doing the problem set, you should consider this when computing the log-likelihood.