summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJens Geyer <jensg@apache.org>2020-06-24 23:51:01 +0200
committerJens Geyer <jensg@apache.org>2020-06-25 22:00:52 +0200
commit6e16c2bc542657954966f5fde98d16398853582c (patch)
tree00e0bd8bc1c59efd3f9e5861c6763c842461bee5
parent283410126ccb3ac4990045e07cccb5df11ee2a16 (diff)
downloadthrift-6e16c2bc542657954966f5fde98d16398853582c.tar.gz
THRIFT-5238 GetHashCode can throw NullReferenceException
Client: netstd Patch: Jens Geyer This closes #2187
-rw-r--r--compiler/cpp/src/thrift/generate/t_netstd_generator.cc127
-rw-r--r--compiler/cpp/src/thrift/generate/t_netstd_generator.h2
-rw-r--r--lib/netstd/Tests/Thrift.Tests/DataModel/NullValuesSet.cs101
3 files changed, 158 insertions, 72 deletions
diff --git a/compiler/cpp/src/thrift/generate/t_netstd_generator.cc b/compiler/cpp/src/thrift/generate/t_netstd_generator.cc
index 0373faf48..e9c579c9f 100644
--- a/compiler/cpp/src/thrift/generate/t_netstd_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_netstd_generator.cc
@@ -1173,17 +1173,8 @@ void t_netstd_generator::generate_netstd_deepcopy_method(ostream& out, t_struct*
t_type* ttype = (*m_iter)->get_type();
string copy_op = get_deep_copy_method_call(ttype, needs_typecast);
- bool have_indent = false;
- if (!field_is_required(*m_iter)) {
- out << indent() << "if( this.__isset." << normalize_name((*m_iter)->get_name()) << ")" << endl;
- indent_up();
- have_indent = true;
- }
- else if( type_can_be_null(ttype)) {
- out << indent() << "if( this." << prop_name(*m_iter) << " != null)" << endl;
- indent_up();
- have_indent = true;
- }
+ bool is_required = field_is_required(*m_iter);
+ generate_null_check_begin( out, *m_iter);
out << indent() << tmp_instance << "." << prop_name(*m_iter) << " = ";
if( needs_typecast) {
@@ -1191,8 +1182,10 @@ void t_netstd_generator::generate_netstd_deepcopy_method(ostream& out, t_struct*
}
out << "this." << prop_name(*m_iter) << copy_op << ";" << endl;
- if (have_indent) {
- indent_down();
+ generate_null_check_end( out, *m_iter);
+ if( !is_required) {
+ out << indent() << tmp_instance << ".__isset." << normalize_name((*m_iter)->get_name())
+ << " = this.__isset." << normalize_name((*m_iter)->get_name()) << ";" << endl;
}
}
@@ -1306,6 +1299,44 @@ void t_netstd_generator::generate_netstd_struct_reader(ostream& out, t_struct* t
out << indent() << "}" << endl << endl;
}
+
+void t_netstd_generator::generate_null_check_begin(ostream& out, t_field* tfield) {
+ bool is_required = field_is_required(tfield);
+ bool null_allowed = type_can_be_null(tfield->get_type());
+
+ if( null_allowed || (!is_required)) {
+ bool first = true;
+ out << indent() << "if(";
+
+ if( null_allowed) {
+ out << "(" << prop_name(tfield) << " != null)";
+ first = false;
+ }
+
+ if( !is_required) {
+ if( !first) {
+ out << " && ";
+ }
+ out << "__isset." << normalize_name(tfield->get_name());
+ }
+
+ out << ")" << endl
+ << indent() << "{" << endl;
+ indent_up();
+ }
+}
+
+
+void t_netstd_generator::generate_null_check_end(ostream& out, t_field* tfield) {
+ bool is_required = field_is_required(tfield);
+ bool null_allowed = type_can_be_null(tfield->get_type());
+
+ if( null_allowed || (!is_required)) {
+ indent_down();
+ out << indent() << "}" << endl;
+ }
+}
+
void t_netstd_generator::generate_netstd_struct_writer(ostream& out, t_struct* tstruct)
{
out << indent() << "public async Task WriteAsync(TProtocol oprot, CancellationToken cancellationToken)" << endl
@@ -1329,23 +1360,7 @@ void t_netstd_generator::generate_netstd_struct_writer(ostream& out, t_struct* t
out << indent() << "var field = new TField();" << endl;
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter)
{
- bool is_required = field_is_required(*f_iter);
- if (!is_required)
- {
- bool null_allowed = type_can_be_null((*f_iter)->get_type());
- if (null_allowed)
- {
- out << indent() << "if (" << prop_name(*f_iter) << " != null && __isset." << normalize_name((*f_iter)->get_name()) << ")" << endl
- << indent() << "{" << endl;
- indent_up();
- }
- else
- {
- out << indent() << "if (__isset." << normalize_name((*f_iter)->get_name()) << ")" << endl
- << indent() << "{" << endl;
- indent_up();
- }
- }
+ generate_null_check_begin( out, *f_iter);
out << indent() << "field.Name = \"" << (*f_iter)->get_name() << "\";" << endl
<< indent() << "field.Type = " << type_to_enum((*f_iter)->get_type()) << ";" << endl
<< indent() << "field.ID = " << (*f_iter)->get_key() << ";" << endl
@@ -1354,11 +1369,7 @@ void t_netstd_generator::generate_netstd_struct_writer(ostream& out, t_struct* t
generate_serialize_field(out, *f_iter);
out << indent() << "await oprot.WriteFieldEndAsync(cancellationToken);" << endl;
- if (!is_required)
- {
- indent_down();
- out << indent() << "}" << endl;
- }
+ generate_null_check_end(out, *f_iter);
}
}
@@ -1482,23 +1493,8 @@ void t_netstd_generator::generate_netstd_struct_tostring(ostream& out, t_struct*
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter)
{
- bool is_required = field_is_required((*f_iter));
- if (!is_required)
- {
- bool null_allowed = type_can_be_null((*f_iter)->get_type());
- if (null_allowed)
- {
- out << indent() << "if (" << prop_name((*f_iter)) << " != null && __isset." << normalize_name((*f_iter)->get_name()) << ")" << endl
- << indent() << "{" << endl;
- indent_up();
- }
- else
- {
- out << indent() << "if (__isset." << normalize_name((*f_iter)->get_name()) << ")" << endl
- << indent() << "{" << endl;
- indent_up();
- }
- }
+ bool is_required = field_is_required(*f_iter);
+ generate_null_check_begin(out, *f_iter);
if (useFirstFlag && (!had_required))
{
@@ -1512,13 +1508,8 @@ void t_netstd_generator::generate_netstd_struct_tostring(ostream& out, t_struct*
out << indent() << prop_name(*f_iter) << ".ToString(sb);" << endl;
- if (!is_required)
- {
- indent_down();
- out << indent() << "}" << endl;
- }
- else
- {
+ generate_null_check_end(out, *f_iter);
+ if (is_required) {
had_required = true; // now __count must be > 0, so we don't need to check it anymore
}
}
@@ -1859,26 +1850,18 @@ void t_netstd_generator::generate_netstd_struct_hashcode(ostream& out, t_struct*
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter)
{
t_type* ttype = (*f_iter)->get_type();
- if (!field_is_required((*f_iter)))
- {
- out << indent() << "if(__isset." << normalize_name((*f_iter)->get_name()) << ")" << endl;
- indent_up();
- }
+
+ generate_null_check_begin(out, *f_iter);
out << indent() << "hashcode = (hashcode * 397) + ";
- if (ttype->is_container())
- {
+ if (ttype->is_container()) {
out << "TCollections.GetHashCode(" << prop_name((*f_iter)) << ")";
}
- else
- {
+ else {
out << prop_name((*f_iter)) << ".GetHashCode()";
}
out << ";" << endl;
- if (!field_is_required((*f_iter)))
- {
- indent_down();
- }
+ generate_null_check_end(out, *f_iter);
}
indent_down();
diff --git a/compiler/cpp/src/thrift/generate/t_netstd_generator.h b/compiler/cpp/src/thrift/generate/t_netstd_generator.h
index ccbd90235..94ad1619b 100644
--- a/compiler/cpp/src/thrift/generate/t_netstd_generator.h
+++ b/compiler/cpp/src/thrift/generate/t_netstd_generator.h
@@ -171,4 +171,6 @@ private:
void collect_extensions_types(t_type* ttype);
void generate_extensions(ostream& out, map<string, t_type*> types);
void reset_indent();
+ void generate_null_check_begin(ostream& out, t_field* tfield);
+ void generate_null_check_end(ostream& out, t_field* tfield);
};
diff --git a/lib/netstd/Tests/Thrift.Tests/DataModel/NullValuesSet.cs b/lib/netstd/Tests/Thrift.Tests/DataModel/NullValuesSet.cs
new file mode 100644
index 000000000..693b68ecc
--- /dev/null
+++ b/lib/netstd/Tests/Thrift.Tests/DataModel/NullValuesSet.cs
@@ -0,0 +1,101 @@
+// Licensed to the Apache Software Foundation(ASF) under one
+// or more contributor license agreements.See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Text;
+using Microsoft.VisualStudio.TestPlatform.ObjectModel;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using OptReqDefTest;
+using Thrift.Collections;
+
+namespace Thrift.Tests.DataModel
+{
+ // ReSharper disable once InconsistentNaming
+ [TestClass]
+ public class Thrift_5238
+ {
+ private void CheckInstance(RaceDetails instance)
+ {
+ // object
+ Assert.IsTrue(instance.__isset.def_nested);
+ Assert.IsTrue(instance.__isset.opt_nested);
+ Assert.IsNull(instance.Def_nested);
+ Assert.IsNull(instance.Opt_nested);
+
+ // string
+ Assert.IsTrue(instance.__isset.def_four);
+ Assert.IsTrue(instance.__isset.opt_four);
+ Assert.IsNull(instance.Req_four);
+ Assert.IsNull(instance.Def_four);
+ Assert.IsNull(instance.Opt_four);
+
+ // byte[]
+ Assert.IsTrue(instance.__isset.def_five);
+ Assert.IsTrue(instance.__isset.opt_five);
+ Assert.IsNull(instance.Req_five);
+ Assert.IsNull(instance.Def_five);
+ Assert.IsNull(instance.Opt_five);
+
+ // list<>
+ Assert.IsTrue(instance.__isset.def_six);
+ Assert.IsTrue(instance.__isset.opt_six);
+ Assert.IsNull(instance.Req_six);
+ Assert.IsNull(instance.Opt_six);
+ Assert.IsNull(instance.Def_six);
+ }
+
+ [TestMethod]
+ public void Thrift_5238_ProperNullChecks()
+ {
+ var instance = new OptReqDefTest.RaceDetails();
+
+ // object
+ instance.Def_nested = null;
+ instance.Opt_nested = null;
+
+ // string
+ instance.Req_four = null;
+ instance.Def_four = null;
+ instance.Opt_four = null;
+
+ // byte[]
+ instance.Req_five = null;
+ instance.Def_five = null;
+ instance.Opt_five = null;
+
+ // list<>
+ instance.Req_six = null;
+ instance.Opt_six = null;
+ instance.Def_six = null;
+
+ // test the setup
+ CheckInstance(instance);
+
+ // validate proper null checks , any of these throws if not
+ instance.ToString();
+ instance.GetHashCode();
+
+ // validate proper null checks , any of these throws if not
+ var copy = instance.DeepCopy();
+ CheckInstance(copy);
+ }
+
+ }
+}