Implementation of DLRM: Embedding Operations #4227
Unanswered
Sir-NoChill
asked this question in
Q&A
Replies: 1 comment
-
Still working on this problem, I switched from nnx back to flax.linen just for more historical code examples. My current model looks like the following, which (I think) is correct: class DLRM_Net(nn.Module):
m_spa: int
ln_emb: List[int]
ln_bot: List[int]
ln_top: List[int]
arch_interaction_op: str
arch_interaction_itself: bool = False
sigmoid_bot: int = -1
sigmoid_top: int = -1
loss_threshold: float = 0.0
weighted_pooling: Optional[str] = None
def setup(self):
self.embeddings = [nn.Embed(num_embeddings=n, features=self.m_spa)
for n in self.ln_emb]
self.bot_mlp = self.create_mlp(self.ln_bot, self.sigmoid_bot)
self.top_mlp = self.create_mlp(self.ln_top, self.sigmoid_top)
def create_mlp(self, ln, sigmoid_layer):
layers = []
for i in range(len(ln) - 1):
layers.append(nn.Dense(features=ln[i + 1]))
if i == sigmoid_layer:
layers.append(nn.sigmoid)
else:
layers.append(nn.relu)
return nn.Sequential(layers)
def apply_embedding(self, lS_o, lS_i, embeddings):
"""Embeddings lookup for sparse features using jax.lax.gather"""
ly = []
for k in range(len(embeddings)):
E = embeddings[k]
embeds = E.apply(lS_i[k], lS_o[k])
# Perform sum over the range of gathered embeddings specified by lS_o
V = jnp.sum(embeds, axis=-1)
ly.append(V)
return ly
def interact_features(self, x, ly):
if self.arch_interaction_op == "dot":
T = jnp.concatenate([x] + ly, axis=1).reshape(x.shape[0], -1, x.shape[1])
Z = jnp.matmul(T, jnp.transpose(T, axes=(0, 2, 1)))
offset = 1 if self.arch_interaction_itself else 0
li = jnp.array([i for i in range(Z.shape[1]) for j in range(i + offset)])
lj = jnp.array([j for i in range(Z.shape[2]) for j in range(i + offset)])
Zflat = Z[:, li, lj]
R = jnp.concatenate([x, Zflat], axis=1)
elif self.arch_interaction_op == "cat":
R = jnp.concatenate([x] + ly, axis=1)
else:
raise ValueError(f"Unsupported interaction op: {self.arch_interaction_op}")
return R
def __call__(self, dense_x, lS_o, lS_i):
x = self.bot_mlp(dense_x)
ly = self.apply_embedding(lS_o, lS_i, self.embeddings)
z = self.interact_features(x, ly)
p = self.top_mlp(z)
if 0.0 < self.loss_threshold < 1.0:
p = jnp.clip(p, self.loss_threshold, 1.0 - self.loss_threshold)
return p though I am having trouble with the |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello community!
I am currently trying to reimplement Meta's DLRM algorithm, specifically the architecture discussed in this paper for profiling and performance research. I am having some trouble with writing a flax implementation of the sparse vector embedding code:
In Meta's implementation, they initialize a
torch.EmbeddingBag
as follows (see line in the original code):but they subsequently use it like this (refer to this line):
However I cannot find a way to duplicate this functionality using the flax
nnx.Embed
orlinen.Embed
class. I am also relatively new to jax/flax so I apologize in advance for my further questions :) My current model is as follows (using nnx):Beta Was this translation helpful? Give feedback.
All reactions