diff options
-rw-r--r-- | git/objects/tree.py | 2 | ||||
-rw-r--r-- | git/objects/util.py | 6 | ||||
-rw-r--r-- | git/types.py | 10 |
3 files changed, 11 insertions, 7 deletions
diff --git a/git/objects/tree.py b/git/objects/tree.py index 804554d8..e168c6c4 100644 --- a/git/objects/tree.py +++ b/git/objects/tree.py @@ -323,7 +323,7 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable): super(Tree, self).traverse(predicate, prune, depth, # type: ignore branch_first, visit_once, ignore_self)) - def list_traverse(self, *args: Any, **kwargs: Any) -> IterableList[Union['Tree', 'Submodule', 'Blob']]: + def list_traverse(self, *args: Any, **kwargs: Any) -> IterableList[IndexObjUnion]: return super(Tree, self).list_traverse(* args, **kwargs) # List protocol diff --git a/git/objects/util.py b/git/objects/util.py index 4dce0aee..1c266563 100644 --- a/git/objects/util.py +++ b/git/objects/util.py @@ -19,13 +19,11 @@ import time import calendar from datetime import datetime, timedelta, tzinfo -from git.objects.base import IndexObject # just for an isinstance check - # typing ------------------------------------------------------------ from typing import (Any, Callable, Deque, Iterator, NamedTuple, overload, Sequence, TYPE_CHECKING, Tuple, Type, TypeVar, Union, cast) -from git.types import Literal +from git.types import Has_id_attribute, Literal if TYPE_CHECKING: from io import BytesIO, StringIO @@ -319,7 +317,7 @@ class Traversable(object): """ # Commit and Submodule have id.__attribute__ as IterableObj # Tree has id.__attribute__ inherited from IndexObject - if isinstance(self, (TraversableIterableObj, IndexObject)): + if isinstance(self, (TraversableIterableObj, Has_id_attribute)): id = self._id_attribute_ else: id = "" # shouldn't reach here, unless Traversable subclass created with no _id_attribute_ diff --git a/git/types.py b/git/types.py index ac1bb2c8..b107c2e1 100644 --- a/git/types.py +++ b/git/types.py @@ -11,9 +11,9 @@ if TYPE_CHECKING: from git.repo import Repo if sys.version_info[:2] >= (3, 8): - from typing import Final, Literal, SupportsIndex, TypedDict, Protocol # noqa: F401 + from typing import Final, Literal, SupportsIndex, TypedDict, Protocol, runtime_checkable # noqa: F401 else: - from typing_extensions import Final, Literal, SupportsIndex, TypedDict, Protocol # noqa: F401 + from typing_extensions import Final, Literal, SupportsIndex, TypedDict, Protocol, runtime_checkable # noqa: F401 if sys.version_info[:2] >= (3, 10): from typing import TypeGuard # noqa: F401 @@ -73,5 +73,11 @@ class HSH_TD(TypedDict): files: Dict[PathLike, Files_TD] +@runtime_checkable class Has_Repo(Protocol): repo: 'Repo' + + +@runtime_checkable +class Has_id_attribute(Protocol): + _id_attribute_: str |