summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Include/setobject.h3
-rw-r--r--Lib/test/test_set.py8
-rw-r--r--Objects/dictobject.c18
-rw-r--r--Objects/setobject.c20
4 files changed, 44 insertions, 5 deletions
diff --git a/Include/setobject.h b/Include/setobject.h
index a16c2f7cdc..750a2a8a21 100644
--- a/Include/setobject.h
+++ b/Include/setobject.h
@@ -82,7 +82,8 @@ PyAPI_FUNC(int) PySet_Clear(PyObject *set);
PyAPI_FUNC(int) PySet_Contains(PyObject *anyset, PyObject *key);
PyAPI_FUNC(int) PySet_Discard(PyObject *set, PyObject *key);
PyAPI_FUNC(int) PySet_Add(PyObject *set, PyObject *key);
-PyAPI_FUNC(int) _PySet_Next(PyObject *set, Py_ssize_t *pos, PyObject **entry);
+PyAPI_FUNC(int) _PySet_Next(PyObject *set, Py_ssize_t *pos, PyObject **key);
+PyAPI_FUNC(int) _PySet_NextEntry(PyObject *set, Py_ssize_t *pos, PyObject **key, long *hash);
PyAPI_FUNC(PyObject *) PySet_Pop(PyObject *set);
PyAPI_FUNC(int) _PySet_Update(PyObject *set, PyObject *iterable);
diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py
index 45f61b2e8d..b46cac4a44 100644
--- a/Lib/test/test_set.py
+++ b/Lib/test/test_set.py
@@ -285,10 +285,14 @@ class TestJointOps(unittest.TestCase):
s = self.thetype(d)
self.assertEqual(sum(elem.hash_count for elem in d), n)
s.difference(d)
- self.assertEqual(sum(elem.hash_count for elem in d), n)
+ self.assertEqual(sum(elem.hash_count for elem in d), n)
if hasattr(s, 'symmetric_difference_update'):
s.symmetric_difference_update(d)
- self.assertEqual(sum(elem.hash_count for elem in d), n)
+ self.assertEqual(sum(elem.hash_count for elem in d), n)
+ d2 = dict.fromkeys(set(d))
+ self.assertEqual(sum(elem.hash_count for elem in d), n)
+ d3 = dict.fromkeys(frozenset(d))
+ self.assertEqual(sum(elem.hash_count for elem in d), n)
class TestSet(TestJointOps):
thetype = set
diff --git a/Objects/dictobject.c b/Objects/dictobject.c
index 1cb3ee6ad8..acf5ae3159 100644
--- a/Objects/dictobject.c
+++ b/Objects/dictobject.c
@@ -1175,6 +1175,24 @@ dict_fromkeys(PyObject *cls, PyObject *args)
if (d == NULL)
return NULL;
+ if (PyDict_CheckExact(d) && PyAnySet_CheckExact(seq)) {
+ dictobject *mp = (dictobject *)d;
+ Py_ssize_t pos = 0;
+ PyObject *key;
+ long hash;
+
+ if (dictresize(mp, PySet_GET_SIZE(seq)))
+ return NULL;
+
+ while (_PySet_NextEntry(seq, &pos, &key, &hash)) {
+ Py_INCREF(key);
+ Py_INCREF(Py_None);
+ if (insertdict(mp, key, hash, Py_None))
+ return NULL;
+ }
+ return d;
+ }
+
it = PyObject_GetIter(seq);
if (it == NULL){
Py_DECREF(d);
diff --git a/Objects/setobject.c b/Objects/setobject.c
index 07ba99641c..a896d937fa 100644
--- a/Objects/setobject.c
+++ b/Objects/setobject.c
@@ -2137,7 +2137,7 @@ PySet_Add(PyObject *set, PyObject *key)
}
int
-_PySet_Next(PyObject *set, Py_ssize_t *pos, PyObject **entry)
+_PySet_Next(PyObject *set, Py_ssize_t *pos, PyObject **key)
{
setentry *entry_ptr;
@@ -2147,7 +2147,23 @@ _PySet_Next(PyObject *set, Py_ssize_t *pos, PyObject **entry)
}
if (set_next((PySetObject *)set, pos, &entry_ptr) == 0)
return 0;
- *entry = entry_ptr->key;
+ *key = entry_ptr->key;
+ return 1;
+}
+
+int
+_PySet_NextEntry(PyObject *set, Py_ssize_t *pos, PyObject **key, long *hash)
+{
+ setentry *entry;
+
+ if (!PyAnySet_Check(set)) {
+ PyErr_BadInternalCall();
+ return -1;
+ }
+ if (set_next((PySetObject *)set, pos, &entry) == 0)
+ return 0;
+ *key = entry->key;
+ *hash = entry->hash;
return 1;
}