Skip to content
Discussion options

You must be logged in to vote

You need to change CLossHook's loss function to torch.nn.BCEWithLogitsLoss. The hook you're using should allow you to pass in a custom c_hook or something like that. I can help more with this if you tell me which hook you're using.

An alternative is to "monkey patch" the init function of CLossHook:

import torch
from pytorch_adapt.hooks import CLossHook

def init_modifier(method):
    def modify(self, *args, **kwargs):
        method(self, *args, **kwargs)
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
    return modify

CLossHook.__init__ = init_modifier(CLossHook.__init__)

Then every hook that uses CLossHook will now be using torch.nn.BCEWithLogitsLoss.

You might get a type error du…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@AlexandrByzov
Comment options

Answer selected by AlexandrByzov
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants