Important: The generative model presented in this article comes from an assignment in USC EE660, a course taught by Professor Stephen Lyle Tu. Since I skipped some prerequisite ML courses, my foundational understanding of ML principles is not particularly strong, and I am not particularly interested in ML theory. As a result, my grade in this course was relatively low. However, Professor Stephen still put a great deal of effort into this course, teaching the theoretical foundations of many algorithms. Although theory can be dry, it is indeed a crucial part of understanding algorithms! I highly recommend taking this course—don’t worry about your GPA; just learn what you want to learn.
It is important to note that the generative model presented in this article is very simple. You don’t even need to implement the model yourself; you can simply use the corresponding libraries. Using open-source libraries is much more convenient. This article aims to highlight some finer details of generative algorithms. Additionally, simply calling GPT can generate similar code.
@nn.compact def __call__(self, x: jax.Array) -> jax.Array: for f in self.features[:-1]: x = nn.Dense(f)(x) x = self.activation(x) x = nn.Dense(self.features[-1])(x) return x
# Define the encoder class Encoder(nn.Module): input_dim: int latent_dim: int @nn.compact def __call__(self, data: jax.Array): mlp = MLP(features=[512, 512, 512, self.latent_dim * 2]) # Output includes mean and log variance x = mlp(data) mean, log_var = jnp.split(x, 2, axis=-1) return mean, log_var
# Define the decoder class Decoder(nn.Module): output_dim: int