8000 Negative sampling · Issue #21 · intfloat/SimKGC · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Negative sampling #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
whistle9 opened this issue Feb 28, 2023 · 6 comments
Open

Negative sampling #21

whistle9 opened this issue Feb 28, 2023 · 6 comments

Comments

@whistle9
Copy link

I want to ask a very simple but important question for me, that is, as a novice, I can't accurately find the three negative sampling parts in the code. Can you help me point out? I need this answer very much. I would be very excited if you can give me an answer

@intfloat
Copy link
Owner

The in-batch negatives simply use samples from the same batch:

logits = hr_vector.mm(tail_vector.t())

The pre-batch negatives use cached negatives from previous batches:

SimKGC/models.py

Lines 117 to 133 in 97cc43e

def _compute_pre_batch_logits(self, hr_vector: torch.tensor,
tail_vector: torch.tensor,
batch_dict: dict) -> torch.tensor:
assert tail_vector.size(0) == self.batch_size
batch_exs = batch_dict['batch_data']
# batch_size x num_neg
pre_batch_logits = hr_vector.mm(self.pre_batch_vectors.clone().t())
pre_batch_logits *= self.log_inv_t.exp() * self.args.pre_batch_weight
if self.pre_batch_exs[-1] is not None:
pre_triplet_mask = construct_mask(batch_exs, self.pre_batch_exs).to(hr_vector.device)
pre_batch_logits.masked_fill_(~pre_triplet_mask, -1e4)
self.pre_batch_vectors[self.offset:(self.offset + self.batch_size)] = tail_vector.data.clone()
self.pre_batch_exs[self.offset:(self.offset + self.batch_size)] = batch_exs
self.offset = (self.offset + self.batch_size) % len(self.pre_batch_exs)
return pre_batch_logits

The self-negatives use the head entities of each example:

SimKGC/models.py

Lines 104 to 109 in 97cc43e

if self.args.use_self_negative and self.training:
head_vector = output_dict['head_vector']
self_neg_logits = torch.sum(hr_vector * head_vector, dim=1) * self.log_inv_t.exp()
self_negative_mask = batch_dict['self_negative_mask']
self_neg_logits.masked_fill_(~self_negative_mask, -1e4)
logits = torch.cat([logits, self_neg_logits.unsqueeze(1)], dim=-1)

@whistle9
Copy link
Author
whistle9 commented Mar 3, 2023

Thank you for your detailed answer. Is all the parameters used in the self-negative sampling method fixed? If I want to use this method in other frameworks, do the parameters used here also have to be?

@intfloat
Copy link
Owner
intfloat commented Mar 5, 2023

Not sure what you are referring to... Self-negative sampling does not add any new parameters.

@whistle9
Copy link
Author

Sorry, I didn't make it clear. I didn't communicate with you in a timely manner due to some issues.
I mean the parameters in the self negative sampling of lines 104-109 of the models.py file, such as head_ vector、self_ neg_ logits... and logits.Are these all necessary parameters?
image
In addition, I found that the negative sampling part belongs to the compute_ logits function has only been called twice by trainer.py (lines 114/154), respectively eval_epoch function, train_ epoch function, I don't quite understand this part.Could you provide me with an answer? Thank you and look forward to your reply.".
image
image

@qiaoxiaoqiao1
Copy link

Not sure what you are referring to... Self-negative sampling does not add any new parameters.

May I know why you set -1e4 for masked_fill?

@intfloat
Copy link
Owner
intfloat commented Apr 22, 2023

Not sure what you are referring to... Self-negative sampling does not add any new parameters.

May I know why you set -1e4 for masked_fill?

Any number that is small enough should be able to mask the probability to 0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants
0