Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

many changes #11

Merged
merged 4 commits into from
Mar 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion admin/src/portr_admin/services/connection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from portr_admin.models.connection import Connection, ConnectionType
from portr_admin.models.connection import Connection, ConnectionStatus, ConnectionType
from portr_admin.models.user import TeamUser
from portr_admin.utils.exception import ServiceError

Expand All @@ -12,6 +12,13 @@ async def create_new_connection(
if type == ConnectionType.http and not subdomain:
raise ServiceError("subdomain is required for http connections")

if type == ConnectionType.http:
active_connection = await Connection.filter(
subdomain=subdomain, status=ConnectionStatus.active.value
).first()
if active_connection:
raise ServiceError("Subdomain already in use")

return await Connection.create(
type=type,
subdomain=subdomain if type == ConnectionType.http else None,
Expand Down
18 changes: 18 additions & 0 deletions admin/src/portr_admin/tests/api_tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,24 @@ async def test_create_new_connection_with_wrong_secret_key_should_fail(self):
assert resp.status_code == 400
assert resp.json() == {"message": "Invalid secret key"}

async def test_create_new_connection_with_active_subdomain_should_fail(self):
await ConnectionFactory.create(
type="http",
subdomain="test-subdomain",
team=self.team_user.team,
status=ConnectionStatus.active,
)
resp = self.client.post(
"/api/v1/connections/",
json={
"connection_type": "http",
"secret_key": self.team_user.secret_key,
"subdomain": "test-subdomain",
},
)
assert resp.status_code == 400
assert resp.json() == {"message": "Subdomain already in use"}

async def test_list_active_connections(self):
resp = self.team_user_client.get(
"/api/v1/connections/",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
from portr_admin.models.connection import ConnectionType
import pytest
from tortoise.contrib.test import SimpleTestCase
Expand All @@ -20,9 +20,14 @@ async def test_create_http_connection_without_subdomain_should_fail(self):
assert str(e.value) == "subdomain is required for http connections"

@patch("portr_admin.models.connection.Connection.create")
@patch("portr_admin.models.connection.Connection.filter")
async def test_create_http_connection_with_subdomain_should_succeed(
self, create_fn
self, filter_fn, create_fn
):
first_mock = AsyncMock()
first_mock.return_value = None
filter_fn.return_value.first = first_mock

await connection_service.create_new_connection(
type=ConnectionType.http,
created_by=self.team_user,
Expand Down
11 changes: 8 additions & 3 deletions tunnel/cmd/portr/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@ import (
func startTunnels(c *cli.Context, tunnelFromCli *config.Tunnel) error {
_c := client.NewClient(c.String("config"))

tunnelFromCli.SetDefaults()
var err error

if tunnelFromCli != nil {
tunnelFromCli.SetDefaults()
_c.ReplaceTunnelsFromCli(*tunnelFromCli)
_c.Start(c.Context)
err = _c.Start(c.Context)
} else {
_c.Start(c.Context, c.Args().Slice()...)
err = _c.Start(c.Context, c.Args().Slice()...)
}

if err != nil {
return err
}

signalCh := make(chan os.Signal, 1)
Expand Down
9 changes: 8 additions & 1 deletion tunnel/internal/client/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package client

import (
"context"
"fmt"
"log"
"slices"

Expand All @@ -28,7 +29,7 @@ func NewClient(configFile string) *Client {
}
}

func (c *Client) Start(ctx context.Context, services ...string) {
func (c *Client) Start(ctx context.Context, services ...string) error {
var clientConfigs []config.ClientConfig

db := db.New()
Expand All @@ -48,11 +49,17 @@ func (c *Client) Start(ctx context.Context, services ...string) {
})
}

if len(clientConfigs) == 0 {
return fmt.Errorf("please enter a valid service name")
}

for _, clientConfig := range clientConfigs {
sshc := ssh.New(clientConfig, db)
c.Add(sshc)
go sshc.Start(ctx)
}

return nil
}

func (c *Client) Add(sshc *ssh.SshClient) {
Expand Down
61 changes: 24 additions & 37 deletions tunnel/internal/client/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@ import (
"net"
"net/http"
"os"
"os/signal"
"syscall"
"time"

"github.com/amalshaji/portr/internal/client/config"
"github.com/amalshaji/portr/internal/client/db"
"github.com/amalshaji/portr/internal/constants"
"github.com/amalshaji/portr/internal/utils"
"github.com/go-resty/resty/v2"
"github.com/labstack/gommon/color"
"gorm.io/datatypes"

"github.com/oklog/ulid/v2"
Expand All @@ -32,27 +30,25 @@ var (
)

type SshClient struct {
config config.ClientConfig
listener net.Listener
log *slog.Logger
db *db.Db
connected chan bool
config config.ClientConfig
listener net.Listener
log *slog.Logger
db *db.Db
}

func New(config config.ClientConfig, db *db.Db) *SshClient {
return &SshClient{
config: config,
listener: nil,
log: slog.New(slog.NewTextHandler(os.Stdout, nil)),
db: db,
connected: make(chan bool),
config: config,
listener: nil,
log: slog.New(slog.NewTextHandler(os.Stdout, nil)),
db: db,
}
}

func (s *SshClient) createNewConnection() (string, error) {
client := resty.New()
var reqErr struct {
Detail any `json:"detail"`
Message string `json:"message"`
}
var response struct {
ConnectionId string `json:"connection_id"`
Expand All @@ -78,8 +74,10 @@ func (s *SshClient) createNewConnection() (string, error) {
}

if resp.StatusCode() != 200 {
s.log.Error("failed to create new connection", "error", reqErr)
return "", fmt.Errorf("failed to create new connection")
if s.config.Debug {
s.log.Error("failed to create new connection", "error", reqErr)
}
return "", fmt.Errorf(reqErr.Message)
}
return response.ConnectionId, nil
}
Expand Down Expand Up @@ -134,10 +132,6 @@ func (s *SshClient) startListenerForClient() error {

defer s.listener.Close()

s.connected <- true

fmt.Println()

if tunnelType == constants.Http {
fmt.Printf(
"Tunnel connected: %s -> 🌐 -> %s\n",
Expand Down Expand Up @@ -336,6 +330,10 @@ func (s *SshClient) tcpTunnel(src, dst net.Conn) {
}

func (s *SshClient) Shutdown(ctx context.Context) error {
if s.listener == nil {
return nil
}

err := s.listener.Close()
if err != nil {
return err
Expand All @@ -345,23 +343,12 @@ func (s *SshClient) Shutdown(ctx context.Context) error {
}

func (s *SshClient) Start(_ context.Context) {
utils.ShowLoading("Tunnel connecting", s.connected)

done := make(chan os.Signal, 1)
signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)

go func() {
if err := s.startListenerForClient(); err != nil {
log.Fatalf("failed to establish tunnel connection: error=%v\n", err)
}
}()
fmt.Println("Tunnel connecting...")
fmt.Println(s.config.Tunnel)

<-done
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer func() { cancel() }()
if err := s.Shutdown(ctx); err != nil {
if s.config.Debug {
s.log.Error("failed to stop tunnel client", "error", err)
}
if err := s.startListenerForClient(); err != nil {
fmt.Println()
fmt.Println(color.Red(err))
os.Exit(1)
}
}
18 changes: 0 additions & 18 deletions tunnel/internal/utils/loading.go

This file was deleted.

4 changes: 2 additions & 2 deletions tunnel/internal/utils/port.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ func GenerateRandomHttpPorts() []int {
var startPort int = 20000
var endPort int = 30000

return GenerateRandomNumbers(startPort, endPort, 100)
return GenerateRandomNumbers(startPort, endPort, 10)
}

func GenerateRandomTcpPorts() []int {
var startPort int = 30001
var endPort int = 40001

return GenerateRandomNumbers(startPort, endPort, 100)
return GenerateRandomNumbers(startPort, endPort, 10)
}
Loading