Skip to content

Commit

Permalink
some fixes, and adding validator back in
Browse files Browse the repository at this point in the history
  • Loading branch information
metachris committed Nov 19, 2024
1 parent e8bf254 commit 22792e4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
/measurements.json
/build/
/quotes/
/builder-cert.pem
45 changes: 35 additions & 10 deletions cmd/attested-get/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ import (

"github.com/flashbots/cvm-reverse-proxy/common"
"github.com/flashbots/cvm-reverse-proxy/internal/atls"
azure_tdx "github.com/flashbots/cvm-reverse-proxy/internal/attestation/azure/tdx"
"github.com/flashbots/cvm-reverse-proxy/internal/attestation/measurements"
"github.com/flashbots/cvm-reverse-proxy/internal/attestation/variant"
"github.com/flashbots/cvm-reverse-proxy/internal/config"
"github.com/flashbots/cvm-reverse-proxy/proxy"
"github.com/urfave/cli/v2" // imports as package "cli"
)
Expand All @@ -50,6 +53,11 @@ var flags []cli.Flag = []cli.Flag{
Value: "",
Usage: "Output file for the response payload",
},
&cli.StringFlag{
Name: "attestation-type",
Value: string(proxy.AttestationAzureTDX),
Usage: "type of attestation to present (currently only azure-tdx)",
},
&cli.BoolFlag{
Name: "log-debug",
Value: false,
Expand All @@ -75,6 +83,7 @@ func runClient(cCtx *cli.Context) (err error) {
addr := cCtx.String("addr")
outMeasurements := cCtx.String("out-measurements")
outResponse := cCtx.String("out-response")
attestationTypeStr := cCtx.String("attestation-type")

// Setup logging
log := common.SetupLogger(&common.LoggingOpts{
Expand All @@ -88,7 +97,17 @@ func runClient(cCtx *cli.Context) (err error) {
return errors.New("address needs to start with https://")
}

log.Info("Getting verified measurements from " + addr + " ...")
attestationType, err := proxy.ParseAttestationType(attestationTypeStr)
if err != nil {
log.With("attestation-type", attestationType).Error("invalid attestation-type passed, see --help")
return err
}
if attestationType != proxy.AttestationAzureTDX {
log.Error("currently only azure-tdx attestation is supported")
return errors.New("currently only azure-tdx attestation is supported")
}

log.Info("Executing attested GET request to " + addr + " ...")

// Prepare aTLS stuff
issuer, err := proxy.CreateAttestationIssuer(log, proxy.AttestationAzureTDX)
Expand All @@ -97,24 +116,31 @@ func runClient(cCtx *cli.Context) (err error) {
return err
}

tlsConfig, err := atls.CreateAttestationClientTLSConfig(issuer, []atls.Validator{})
// Prepare an azure-tdx validator without any required measurements
attConfig := config.DefaultForAzureTDX()
attConfig.SetMeasurements(measurements.M{})
validator := azure_tdx.NewValidator(attConfig, proxy.AttestationLogger{Log: log})

// Create the (a)TLS config
tlsConfig, err := atls.CreateAttestationClientTLSConfig(issuer, []atls.Validator{validator})
if err != nil {
log.Error("could not create atls config", "err", err)
return err
}

tr := &http.Transport{
// Prepare the client
client := &http.Client{Transport: &http.Transport{
TLSClientConfig: tlsConfig,
}
client := &http.Client{Transport: tr}
}}

// Execute the GET request
resp, err := client.Get(addr)
if err != nil {
return err
}
certs := resp.TLS.PeerCertificates

// Extract the aTLS variant and measurements from the TLS connection
atlsVariant, extractedMeasurements, err := proxy.GetMeasurementsFromTLS(certs, []asn1.ObjectIdentifier{variant.AzureTDX{}.OID()})
atlsVariant, extractedMeasurements, err := proxy.GetMeasurementsFromTLS(resp.TLS.PeerCertificates, []asn1.ObjectIdentifier{variant.AzureTDX{}.OID()})
if err != nil {
log.Error("Error in getMeasurementsFromTLS", "err", err)
return err
Expand All @@ -130,11 +156,10 @@ func runClient(cCtx *cli.Context) (err error) {
return errors.New("could not marshal measurement extracted from tls extension")
}

log.Info("Variant: " + atlsVariant.String())
log.Info(fmt.Sprintf("Measurements for %s with %d entries:", atlsVariant.String(), len(measurementsInHeaderFormat)))
fmt.Println(string(marshaledPcrs))
if outMeasurements != "" {
if err := os.WriteFile(outMeasurements, marshaledPcrs, 0644); err != nil {
if err := os.WriteFile(outMeasurements, marshaledPcrs, 0o644); err != nil {
return err
}
}
Expand All @@ -148,7 +173,7 @@ func runClient(cCtx *cli.Context) (err error) {
log.Info(fmt.Sprintf("Response body with %d bytes:", len(msg)))
fmt.Println(string(msg))
if outResponse != "" {
if err := os.WriteFile(outResponse, msg, 0644); err != nil {
if err := os.WriteFile(outResponse, msg, 0o644); err != nil {
return err
}
}
Expand Down

0 comments on commit 22792e4

Please sign in to comment.