Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions fire/helptext.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import collections
import itertools
import typing

from fire import completion
from fire import custom_descriptions
Expand Down Expand Up @@ -524,14 +525,37 @@ def _GetArgType(arg, spec):
"""
if arg in spec.annotations:
arg_type = spec.annotations[arg]
return _FormatType(arg_type)
return ''


def _FormatType(tp):
"""Format a type annotation for display in help text.

Handles Optional, Union, and other typing constructs properly,
displaying e.g. Optional[str] instead of just 'Optional'.
"""
origin = typing.get_origin(tp)
args = typing.get_args(tp)

if origin is None:
# Simple type like str, int, etc.
try:
return arg_type.__qualname__
return tp.__qualname__
except AttributeError:
# Some typing objects, such as typing.Union do not have either a __name__
# or __qualname__ attribute.
# repr(typing.Union[int, str]) will return ': typing.Union[int, str]'
return repr(arg_type)
return ''
return repr(tp)

if origin is typing.Union:
# Display Optional[str] as 'str | None'
non_none = [a for a in args if a is not type(None)]
if len(non_none) == 1 and len(args) == 2:
return f'{_FormatType(non_none[0])} | None'
return ' | '.join(_FormatType(a) for a in args)

# For generic types like List[str], Dict[str, int]
args_str = ', '.join(_FormatType(a) for a in args)
origin_name = getattr(origin, '__qualname__', str(origin))
return f'{origin_name}[{args_str}]'


def _GetArgDefault(flag, spec):
Expand Down