// Copyright (C) 2020-2023 Free Software Foundation, Inc.
// This file is part of GCC.
// GCC is free software; you can redistribute it and/or modify it under
// the terms of the GNU General Public License as published by the Free
// Software Foundation; either version 3, or (at your option) any later
// version.
// GCC is distributed in the hope that it will be useful, but WITHOUT ANY
// WARRANTY; without even the implied warranty of MERCHANTABILITY or
// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
// for more details.
// You should have received a copy of the GNU General Public License
// along with GCC; see the file COPYING3. If not see
// .
#include "rust-autoderef.h"
#include "rust-hir-path-probe.h"
#include "rust-hir-dot-operator.h"
#include "rust-hir-trait-resolve.h"
namespace Rust {
namespace Resolver {
static bool
resolve_operator_overload_fn (
Analysis::RustLangItem::ItemType lang_item_type, const TyTy::BaseType *ty,
TyTy::FnType **resolved_fn, HIR::ImplItem **impl_item,
Adjustment::AdjustmentType *requires_ref_adjustment);
TyTy::BaseType *
Adjuster::adjust_type (const std::vector &adjustments)
{
if (adjustments.size () == 0)
return base->clone ();
return adjustments.back ().get_expected ()->clone ();
}
Adjustment
Adjuster::try_deref_type (const TyTy::BaseType *ty,
Analysis::RustLangItem::ItemType deref_lang_item)
{
HIR::ImplItem *impl_item = nullptr;
TyTy::FnType *fn = nullptr;
Adjustment::AdjustmentType requires_ref_adjustment
= Adjustment::AdjustmentType::ERROR;
bool operator_overloaded
= resolve_operator_overload_fn (deref_lang_item, ty, &fn, &impl_item,
&requires_ref_adjustment);
if (!operator_overloaded)
{
return Adjustment::get_error ();
}
auto resolved_base = fn->get_return_type ()->clone ();
bool is_valid_type = resolved_base->get_kind () == TyTy::TypeKind::REF;
if (!is_valid_type)
return Adjustment::get_error ();
TyTy::ReferenceType *ref_base
= static_cast (resolved_base);
Adjustment::AdjustmentType adjustment_type
= Adjustment::AdjustmentType::ERROR;
switch (deref_lang_item)
{
case Analysis::RustLangItem::ItemType::DEREF:
adjustment_type = Adjustment::AdjustmentType::DEREF;
break;
case Analysis::RustLangItem::ItemType::DEREF_MUT:
adjustment_type = Adjustment::AdjustmentType::DEREF_MUT;
break;
default:
break;
}
return Adjustment::get_op_overload_deref_adjustment (adjustment_type, ty,
ref_base, fn, impl_item,
requires_ref_adjustment);
}
Adjustment
Adjuster::try_raw_deref_type (const TyTy::BaseType *ty)
{
bool is_valid_type = ty->get_kind () == TyTy::TypeKind::REF;
if (!is_valid_type)
return Adjustment::get_error ();
const TyTy::ReferenceType *ref_base
= static_cast (ty);
auto infered = ref_base->get_base ()->clone ();
return Adjustment (Adjustment::AdjustmentType::INDIRECTION, ty, infered);
}
Adjustment
Adjuster::try_unsize_type (const TyTy::BaseType *ty)
{
bool is_valid_type = ty->get_kind () == TyTy::TypeKind::ARRAY;
if (!is_valid_type)
return Adjustment::get_error ();
auto mappings = Analysis::Mappings::get ();
auto context = TypeCheckContext::get ();
const auto ref_base = static_cast (ty);
auto slice_elem = ref_base->get_element_type ();
auto slice
= new TyTy::SliceType (mappings->get_next_hir_id (), ty->get_ident ().locus,
TyTy::TyVar (slice_elem->get_ref ()));
context->insert_implicit_type (slice);
return Adjustment (Adjustment::AdjustmentType::UNSIZE, ty, slice);
}
static bool
resolve_operator_overload_fn (
Analysis::RustLangItem::ItemType lang_item_type, const TyTy::BaseType *ty,
TyTy::FnType **resolved_fn, HIR::ImplItem **impl_item,
Adjustment::AdjustmentType *requires_ref_adjustment)
{
auto context = TypeCheckContext::get ();
auto mappings = Analysis::Mappings::get ();
// look up lang item for arithmetic type
std::string associated_item_name
= Analysis::RustLangItem::ToString (lang_item_type);
DefId respective_lang_item_id = UNKNOWN_DEFID;
bool lang_item_defined
= mappings->lookup_lang_item (lang_item_type, &respective_lang_item_id);
if (!lang_item_defined)
return false;
auto segment = HIR::PathIdentSegment (associated_item_name);
auto candidates
= MethodResolver::Probe (ty, HIR::PathIdentSegment (associated_item_name),
true);
bool have_implementation_for_lang_item = !candidates.empty ();
if (!have_implementation_for_lang_item)
return false;
// multiple candidates?
if (candidates.size () > 1)
{
// error out? probably not for this case
return false;
}
// Get the adjusted self
auto candidate = *candidates.begin ();
Adjuster adj (ty);
TyTy::BaseType *adjusted_self = adj.adjust_type (candidate.adjustments);
// is this the case we are recursive
// handle the case where we are within the impl block for this
// lang_item otherwise we end up with a recursive operator overload
// such as the i32 operator overload trait
TypeCheckContextItem &fn_context = context->peek_context ();
if (fn_context.get_type () == TypeCheckContextItem::ItemType::IMPL_ITEM)
{
auto &impl_item = fn_context.get_impl_item ();
HIR::ImplBlock *parent = impl_item.first;
HIR::Function *fn = impl_item.second;
if (parent->has_trait_ref ()
&& fn->get_function_name ().compare (associated_item_name) == 0)
{
TraitReference *trait_reference
= TraitResolver::Lookup (*parent->get_trait_ref ().get ());
if (!trait_reference->is_error ())
{
TyTy::BaseType *lookup = nullptr;
bool ok = context->lookup_type (fn->get_mappings ().get_hirid (),
&lookup);
rust_assert (ok);
rust_assert (lookup->get_kind () == TyTy::TypeKind::FNDEF);
TyTy::FnType *fntype = static_cast (lookup);
rust_assert (fntype->is_method ());
bool is_lang_item_impl
= trait_reference->get_mappings ().get_defid ()
== respective_lang_item_id;
bool self_is_lang_item_self
= fntype->get_self_type ()->is_equal (*adjusted_self);
bool recursive_operator_overload
= is_lang_item_impl && self_is_lang_item_self;
if (recursive_operator_overload)
return false;
}
}
}
TyTy::BaseType *lookup_tyty = candidate.candidate.ty;
// rust only support impl item deref operator overloading ie you must have an
// impl block for it
rust_assert (candidate.candidate.type
== PathProbeCandidate::CandidateType::IMPL_FUNC);
*impl_item = candidate.candidate.item.impl.impl_item;
rust_assert (lookup_tyty->get_kind () == TyTy::TypeKind::FNDEF);
TyTy::BaseType *lookup = lookup_tyty;
TyTy::FnType *fn = static_cast (lookup);
rust_assert (fn->is_method ());
if (fn->needs_substitution ())
{
if (ty->get_kind () == TyTy::TypeKind::ADT)
{
const TyTy::ADTType *adt = static_cast (ty);
auto s = fn->get_self_type ()->get_root ();
rust_assert (s->can_eq (adt, false));
rust_assert (s->get_kind () == TyTy::TypeKind::ADT);
const TyTy::ADTType *self_adt
= static_cast (s);
// we need to grab the Self substitutions as the inherit type
// parameters for this
if (self_adt->needs_substitution ())
{
rust_assert (adt->was_substituted ());
TyTy::SubstitutionArgumentMappings used_args_in_prev_segment
= GetUsedSubstArgs::From (adt);
TyTy::SubstitutionArgumentMappings inherit_type_args
= self_adt->solve_mappings_from_receiver_for_self (
used_args_in_prev_segment);
// there may or may not be inherited type arguments
if (!inherit_type_args.is_error ())
{
// need to apply the inherited type arguments to the
// function
lookup = fn->handle_substitions (inherit_type_args);
}
}
}
else
{
rust_assert (candidate.adjustments.size () < 2);
// lets infer the params for this we could probably fix this up by
// actually just performing a substitution of a single param but this
// seems more generic i think.
//
// this is the case where we had say Foo<&Bar>> and we have derefed to
// the &Bar and we are trying to match a method self of Bar which
// requires another deref which is matched to the deref trait impl of
// &&T so this requires another reference and deref call
lookup = fn->infer_substitions (Location ());
rust_assert (lookup->get_kind () == TyTy::TypeKind::FNDEF);
fn = static_cast (lookup);
Location unify_locus = mappings->lookup_location (ty->get_ref ());
TypeCheckBase::unify_site (
ty->get_ref (), TyTy::TyWithLocation (fn->get_self_type ()),
TyTy::TyWithLocation (adjusted_self), unify_locus);
lookup = fn;
}
}
if (candidate.adjustments.size () > 0)
*requires_ref_adjustment = candidate.adjustments.at (0).get_type ();
*resolved_fn = static_cast (lookup);
return true;
}
AutoderefCycle::AutoderefCycle (bool autoderef_flag)
: autoderef_flag (autoderef_flag)
{}
AutoderefCycle::~AutoderefCycle () {}
void
AutoderefCycle::try_hook (const TyTy::BaseType &)
{}
bool
AutoderefCycle::cycle (const TyTy::BaseType *receiver)
{
const TyTy::BaseType *r = receiver;
while (true)
{
rust_debug ("autoderef try 1: {%s}", r->debug_str ().c_str ());
if (try_autoderefed (r))
return true;
// 4. deref to to 1, if cannot deref then quit
if (autoderef_flag)
return false;
// try unsize
Adjustment unsize = Adjuster::try_unsize_type (r);
if (!unsize.is_error ())
{
adjustments.push_back (unsize);
auto unsize_r = unsize.get_expected ();
rust_debug ("autoderef try unsize: {%s}",
unsize_r->debug_str ().c_str ());
if (try_autoderefed (unsize_r))
return true;
adjustments.pop_back ();
}
Adjustment deref
= Adjuster::try_deref_type (r, Analysis::RustLangItem::ItemType::DEREF);
if (!deref.is_error ())
{
auto deref_r = deref.get_expected ();
adjustments.push_back (deref);
rust_debug ("autoderef try lang-item DEREF: {%s}",
deref_r->debug_str ().c_str ());
if (try_autoderefed (deref_r))
return true;
adjustments.pop_back ();
}
Adjustment deref_mut = Adjuster::try_deref_type (
r, Analysis::RustLangItem::ItemType::DEREF_MUT);
if (!deref_mut.is_error ())
{
auto deref_r = deref_mut.get_expected ();
adjustments.push_back (deref_mut);
rust_debug ("autoderef try lang-item DEREF_MUT: {%s}",
deref_r->debug_str ().c_str ());
if (try_autoderefed (deref_r))
return true;
adjustments.pop_back ();
}
if (!deref_mut.is_error ())
{
auto deref_r = deref_mut.get_expected ();
adjustments.push_back (deref_mut);
Adjustment raw_deref = Adjuster::try_raw_deref_type (deref_r);
adjustments.push_back (raw_deref);
deref_r = raw_deref.get_expected ();
if (try_autoderefed (deref_r))
return true;
adjustments.pop_back ();
adjustments.pop_back ();
}
if (!deref.is_error ())
{
r = deref.get_expected ();
adjustments.push_back (deref);
}
Adjustment raw_deref = Adjuster::try_raw_deref_type (r);
if (raw_deref.is_error ())
return false;
r = raw_deref.get_expected ();
adjustments.push_back (raw_deref);
}
return false;
}
bool
AutoderefCycle::try_autoderefed (const TyTy::BaseType *r)
{
try_hook (*r);
// 1. try raw
if (select (*r))
return true;
// 2. try ref
TyTy::ReferenceType *r1
= new TyTy::ReferenceType (r->get_ref (), TyTy::TyVar (r->get_ref ()),
Mutability::Imm);
adjustments.push_back (
Adjustment (Adjustment::AdjustmentType::IMM_REF, r, r1));
if (select (*r1))
return true;
adjustments.pop_back ();
// 3. try mut ref
TyTy::ReferenceType *r2
= new TyTy::ReferenceType (r->get_ref (), TyTy::TyVar (r->get_ref ()),
Mutability::Mut);
adjustments.push_back (
Adjustment (Adjustment::AdjustmentType::MUT_REF, r, r2));
if (select (*r2))
return true;
adjustments.pop_back ();
return false;
}
} // namespace Resolver
} // namespace Rust