diff options
Diffstat (limited to 'internal/sshd/connection_test.go')
-rw-r--r-- | internal/sshd/connection_test.go | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go index d6bd3c0..9b5e158 100644 --- a/internal/sshd/connection_test.go +++ b/internal/sshd/connection_test.go @@ -5,6 +5,8 @@ import ( "errors" "testing" + "golang.org/x/sync/semaphore" + "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" ) @@ -48,7 +50,9 @@ func (f *fakeNewChannel) ExtraData() []byte { } func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel) { - conn := newConnection(sessionsNum, "127.0.0.1:50000") + conn := &connection{ + concurrentSessions: semaphore.NewWeighted(sessionsNum), + } chans := make(chan ssh.NewChannel, 1) chans <- newChannel @@ -62,10 +66,11 @@ func TestPanicDuringSessionIsRecovered(t *testing.T) { numSessions := 0 require.NotPanics(t, func() { - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { numSessions += 1 close(chans) panic("This is a panic") + return nil }) }) @@ -80,7 +85,7 @@ func TestUnknownChannelType(t *testing.T) { conn, chans := setup(1, newChannel) go func() { - conn.handle(context.Background(), chans, nil) + conn.handleRequests(context.Background(), nil, chans, nil) }() rejectionData := <-rejectCh @@ -100,8 +105,9 @@ func TestTooManySessions(t *testing.T) { defer cancel() go func() { - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { <-ctx.Done() // Keep the accepted channel open until the end of the test + return nil }) }() @@ -114,9 +120,10 @@ func TestAcceptSessionSucceeds(t *testing.T) { conn, chans := setup(1, newChannel) channelHandled := false - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { channelHandled = true close(chans) + return nil }) require.True(t, channelHandled) @@ -132,8 +139,9 @@ func TestAcceptSessionFails(t *testing.T) { channelHandled := false go func() { - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { channelHandled = true + return nil }) }() |