Storing model parameters for NTK analysis in Colibri#2433
Storing model parameters for NTK analysis in Colibri#2433
Conversation
n3fit/src/n3fit/io/writer.py
Outdated
| epoch = self.stopping_object.would_stop_epoch | ||
| with open(out_path, "w", encoding="utf-8") as f: | ||
| f.write(str(epoch) if epoch is not None else "None") | ||
| f.write("\n") |
There was a problem hiding this comment.
I know you said this is not meant to be merged, but I think if instead of adding a new file you add this to the final .json file, so that normally would_stop_epoch == stop_epoch, unless you have some ntk flag which then makes them different, this could very well be on the standard n3fit.
(if you store it all in the same object that also means that it should work with parallel replicas ootb)
There was a problem hiding this comment.
Thanks, this seems much better.
| -1 if self._history.final_epoch is None else self._history.final_epoch + 1 | ||
| ) | ||
| if not self._dont_stop: | ||
| self._restore_best_weights() |
There was a problem hiding this comment.
I would remove this condition. If the rest is correct, then the best weights should be the last ones. If not, there's something missing / not working as intended.
It is a good colibri down the mine
There was a problem hiding this comment.
You mean if not self._dont_stop:? In other words, always call self._restore_best_weights?
Ah no, you meant the condition for self._would_stop_epoch.
This PR implements checkpointing of the model parameters during training. Parameters are serialised in npz format as a single flattened array, which is what the n3fit module in Colibri expects.
There are other things that are meant as workarounds in order to make the serialised objects compatible with I've already implemented in Colibri for the NTK. Thus, this is a temporary solution until the n3fit module in colibri is ready. This is not meant to be merged!
I'll perform tests and post the results here.