summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--git/objects/tree.py2
-rw-r--r--git/objects/util.py6
-rw-r--r--git/types.py10
3 files changed, 11 insertions, 7 deletions
diff --git a/git/objects/tree.py b/git/objects/tree.py
index 804554d8..e168c6c4 100644
--- a/git/objects/tree.py
+++ b/git/objects/tree.py
@@ -323,7 +323,7 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
super(Tree, self).traverse(predicate, prune, depth, # type: ignore
branch_first, visit_once, ignore_self))
- def list_traverse(self, *args: Any, **kwargs: Any) -> IterableList[Union['Tree', 'Submodule', 'Blob']]:
+ def list_traverse(self, *args: Any, **kwargs: Any) -> IterableList[IndexObjUnion]:
return super(Tree, self).list_traverse(* args, **kwargs)
# List protocol
diff --git a/git/objects/util.py b/git/objects/util.py
index 4dce0aee..1c266563 100644
--- a/git/objects/util.py
+++ b/git/objects/util.py
@@ -19,13 +19,11 @@ import time
import calendar
from datetime import datetime, timedelta, tzinfo
-from git.objects.base import IndexObject # just for an isinstance check
-
# typing ------------------------------------------------------------
from typing import (Any, Callable, Deque, Iterator, NamedTuple, overload, Sequence,
TYPE_CHECKING, Tuple, Type, TypeVar, Union, cast)
-from git.types import Literal
+from git.types import Has_id_attribute, Literal
if TYPE_CHECKING:
from io import BytesIO, StringIO
@@ -319,7 +317,7 @@ class Traversable(object):
"""
# Commit and Submodule have id.__attribute__ as IterableObj
# Tree has id.__attribute__ inherited from IndexObject
- if isinstance(self, (TraversableIterableObj, IndexObject)):
+ if isinstance(self, (TraversableIterableObj, Has_id_attribute)):
id = self._id_attribute_
else:
id = "" # shouldn't reach here, unless Traversable subclass created with no _id_attribute_
diff --git a/git/types.py b/git/types.py
index ac1bb2c8..b107c2e1 100644
--- a/git/types.py
+++ b/git/types.py
@@ -11,9 +11,9 @@ if TYPE_CHECKING:
from git.repo import Repo
if sys.version_info[:2] >= (3, 8):
- from typing import Final, Literal, SupportsIndex, TypedDict, Protocol # noqa: F401
+ from typing import Final, Literal, SupportsIndex, TypedDict, Protocol, runtime_checkable # noqa: F401
else:
- from typing_extensions import Final, Literal, SupportsIndex, TypedDict, Protocol # noqa: F401
+ from typing_extensions import Final, Literal, SupportsIndex, TypedDict, Protocol, runtime_checkable # noqa: F401
if sys.version_info[:2] >= (3, 10):
from typing import TypeGuard # noqa: F401
@@ -73,5 +73,11 @@ class HSH_TD(TypedDict):
files: Dict[PathLike, Files_TD]
+@runtime_checkable
class Has_Repo(Protocol):
repo: 'Repo'
+
+
+@runtime_checkable
+class Has_id_attribute(Protocol):
+ _id_attribute_: str