Skip to content

Modify 🖌

Info

It is advised to read the abstractions page before making changes.

Parameters & Loop

Our code uses google/python-fire to manage parameters and repeatedly call the algorithm interface. To help you understand what fire does, here's an equivalent implementation using argparse.

if __name__ == "__main__":
    # ...
    fire.Fire(
        Trainer,
        serialize=lambda gen: (log_data for log_data in gen if "logs" in log_data and log_data["log_type"] != "train"),
    )
def parse_args() -> argparse.Namespace:
    # fmt: off
    parser = argparse.ArgumentParser()
    parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"))
    parser.add_argument("--seed", type=int, default=1)
    # ...
    args = parser.parse_args()
    # fmt: on
    return args


if __name__ == "__main__":
    # ...
    kwargs = vars(parse_args())
    trainer = Trainer(**kwargs)
    serialize = lambda gen: (log_data for log_data in gen if "logs" in log_data and log_data["log_type"] != "train")
    for log_data in serialize(trainer(**kwargs)):
        print(log_data)

Modify Algorithm

Our Algorithm is completely implemented in a single file, and we can directly modify four classes: Model📦, Algorithm👣, Agent🤖, Trainer🔁.

Our modular design does not prescribe a strict interface, and you are free to modify these four classes as long as it works. To use the features we provided (e.g. logger, model saving, model evaluation), you need to keep the Trainer🔁 interface.

Modify Feature

Writing Decorator

The generic feature is implemented as a decorator, you can refer to the code below and abcdrl/utils/wrapper_*.py file to implement the new feature you want and apply it to all algorithms.

from combine_signatures.combine_signatures import combine_signatures


def wrapper_example(
    wrapped: Callable[..., Generator[dict[str, Any], None, None]]
) -> Callable[..., Generator[dict[str, Any], None, None]]:
    @combine_signatures(wrapped)
    def _wrapper(*args, new_arg: int = 1, **kwargs) -> Generator[dict[str, Any], None, None]: # Add additional parameters
        # After initializing the Trainer, before running the algorithm
        gen = wrapped(*args, **kwargs)
        for log_data in gen:
            if "logs" in log_data and log_data["log_type"] != "train":
                # Here, control flow is modified and log data is handled
                yield log_data # Each step of the algorithm
        # After running the algorithm
    return _wrapper

Using Decorator

# Step 1:Copy the decorators you need
def wrapper_example(
    wrapped: Callable[..., Generator[dict[str, Any], None, None]]
) -> Callable[..., Generator[dict[str, Any], None, None]]:
    @combine_signatures(wrapped)
    def _wrapper(*args, new_arg: int = 1, **kwargs) -> Generator[dict[str, Any], None, None]:
        gen = wrapped(*args, **kwargs)
        for log_data in gen:
            if "logs" in log_data and log_data["log_type"] != "train":
                yield log_data
    return _wrapper


if __name__ == "__main__":
    SEED=1234
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    Trainer.__call__ = wrapper_logger(Trainer.__call__)  # type: ignore[assignment]
    # Step 2:Decorate the Trainer.__call__ function
    Trainer.__call__ = wrapper_example(Trainer.__call__)  # type: ignore[assignment]
    fire.Fire(
        Trainer,
        serialize=lambda gen: (log_data for log_data in gen if "logs" in log_data and log_data["log_type"] != "train"),
    )

Last update: 2023-01-11