diff --git a/fire/helptext.py b/fire/helptext.py index 347278da..a59ed99b 100644 --- a/fire/helptext.py +++ b/fire/helptext.py @@ -33,6 +33,7 @@ import collections import itertools +import typing from fire import completion from fire import custom_descriptions @@ -524,14 +525,37 @@ def _GetArgType(arg, spec): """ if arg in spec.annotations: arg_type = spec.annotations[arg] + return _FormatType(arg_type) + return '' + + +def _FormatType(tp): + """Format a type annotation for display in help text. + + Handles Optional, Union, and other typing constructs properly, + displaying e.g. Optional[str] instead of just 'Optional'. + """ + origin = typing.get_origin(tp) + args = typing.get_args(tp) + + if origin is None: + # Simple type like str, int, etc. try: - return arg_type.__qualname__ + return tp.__qualname__ except AttributeError: - # Some typing objects, such as typing.Union do not have either a __name__ - # or __qualname__ attribute. - # repr(typing.Union[int, str]) will return ': typing.Union[int, str]' - return repr(arg_type) - return '' + return repr(tp) + + if origin is typing.Union: + # Display Optional[str] as 'str | None' + non_none = [a for a in args if a is not type(None)] + if len(non_none) == 1 and len(args) == 2: + return f'{_FormatType(non_none[0])} | None' + return ' | '.join(_FormatType(a) for a in args) + + # For generic types like List[str], Dict[str, int] + args_str = ', '.join(_FormatType(a) for a in args) + origin_name = getattr(origin, '__qualname__', str(origin)) + return f'{origin_name}[{args_str}]' def _GetArgDefault(flag, spec):