DSM

Important: The generative model presented in this article comes from an assignment in USC EE660, a course taught by Professor Stephen Lyle Tu. Since I skipped some prerequisite ML courses, my foundational understanding of ML principles is not particularly strong, and I am not particularly interested in ML theory. As a result, my grade in this course was relatively low. However, Professor Stephen still put a great deal of effort into this course, teaching the theoretical foundations of many algorithms. Although theory can be dry, it is indeed a crucial part of understanding algorithms! I highly recommend taking this course—don’t worry about your GPA; just learn what you want to learn.

It is important to note that the generative model presented in this article is very simple. You don’t even need to implement the model yourself; you can simply use the corresponding libraries. Using open-source libraries is much more convenient. This article aims to highlight some finer details of generative algorithms. Additionally, simply calling GPT can generate similar code.

重要:本文展示的生成式模型来自USC EE660的作业。该科程来自Stephen Lyle Tu。由于我跳过了一些ML的前置科程,一些ML的基础原理并没有特别扎实,并且我对ML的理论知识其实并不感兴趣。所以我这门课程的分数较低。但是Pro Stephen依旧在这门科程上付出了许多心血。教授了许多算法的理论基础。尽管理论是枯燥的,但这确实是理解算法特别重要的一环!我强烈推荐去上这门课。不要担心自己的GPA。想学什么就学什么。
需要注意的是本文展示的生成式模型是十分简单的,你甚至不用自己去实现这个模型,直接去使用对应的库就可以。使用开源库更加方便,此文指在展示生成式算法的一些细枝末节。并且简单的调用GPT也可以生成出这种代码。

Model

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

Activation = Callable[[jax.Array], jax.Array]

class MLP(nn.Module):

features: Sequence[int]
activation: Activation = nn.swish

@nn.compact
def __call__(self, x: jax.Array) -> jax.Array:
for f in self.features[:-1]:
x = nn.Dense(f)(x)
x = self.activation(x)
x = nn.Dense(self.features[-1])(x)
return x

model = MLP(features=[64, 64, 2])

params = model.init(next(prng_seq), X_train[:1, ...])
opt_state = optimizer.init(params)
bm = BatchManager(data=X_train, batch_size=128, key=next(prng_seq))

Loss Function

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

def dsm_loss(params: chex.ArrayTree, batch: jax.Array, sigma: float, k: int) -> float:

n, d = batch.shape
noise = jax.random.normal(jax.random.PRNGKey(0), shape=(n, k, d)) # w_i^{j,k} ~ N(0, I)
perturbed_inputs = batch[:, None, :] + sigma * noise


model_outputs = jax.vmap(lambda x_j: model.apply(params, x_j), in_axes=(0,))(perturbed_inputs.reshape(-1, d))
model_outputs = model_outputs.reshape(n, k, d)
term1 = jnp.linalg.norm(perturbed_inputs + sigma**2 * model_outputs - batch[:, None, :], axis=-1) ** 2

loss = jnp.sum(term1) / (n * k * sigma**4)

return loss

Generate samples through langevin sampling