From 535b8110a7ef00c2d17231828a0bd611fa3808f0 Mon Sep 17 00:00:00 2001 From: Florian Ritterhoff Date: Thu, 26 Oct 2023 08:42:20 +0200 Subject: [PATCH] chore: better handling of reconnects --- acme/mqtt/client.go | 98 ++++++++++++++++----------------- cmd/step-agent/main.go | 120 ++++++++++++++++++++--------------------- 2 files changed, 109 insertions(+), 109 deletions(-) diff --git a/acme/mqtt/client.go b/acme/mqtt/client.go index f062c7b19..655bcd3b3 100644 --- a/acme/mqtt/client.go +++ b/acme/mqtt/client.go @@ -33,8 +33,56 @@ func Connect(acmeDB acme.DB, host, user, password, organization string) (validat opts.OnConnectionLost = func(cl mqtt.Client, err error) { logrus.Println("mqtt connection lost") } - opts.OnConnect = func(mqtt.Client) { + opts.OnConnect = func(cl mqtt.Client) { logrus.Println("mqtt connection established") + go func() { + cl.Subscribe(fmt.Sprintf("%s/data", organization), 1, func(client mqtt.Client, msg mqtt.Message) { + logrus.Printf("Received message on topic: %s\nMessage: %s\n", msg.Topic(), msg.Payload()) + ctx := context.Background() + data := msg.Payload() + var payload validation.ValidationResponse + err := json.Unmarshal(data, &payload) + if err != nil { + logrus.Errorf("error unmarshalling payload: %v", err) + return + } + + ch, err := acmeDB.GetChallenge(ctx, payload.Challenge, payload.Authz) + if err != nil { + logrus.Errorf("error getting challenge: %v", err) + return + } + + acc, err := acmeDB.GetAccount(ctx, ch.AccountID) + if err != nil { + logrus.Errorf("error getting account: %v", err) + return + } + expected, err := acme.KeyAuthorization(ch.Token, acc.Key) + + if payload.Content != expected || err != nil { + logrus.Errorf("invalid key authorization: %v", err) + return + } + u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} + logrus.Infof("challenge %s validated using mqtt", u.String()) + + if ch.Status != acme.StatusPending && ch.Status != acme.StatusValid { + return + } + + ch.Status = acme.StatusValid + ch.Error = nil + ch.ValidatedAt = clock.Now().Format(time.RFC3339) + + if err = acmeDB.UpdateChallenge(ctx, ch); err != nil { + logrus.Errorf("error updating challenge: %v", err) + } else { + logrus.Infof("challenge %s updated to valid", u.String()) + } + + }) + }() } opts.OnReconnecting = func(mqtt.Client, *mqtt.ClientOptions) { logrus.Println("mqtt attempting to reconnect") @@ -47,54 +95,6 @@ func Connect(acmeDB acme.DB, host, user, password, organization string) (validat return nil, token.Error() } - go func() { - client.Subscribe(fmt.Sprintf("%s/data", organization), 1, func(client mqtt.Client, msg mqtt.Message) { - logrus.Printf("Received message on topic: %s\nMessage: %s\n", msg.Topic(), msg.Payload()) - ctx := context.Background() - data := msg.Payload() - var payload validation.ValidationResponse - err := json.Unmarshal(data, &payload) - if err != nil { - logrus.Errorf("error unmarshalling payload: %v", err) - return - } - - ch, err := acmeDB.GetChallenge(ctx, payload.Challenge, payload.Authz) - if err != nil { - logrus.Errorf("error getting challenge: %v", err) - return - } - - acc, err := acmeDB.GetAccount(ctx, ch.AccountID) - if err != nil { - logrus.Errorf("error getting account: %v", err) - return - } - expected, err := acme.KeyAuthorization(ch.Token, acc.Key) - - if payload.Content != expected || err != nil { - logrus.Errorf("invalid key authorization: %v", err) - return - } - u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} - logrus.Infof("challenge %s validated using mqtt", u.String()) - - if ch.Status != acme.StatusPending && ch.Status != acme.StatusValid { - return - } - - ch.Status = acme.StatusValid - ch.Error = nil - ch.ValidatedAt = clock.Now().Format(time.RFC3339) - - if err = acmeDB.UpdateChallenge(ctx, ch); err != nil { - logrus.Errorf("error updating challenge: %v", err) - } else { - logrus.Infof("challenge %s updated to valid", u.String()) - } - - }) - }() connection := validation.BrokerConnection{Client: client, Organization: organization} return connection, nil } diff --git a/cmd/step-agent/main.go b/cmd/step-agent/main.go index 21a83ee16..85d322a25 100644 --- a/cmd/step-agent/main.go +++ b/cmd/step-agent/main.go @@ -60,9 +60,68 @@ var agent = cli.Command{ options.OnConnectionLost = func(cl mqtt.Client, err error) { logrus.Println("mqtt connection lost") } - options.OnConnect = func(mqtt.Client) { + options.OnConnect = func(cl mqtt.Client) { logrus.Println("mqtt connection established") + // Subscribe to topic + token := cl.Subscribe(fmt.Sprintf("%s/jobs", c.String("organization")), 0, func(client mqtt.Client, msg mqtt.Message) { + logrus.Infof("received message on topic %s", msg.Topic()) + logrus.Infof("message: %s", msg.Payload()) + + var data validation.ValidationRequest + + req := msg.Payload() + json.Unmarshal(req, &data) + + logger := logrus.WithField("authz", data.Authz).WithField("target", data.Target).WithField("account", data.Challenge) + + http := acme.NewClient() + resp, err := http.Get(data.Target) + if err != nil { + logger.WithError(err).Warn("validating failed") + return + } + + defer resp.Body.Close() + if resp.StatusCode >= 400 { + logger.Warnf("validation for %s failed with error: %s", data.Target, resp.Status) + return + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + logger.WithError(err).Warn("parsing body failed") + return + } + + keyAuth := strings.TrimSpace(string(body)) + logger.Infof("keyAuth: %s", keyAuth) + + json, err := json.Marshal(&validation.ValidationResponse{ + Authz: data.Authz, + Challenge: data.Challenge, + Content: keyAuth, + }) + if err != nil { + logger.WithError(err).Warn("marshalling failed") + return + } + // Publish to topic + token := cl.Publish(fmt.Sprintf("%s/data", c.String("organization")), 0, false, json) + if token.WaitTimeout(30*time.Second) && token.Error() != nil { + logger.WithError(token.Error()).Warn("publishing failed") + } else { + logger.Infof("published to topic %s", fmt.Sprintf("%s/data", c.String("organization"))) + } + + }) + + if token.WaitTimeout(30*time.Second) && token.Error() != nil { + logrus.WithError(token.Error()).Warn("subscribing failed") + } else { + logrus.Infof("subscribed to topic %s", fmt.Sprintf("%s/jobs", c.String("organization"))) + } } + options.OnReconnecting = func(mqtt.Client, *mqtt.ClientOptions) { logrus.Println("mqtt reconnecting") } @@ -72,65 +131,6 @@ var agent = cli.Command{ logrus.Warn(token.Error()) } - // Subscribe to topic - token := client.Subscribe(fmt.Sprintf("%s/jobs", c.String("organization")), 0, func(client mqtt.Client, msg mqtt.Message) { - logrus.Infof("received message on topic %s", msg.Topic()) - logrus.Infof("message: %s", msg.Payload()) - - var data validation.ValidationRequest - - req := msg.Payload() - json.Unmarshal(req, &data) - - logger := logrus.WithField("authz", data.Authz).WithField("target", data.Target).WithField("account", data.Challenge) - - http := acme.NewClient() - resp, err := http.Get(data.Target) - if err != nil { - logger.WithError(err).Warn("validating failed") - return - } - - defer resp.Body.Close() - if resp.StatusCode >= 400 { - logger.Warnf("validation for %s failed with error: %s", data.Target, resp.Status) - return - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - logger.WithError(err).Warn("parsing body failed") - return - } - - keyAuth := strings.TrimSpace(string(body)) - logger.Infof("keyAuth: %s", keyAuth) - - json, err := json.Marshal(&validation.ValidationResponse{ - Authz: data.Authz, - Challenge: data.Challenge, - Content: keyAuth, - }) - if err != nil { - logger.WithError(err).Warn("marshalling failed") - return - } - // Publish to topic - token := client.Publish(fmt.Sprintf("%s/data", c.String("organization")), 0, false, json) - if token.WaitTimeout(30*time.Second) && token.Error() != nil { - logger.WithError(token.Error()).Warn("publishing failed") - } else { - logger.Infof("published to topic %s", fmt.Sprintf("%s/data", c.String("organization"))) - } - - }) - - if token.WaitTimeout(30*time.Second) && token.Error() != nil { - logrus.WithError(token.Error()).Warn("subscribing failed") - } else { - logrus.Infof("subscribed to topic %s", fmt.Sprintf("%s/jobs", c.String("organization"))) - } - return nil }, }