diff options
author | Jens Geyer <jensg@apache.org> | 2020-06-24 23:51:01 +0200 |
---|---|---|
committer | Jens Geyer <jensg@apache.org> | 2020-06-25 22:00:52 +0200 |
commit | 6e16c2bc542657954966f5fde98d16398853582c (patch) | |
tree | 00e0bd8bc1c59efd3f9e5861c6763c842461bee5 | |
parent | 283410126ccb3ac4990045e07cccb5df11ee2a16 (diff) | |
download | thrift-6e16c2bc542657954966f5fde98d16398853582c.tar.gz |
THRIFT-5238 GetHashCode can throw NullReferenceException
Client: netstd
Patch: Jens Geyer
This closes #2187
3 files changed, 158 insertions, 72 deletions
diff --git a/compiler/cpp/src/thrift/generate/t_netstd_generator.cc b/compiler/cpp/src/thrift/generate/t_netstd_generator.cc index 0373faf48..e9c579c9f 100644 --- a/compiler/cpp/src/thrift/generate/t_netstd_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_netstd_generator.cc @@ -1173,17 +1173,8 @@ void t_netstd_generator::generate_netstd_deepcopy_method(ostream& out, t_struct* t_type* ttype = (*m_iter)->get_type(); string copy_op = get_deep_copy_method_call(ttype, needs_typecast); - bool have_indent = false; - if (!field_is_required(*m_iter)) { - out << indent() << "if( this.__isset." << normalize_name((*m_iter)->get_name()) << ")" << endl; - indent_up(); - have_indent = true; - } - else if( type_can_be_null(ttype)) { - out << indent() << "if( this." << prop_name(*m_iter) << " != null)" << endl; - indent_up(); - have_indent = true; - } + bool is_required = field_is_required(*m_iter); + generate_null_check_begin( out, *m_iter); out << indent() << tmp_instance << "." << prop_name(*m_iter) << " = "; if( needs_typecast) { @@ -1191,8 +1182,10 @@ void t_netstd_generator::generate_netstd_deepcopy_method(ostream& out, t_struct* } out << "this." << prop_name(*m_iter) << copy_op << ";" << endl; - if (have_indent) { - indent_down(); + generate_null_check_end( out, *m_iter); + if( !is_required) { + out << indent() << tmp_instance << ".__isset." << normalize_name((*m_iter)->get_name()) + << " = this.__isset." << normalize_name((*m_iter)->get_name()) << ";" << endl; } } @@ -1306,6 +1299,44 @@ void t_netstd_generator::generate_netstd_struct_reader(ostream& out, t_struct* t out << indent() << "}" << endl << endl; } + +void t_netstd_generator::generate_null_check_begin(ostream& out, t_field* tfield) { + bool is_required = field_is_required(tfield); + bool null_allowed = type_can_be_null(tfield->get_type()); + + if( null_allowed || (!is_required)) { + bool first = true; + out << indent() << "if("; + + if( null_allowed) { + out << "(" << prop_name(tfield) << " != null)"; + first = false; + } + + if( !is_required) { + if( !first) { + out << " && "; + } + out << "__isset." << normalize_name(tfield->get_name()); + } + + out << ")" << endl + << indent() << "{" << endl; + indent_up(); + } +} + + +void t_netstd_generator::generate_null_check_end(ostream& out, t_field* tfield) { + bool is_required = field_is_required(tfield); + bool null_allowed = type_can_be_null(tfield->get_type()); + + if( null_allowed || (!is_required)) { + indent_down(); + out << indent() << "}" << endl; + } +} + void t_netstd_generator::generate_netstd_struct_writer(ostream& out, t_struct* tstruct) { out << indent() << "public async Task WriteAsync(TProtocol oprot, CancellationToken cancellationToken)" << endl @@ -1329,23 +1360,7 @@ void t_netstd_generator::generate_netstd_struct_writer(ostream& out, t_struct* t out << indent() << "var field = new TField();" << endl; for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { - bool is_required = field_is_required(*f_iter); - if (!is_required) - { - bool null_allowed = type_can_be_null((*f_iter)->get_type()); - if (null_allowed) - { - out << indent() << "if (" << prop_name(*f_iter) << " != null && __isset." << normalize_name((*f_iter)->get_name()) << ")" << endl - << indent() << "{" << endl; - indent_up(); - } - else - { - out << indent() << "if (__isset." << normalize_name((*f_iter)->get_name()) << ")" << endl - << indent() << "{" << endl; - indent_up(); - } - } + generate_null_check_begin( out, *f_iter); out << indent() << "field.Name = \"" << (*f_iter)->get_name() << "\";" << endl << indent() << "field.Type = " << type_to_enum((*f_iter)->get_type()) << ";" << endl << indent() << "field.ID = " << (*f_iter)->get_key() << ";" << endl @@ -1354,11 +1369,7 @@ void t_netstd_generator::generate_netstd_struct_writer(ostream& out, t_struct* t generate_serialize_field(out, *f_iter); out << indent() << "await oprot.WriteFieldEndAsync(cancellationToken);" << endl; - if (!is_required) - { - indent_down(); - out << indent() << "}" << endl; - } + generate_null_check_end(out, *f_iter); } } @@ -1482,23 +1493,8 @@ void t_netstd_generator::generate_netstd_struct_tostring(ostream& out, t_struct* for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { - bool is_required = field_is_required((*f_iter)); - if (!is_required) - { - bool null_allowed = type_can_be_null((*f_iter)->get_type()); - if (null_allowed) - { - out << indent() << "if (" << prop_name((*f_iter)) << " != null && __isset." << normalize_name((*f_iter)->get_name()) << ")" << endl - << indent() << "{" << endl; - indent_up(); - } - else - { - out << indent() << "if (__isset." << normalize_name((*f_iter)->get_name()) << ")" << endl - << indent() << "{" << endl; - indent_up(); - } - } + bool is_required = field_is_required(*f_iter); + generate_null_check_begin(out, *f_iter); if (useFirstFlag && (!had_required)) { @@ -1512,13 +1508,8 @@ void t_netstd_generator::generate_netstd_struct_tostring(ostream& out, t_struct* out << indent() << prop_name(*f_iter) << ".ToString(sb);" << endl; - if (!is_required) - { - indent_down(); - out << indent() << "}" << endl; - } - else - { + generate_null_check_end(out, *f_iter); + if (is_required) { had_required = true; // now __count must be > 0, so we don't need to check it anymore } } @@ -1859,26 +1850,18 @@ void t_netstd_generator::generate_netstd_struct_hashcode(ostream& out, t_struct* for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { t_type* ttype = (*f_iter)->get_type(); - if (!field_is_required((*f_iter))) - { - out << indent() << "if(__isset." << normalize_name((*f_iter)->get_name()) << ")" << endl; - indent_up(); - } + + generate_null_check_begin(out, *f_iter); out << indent() << "hashcode = (hashcode * 397) + "; - if (ttype->is_container()) - { + if (ttype->is_container()) { out << "TCollections.GetHashCode(" << prop_name((*f_iter)) << ")"; } - else - { + else { out << prop_name((*f_iter)) << ".GetHashCode()"; } out << ";" << endl; - if (!field_is_required((*f_iter))) - { - indent_down(); - } + generate_null_check_end(out, *f_iter); } indent_down(); diff --git a/compiler/cpp/src/thrift/generate/t_netstd_generator.h b/compiler/cpp/src/thrift/generate/t_netstd_generator.h index ccbd90235..94ad1619b 100644 --- a/compiler/cpp/src/thrift/generate/t_netstd_generator.h +++ b/compiler/cpp/src/thrift/generate/t_netstd_generator.h @@ -171,4 +171,6 @@ private: void collect_extensions_types(t_type* ttype); void generate_extensions(ostream& out, map<string, t_type*> types); void reset_indent(); + void generate_null_check_begin(ostream& out, t_field* tfield); + void generate_null_check_end(ostream& out, t_field* tfield); }; diff --git a/lib/netstd/Tests/Thrift.Tests/DataModel/NullValuesSet.cs b/lib/netstd/Tests/Thrift.Tests/DataModel/NullValuesSet.cs new file mode 100644 index 000000000..693b68ecc --- /dev/null +++ b/lib/netstd/Tests/Thrift.Tests/DataModel/NullValuesSet.cs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation(ASF) under one +// or more contributor license agreements.See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership.The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using Microsoft.VisualStudio.TestPlatform.ObjectModel; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using OptReqDefTest; +using Thrift.Collections; + +namespace Thrift.Tests.DataModel +{ + // ReSharper disable once InconsistentNaming + [TestClass] + public class Thrift_5238 + { + private void CheckInstance(RaceDetails instance) + { + // object + Assert.IsTrue(instance.__isset.def_nested); + Assert.IsTrue(instance.__isset.opt_nested); + Assert.IsNull(instance.Def_nested); + Assert.IsNull(instance.Opt_nested); + + // string + Assert.IsTrue(instance.__isset.def_four); + Assert.IsTrue(instance.__isset.opt_four); + Assert.IsNull(instance.Req_four); + Assert.IsNull(instance.Def_four); + Assert.IsNull(instance.Opt_four); + + // byte[] + Assert.IsTrue(instance.__isset.def_five); + Assert.IsTrue(instance.__isset.opt_five); + Assert.IsNull(instance.Req_five); + Assert.IsNull(instance.Def_five); + Assert.IsNull(instance.Opt_five); + + // list<> + Assert.IsTrue(instance.__isset.def_six); + Assert.IsTrue(instance.__isset.opt_six); + Assert.IsNull(instance.Req_six); + Assert.IsNull(instance.Opt_six); + Assert.IsNull(instance.Def_six); + } + + [TestMethod] + public void Thrift_5238_ProperNullChecks() + { + var instance = new OptReqDefTest.RaceDetails(); + + // object + instance.Def_nested = null; + instance.Opt_nested = null; + + // string + instance.Req_four = null; + instance.Def_four = null; + instance.Opt_four = null; + + // byte[] + instance.Req_five = null; + instance.Def_five = null; + instance.Opt_five = null; + + // list<> + instance.Req_six = null; + instance.Opt_six = null; + instance.Def_six = null; + + // test the setup + CheckInstance(instance); + + // validate proper null checks , any of these throws if not + instance.ToString(); + instance.GetHashCode(); + + // validate proper null checks , any of these throws if not + var copy = instance.DeepCopy(); + CheckInstance(copy); + } + + } +} |