diff --git a/src/expand.rs b/src/expand.rs index 33ffd14..6c7d179 100644 --- a/src/expand.rs +++ b/src/expand.rs @@ -3,6 +3,7 @@ use crate::parse::Item; use crate::receiver::ReplaceReceiver; use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens}; +use std::mem; use syn::punctuated::Punctuated; use syn::visit_mut::VisitMut; use syn::{ @@ -28,6 +29,7 @@ enum Context<'a> { supertraits: &'a Supertraits, }, Impl { + impl_generics: &'a Generics, receiver: &'a Type, as_trait: &'a Path, }, @@ -57,6 +59,7 @@ pub fn expand(input: &mut Item) { } Item::Impl(input) => { let context = Context::Impl { + impl_generics: &input.generics, receiver: &input.self_ty, as_trait: &input.trait_.as_ref().unwrap().1, }; @@ -203,12 +206,6 @@ fn transform_sig(context: Context, sig: &mut MethodSig, has_default: bool) { // Pin::from(Box::new(async_trait_method::(self, x))) fn transform_block(context: Context, sig: &MethodSig, block: &mut Block) { let inner = Ident::new(&format!("__{}", sig.ident), sig.ident.span()); - let mut types = sig - .decl - .generics - .type_params() - .map(|param| param.ident.clone()) - .collect::>(); let args = sig .decl .inputs @@ -225,11 +222,35 @@ fn transform_block(context: Context, sig: &MethodSig, block: &mut Block) { let mut standalone = sig.clone(); standalone.ident = inner.clone(); + + let outer_generics = match context { + Context::Trait { generics, .. } => generics, + Context::Impl { impl_generics, .. } => impl_generics, + }; + let fn_generics = mem::replace(&mut standalone.decl.generics, outer_generics.clone()); + standalone.decl.generics.params.extend(fn_generics.params); + if let Some(where_clause) = fn_generics.where_clause { + standalone + .decl + .generics + .make_where_clause() + .predicates + .extend(where_clause.predicates); + } + standalone .decl .generics .params .push(parse_quote!('async_trait)); + + let mut types = standalone + .decl + .generics + .type_params() + .map(|param| param.ident.clone()) + .collect::>(); + match standalone.decl.inputs.iter_mut().next() { Some(arg @ FnArg::SelfRef(_)) => { let (lifetime, mutability) = match arg { @@ -291,9 +312,9 @@ fn transform_block(context: Context, sig: &MethodSig, block: &mut Block) { let mut replace = match context { Context::Trait { .. } => ReplaceReceiver::with(parse_quote!(AsyncTrait)), - Context::Impl { receiver, as_trait } => { - ReplaceReceiver::with_as_trait(receiver.clone(), as_trait.clone()) - } + Context::Impl { + receiver, as_trait, .. + } => ReplaceReceiver::with_as_trait(receiver.clone(), as_trait.clone()), }; replace.visit_method_sig_mut(&mut standalone); replace.visit_block_mut(block); diff --git a/tests/test.rs b/tests/test.rs index e55b68d..9bde405 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -116,6 +116,21 @@ pub async fn test_object_safe_with_default() { object.f().await; } +// https://github.com/dtolnay/async-trait/issues/1 +mod issue1 { + use async_trait::async_trait; + + #[async_trait] + trait Trait { + async fn f(&self); + } + + #[async_trait] + impl Trait for Vec { + async fn f(&self) {} + } +} + // https://github.com/dtolnay/async-trait/issues/2 mod issue2 { use async_trait::async_trait;