diff options
Diffstat (limited to 'internal/sshd/sshd.go')
| -rw-r--r-- | internal/sshd/sshd.go | 214 |
1 files changed, 214 insertions, 0 deletions
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go new file mode 100644 index 0000000..648e29b --- /dev/null +++ b/internal/sshd/sshd.go @@ -0,0 +1,214 @@ +package sshd + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "io/ioutil" + "net" + "strconv" + "time" + + log "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/gitlab-shell/internal/command" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys" + "golang.org/x/crypto/ssh" + "golang.org/x/sync/semaphore" +) + +func Run(cfg *config.Config) error { + authorizedKeysClient, err := authorizedkeys.NewClient(cfg) + if err != nil { + return fmt.Errorf("failed to initialize GitLab client: %w", err) + } + + sshListener, err := net.Listen("tcp", cfg.Server.Listen) + if err != nil { + return fmt.Errorf("failed to listen for connection: %w", err) + } + + config := &ssh.ServerConfig{ + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + if conn.User() != cfg.User { + return nil, errors.New("unknown user") + } + if key.Type() == ssh.KeyAlgoDSA { + return nil, errors.New("DSA is prohibited") + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + res, err := authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal())) + if err != nil { + return nil, err + } + + return &ssh.Permissions{ + // Record the public key used for authentication. + Extensions: map[string]string{ + "key-id": strconv.FormatInt(res.Id, 10), + }, + }, nil + }, + } + + var loadedHostKeys uint + for _, filename := range cfg.Server.HostKeyFiles { + keyRaw, err := ioutil.ReadFile(filename) + if err != nil { + log.Warnf("Failed to read host key %v: %v", filename, err) + continue + } + key, err := ssh.ParsePrivateKey(keyRaw) + if err != nil { + log.Warnf("Failed to parse host key %v: %v", filename, err) + continue + } + loadedHostKeys++ + config.AddHostKey(key) + } + if loadedHostKeys == 0 { + return fmt.Errorf("No host keys could be loaded, aborting") + } + + for { + nconn, err := sshListener.Accept() + if err != nil { + log.Warnf("Failed to accept connection: %v\n", err) + continue + } + + go handleConn(nconn, config, cfg) + } +} + +type execRequest struct { + Command string +} + +type exitStatusReq struct { + ExitStatus uint32 +} + +type envRequest struct { + Name string + Value string +} + +func exitSession(ch ssh.Channel, exitStatus uint32) { + exitStatusReq := exitStatusReq{ + ExitStatus: exitStatus, + } + ch.CloseWrite() + ch.SendRequest("exit-status", false, ssh.Marshal(exitStatusReq)) + ch.Close() +} + +func handleConn(nconn net.Conn, sshCfg *ssh.ServerConfig, cfg *config.Config) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + defer nconn.Close() + conn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg) + if err != nil { + log.Infof("Failed to initialize SSH connection: %v", err) + return + } + + concurrentSessions := semaphore.NewWeighted(cfg.Server.ConcurrentSessionsLimit) + + go ssh.DiscardRequests(reqs) + for newChannel := range chans { + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + if !concurrentSessions.TryAcquire(1) { + newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions") + continue + } + ch, requests, err := newChannel.Accept() + if err != nil { + log.Infof("Could not accept channel: %v", err) + concurrentSessions.Release(1) + continue + } + + go handleSession(ctx, concurrentSessions, ch, requests, conn, nconn, cfg) + } +} + +func handleSession(ctx context.Context, concurrentSessions *semaphore.Weighted, ch ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn, nconn net.Conn, cfg *config.Config) { + defer concurrentSessions.Release(1) + + rw := &readwriter.ReadWriter{ + Out: ch, + In: ch, + ErrOut: ch.Stderr(), + } + var gitProtocolVersion string + + for req := range requests { + var execCmd string + switch req.Type { + case "env": + var envRequest envRequest + if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil { + ch.Close() + return + } + var accepted bool + if envRequest.Name == commandargs.GitProtocolEnv { + gitProtocolVersion = envRequest.Value + accepted = true + } + if req.WantReply { + req.Reply(accepted, []byte{}) + } + + case "exec": + var execRequest execRequest + if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil { + ch.Close() + return + } + execCmd = execRequest.Command + fallthrough + case "shell": + if req.WantReply { + req.Reply(true, []byte{}) + } + args := &commandargs.Shell{ + GitlabKeyId: conn.Permissions.Extensions["key-id"], + RemoteAddr: nconn.RemoteAddr().(*net.TCPAddr), + GitProtocolVersion: gitProtocolVersion, + } + + if err := args.ParseCommand(execCmd); err != nil { + fmt.Fprintf(ch.Stderr(), "Failed to parse command: %v\n", err.Error()) + exitSession(ch, 128) + return + } + + cmd := command.BuildShellCommand(args, cfg, rw) + if cmd == nil { + fmt.Fprintf(ch.Stderr(), "Unknown command: %v\n", args.CommandType) + exitSession(ch, 128) + return + } + if err := cmd.Execute(ctx); err != nil { + fmt.Fprintf(ch.Stderr(), "remote: ERROR: %v\n", err.Error()) + exitSession(ch, 1) + return + } + exitSession(ch, 0) + return + default: + if req.WantReply { + req.Reply(false, []byte{}) + } + } + } +} |
