TL;DR: we can use any intermediate LM representation to prove that a subset of next-token candidates have non-zero probability.
In my paper βClosing the Curious Case of Neural Text Degenerationβ, we show that when a LM outputs the embedding , and we assume that the model outputs the distribution that minimizes cross entropy with the true distribution , we get the resulting relationship This is useful because we can use it as a linear constraint in order to tell whether a particualar token has nonzero true probability, i.e., if there is no solution such that then .
We were able to get this linear set of constraints by considering the gradient with respect to the final embedding , but what about earlier representations in the model? For instance, what if we consider the representation from before the final layer norm? As part of the derivation in the paper, we obtain the equality which is simply the first chain rule expansion of the model gradient with respect to . Importantly, is the Jacobian of the vector function . If we swap out the final embedding for the pre-layernorm representation , we can obtain the relation More generally, with a slight abuse of notation, let be the mapping from any intermediate model representation to the model logits. For every intermediate representation we will have
What does this mean? It means we have LOTS of linear constraints that we can add to our program. I am curious which ones will be useful, and whether we could use this to make our program more efficient. If we used ALL of the constraints, we would have an over-constained program, potentially meaning fewer or no token rejections. If there is some specific structure to the constraints we could perhaps find efficient approximations. Instead of approximating the Jacobian with SVD, could we take the Jacobian w.r.t. a subset of the representations? Instead of going straight backward through the modelβs token embeddings, we could also take the Jacobian w.r.t. representations of previous tokens. Does the Jacobian w.r.t. an earlier representation contain all the information from later representation Jacobians? Following that logic, the input embedding should contain ALL the useful information, but that seems wrong since the input embedding is static. There is a lot to think about here and Iβm not sure which pieces will be useful yet.