summaryrefslogtreecommitdiff
path: root/internal/sshd/connection_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/sshd/connection_test.go')
-rw-r--r--internal/sshd/connection_test.go20
1 files changed, 14 insertions, 6 deletions
diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go
index d6bd3c0..9b5e158 100644
--- a/internal/sshd/connection_test.go
+++ b/internal/sshd/connection_test.go
@@ -5,6 +5,8 @@ import (
"errors"
"testing"
+ "golang.org/x/sync/semaphore"
+
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)
@@ -48,7 +50,9 @@ func (f *fakeNewChannel) ExtraData() []byte {
}
func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel) {
- conn := newConnection(sessionsNum, "127.0.0.1:50000")
+ conn := &connection{
+ concurrentSessions: semaphore.NewWeighted(sessionsNum),
+ }
chans := make(chan ssh.NewChannel, 1)
chans <- newChannel
@@ -62,10 +66,11 @@ func TestPanicDuringSessionIsRecovered(t *testing.T) {
numSessions := 0
require.NotPanics(t, func() {
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
numSessions += 1
close(chans)
panic("This is a panic")
+ return nil
})
})
@@ -80,7 +85,7 @@ func TestUnknownChannelType(t *testing.T) {
conn, chans := setup(1, newChannel)
go func() {
- conn.handle(context.Background(), chans, nil)
+ conn.handleRequests(context.Background(), nil, chans, nil)
}()
rejectionData := <-rejectCh
@@ -100,8 +105,9 @@ func TestTooManySessions(t *testing.T) {
defer cancel()
go func() {
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
<-ctx.Done() // Keep the accepted channel open until the end of the test
+ return nil
})
}()
@@ -114,9 +120,10 @@ func TestAcceptSessionSucceeds(t *testing.T) {
conn, chans := setup(1, newChannel)
channelHandled := false
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
close(chans)
+ return nil
})
require.True(t, channelHandled)
@@ -132,8 +139,9 @@ func TestAcceptSessionFails(t *testing.T) {
channelHandled := false
go func() {
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
+ return nil
})
}()