def_call_impl(self, *input, **kwargs): forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward) # If we don't have any hooks, we want to skip the rest of the logic in # this function, and just call forward. ifnot (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks or _global_forward_hooks or _global_forward_pre_hooks): return forward_call(*input, **kwargs) # Do not call functions when jit is used full_backward_hooks, non_full_backward_hooks = [], [] if self._backward_hooks or _global_backward_hooks: full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() if _global_forward_pre_hooks or self._forward_pre_hooks: for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()): result = hook(self, input) if result isnotNone: ifnotisinstance(result, tuple): result = (result,) input = result
result = forward_call(*input, **kwargs) if _global_forward_hooks or self._forward_hooks: for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()): hook_result = hook(self, input, result) if hook_result isnotNone: result = hook_result
if bw_hook: result = bw_hook.setup_output_hook(result)
# Handle the non-full backward hooks if non_full_backward_hooks: var = result whilenotisinstance(var, torch.Tensor): ifisinstance(var, dict): var = next((v for v in var.values() ifisinstance(v, torch.Tensor))) else: var = var[0] grad_fn = var.grad_fn if grad_fn isnotNone: for hook in non_full_backward_hooks: wrapper = functools.partial(hook, self) functools.update_wrapper(wrapper, hook) grad_fn.register_hook(wrapper) self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
return result
好长一大段是吧?别急,我们只需关注第一行最后面那个self.foward就行了:
1
forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)