summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZebedee Nicholls <zebedee.nicholls@climate-energy-college.org>2018-09-21 19:10:07 +0200
committerZebedee Nicholls <zebedee.nicholls@climate-energy-college.org>2018-09-21 19:10:07 +0200
commit8b9d272b7cfbc8760e4930ff305e04bcf453515f (patch)
tree1242c138525e3faa756a51ff861ff5b609bbea6b
parent09102a3898fea1c4f13c11984696c3f6630e8d0d (diff)
downloadpint-8b9d272b7cfbc8760e4930ff305e04bcf453515f.tar.gz
Pass index maintenance test
Annoyingly I had to make our test csv all floats as I couldn't work out how to get the dataframe testing to work without this
-rw-r--r--pint/pandas_interface/pint_array.py27
-rw-r--r--pint/testsuite/test-data/pandas_test.csv8
-rw-r--r--pint/testsuite/test_pandas_interface.py23
3 files changed, 44 insertions, 14 deletions
diff --git a/pint/pandas_interface/pint_array.py b/pint/pandas_interface/pint_array.py
index 42ebe65..c75c81b 100644
--- a/pint/pandas_interface/pint_array.py
+++ b/pint/pandas_interface/pint_array.py
@@ -500,21 +500,34 @@ class PintDataFrameAccessor(object):
unit_col_name = df_columns.columns[level]
units = df_columns[unit_col_name]
df_columns = df_columns.drop(columns=unit_col_name)
- df_columns.values
- df_new = DataFrame({i: PintArray(Q_(df.values[:, i], unit))
+
+ df_new = DataFrame({
+ i: PintArray(Q_(df.values[:, i], unit))
for i, unit in enumerate(units.values)
})
+
df_new.columns = df_columns.index.droplevel(unit_col_name)
df_new.index = df.index
+
return df_new
def dequantify(self):
- df=self._obj
- df_columns=df.columns.to_frame()
- df_columns['units']=[str(df[col].values.data.units) for col in df.columns]
- df_new=DataFrame({ tuple(df_columns.iloc[i]) : df[col].values.data.magnitude
- for i,col in enumerate(df.columns)
+ df = self._obj
+
+ df_columns = df.columns.to_frame()
+ df_columns['units'] = [
+ str(df[col].values.data.units)
+ for col in df.columns
+ ]
+
+ df_new = DataFrame({
+ tuple(df_columns.iloc[i]): df[col].values.data.magnitude
+ for i, col in enumerate(df.columns)
})
+
+ df_new.columns.names = df.columns.names + ['unit']
+ df_new.index = df.index
+
return df_new
def to_base_units(self):
diff --git a/pint/testsuite/test-data/pandas_test.csv b/pint/testsuite/test-data/pandas_test.csv
index 93d912c..5301e81 100644
--- a/pint/testsuite/test-data/pandas_test.csv
+++ b/pint/testsuite/test-data/pandas_test.csv
@@ -1,6 +1,6 @@
speed,mech power,torque,rail pressure,fuel flow rate,fluid power
rpm,kW,N m,bar,l/min,kW
-1000,,10,1000,10,
-1100,,10,1000000000000,10,
-1200,,10,1000,10,
-1200,,10,1000,10,
+1000.0,,10.0,1000.0,10.0,
+1100.0,,10.0,1000000000000.0,10.0,
+1200.0,,10.0,1000.0,10.0,
+1200.0,,10.0,1000.0,10.0,
diff --git a/pint/testsuite/test_pandas_interface.py b/pint/testsuite/test_pandas_interface.py
index 5d607e4..0da478e 100644
--- a/pint/testsuite/test_pandas_interface.py
+++ b/pint/testsuite/test_pandas_interface.py
@@ -443,7 +443,6 @@ else:
"test-data", "pandas_test.csv"
)
-
df = pd.read_csv(test_csv, header=[0, 1])
df.columns = pd.MultiIndex.from_arrays(
[
@@ -462,9 +461,27 @@ else:
)
- df_ = df.pint.quantify(ureg, level=-1).pint.dequantify()
- pd.testing.assert_frame_equal(df, df_)
+ expected = df.copy()
+
+ # we expect the result to come back with pint names, not input
+ # names
+ def get_pint_value(in_str):
+ return str(ureg.Quantity(1, in_str).units)
+
+ units_level = [
+ i for i, name in enumerate(df.columns.names) if name == 'unit'
+ ][0]
+
+ expected.columns = df.columns.set_levels(
+ df.columns.levels[units_level].map(get_pint_value),
+ level='unit'
+ )
+
+
+ result = df.pint.quantify(ureg, level=-1).pint.dequantify()
+
+ pd.testing.assert_frame_equal(result, expected)
class TestSeriesAccessors(object):