summaryrefslogtreecommitdiff
path: root/internal/sshd/sshd.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/sshd/sshd.go')
-rw-r--r--internal/sshd/sshd.go214
1 files changed, 214 insertions, 0 deletions
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
new file mode 100644
index 0000000..648e29b
--- /dev/null
+++ b/internal/sshd/sshd.go
@@ -0,0 +1,214 @@
+package sshd
+
+import (
+ "context"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "net"
+ "strconv"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys"
+ "golang.org/x/crypto/ssh"
+ "golang.org/x/sync/semaphore"
+)
+
+func Run(cfg *config.Config) error {
+ authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
+ if err != nil {
+ return fmt.Errorf("failed to initialize GitLab client: %w", err)
+ }
+
+ sshListener, err := net.Listen("tcp", cfg.Server.Listen)
+ if err != nil {
+ return fmt.Errorf("failed to listen for connection: %w", err)
+ }
+
+ config := &ssh.ServerConfig{
+ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
+ if conn.User() != cfg.User {
+ return nil, errors.New("unknown user")
+ }
+ if key.Type() == ssh.KeyAlgoDSA {
+ return nil, errors.New("DSA is prohibited")
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ res, err := authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal()))
+ if err != nil {
+ return nil, err
+ }
+
+ return &ssh.Permissions{
+ // Record the public key used for authentication.
+ Extensions: map[string]string{
+ "key-id": strconv.FormatInt(res.Id, 10),
+ },
+ }, nil
+ },
+ }
+
+ var loadedHostKeys uint
+ for _, filename := range cfg.Server.HostKeyFiles {
+ keyRaw, err := ioutil.ReadFile(filename)
+ if err != nil {
+ log.Warnf("Failed to read host key %v: %v", filename, err)
+ continue
+ }
+ key, err := ssh.ParsePrivateKey(keyRaw)
+ if err != nil {
+ log.Warnf("Failed to parse host key %v: %v", filename, err)
+ continue
+ }
+ loadedHostKeys++
+ config.AddHostKey(key)
+ }
+ if loadedHostKeys == 0 {
+ return fmt.Errorf("No host keys could be loaded, aborting")
+ }
+
+ for {
+ nconn, err := sshListener.Accept()
+ if err != nil {
+ log.Warnf("Failed to accept connection: %v\n", err)
+ continue
+ }
+
+ go handleConn(nconn, config, cfg)
+ }
+}
+
+type execRequest struct {
+ Command string
+}
+
+type exitStatusReq struct {
+ ExitStatus uint32
+}
+
+type envRequest struct {
+ Name string
+ Value string
+}
+
+func exitSession(ch ssh.Channel, exitStatus uint32) {
+ exitStatusReq := exitStatusReq{
+ ExitStatus: exitStatus,
+ }
+ ch.CloseWrite()
+ ch.SendRequest("exit-status", false, ssh.Marshal(exitStatusReq))
+ ch.Close()
+}
+
+func handleConn(nconn net.Conn, sshCfg *ssh.ServerConfig, cfg *config.Config) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ defer nconn.Close()
+ conn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg)
+ if err != nil {
+ log.Infof("Failed to initialize SSH connection: %v", err)
+ return
+ }
+
+ concurrentSessions := semaphore.NewWeighted(cfg.Server.ConcurrentSessionsLimit)
+
+ go ssh.DiscardRequests(reqs)
+ for newChannel := range chans {
+ if newChannel.ChannelType() != "session" {
+ newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
+ continue
+ }
+ if !concurrentSessions.TryAcquire(1) {
+ newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions")
+ continue
+ }
+ ch, requests, err := newChannel.Accept()
+ if err != nil {
+ log.Infof("Could not accept channel: %v", err)
+ concurrentSessions.Release(1)
+ continue
+ }
+
+ go handleSession(ctx, concurrentSessions, ch, requests, conn, nconn, cfg)
+ }
+}
+
+func handleSession(ctx context.Context, concurrentSessions *semaphore.Weighted, ch ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn, nconn net.Conn, cfg *config.Config) {
+ defer concurrentSessions.Release(1)
+
+ rw := &readwriter.ReadWriter{
+ Out: ch,
+ In: ch,
+ ErrOut: ch.Stderr(),
+ }
+ var gitProtocolVersion string
+
+ for req := range requests {
+ var execCmd string
+ switch req.Type {
+ case "env":
+ var envRequest envRequest
+ if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil {
+ ch.Close()
+ return
+ }
+ var accepted bool
+ if envRequest.Name == commandargs.GitProtocolEnv {
+ gitProtocolVersion = envRequest.Value
+ accepted = true
+ }
+ if req.WantReply {
+ req.Reply(accepted, []byte{})
+ }
+
+ case "exec":
+ var execRequest execRequest
+ if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil {
+ ch.Close()
+ return
+ }
+ execCmd = execRequest.Command
+ fallthrough
+ case "shell":
+ if req.WantReply {
+ req.Reply(true, []byte{})
+ }
+ args := &commandargs.Shell{
+ GitlabKeyId: conn.Permissions.Extensions["key-id"],
+ RemoteAddr: nconn.RemoteAddr().(*net.TCPAddr),
+ GitProtocolVersion: gitProtocolVersion,
+ }
+
+ if err := args.ParseCommand(execCmd); err != nil {
+ fmt.Fprintf(ch.Stderr(), "Failed to parse command: %v\n", err.Error())
+ exitSession(ch, 128)
+ return
+ }
+
+ cmd := command.BuildShellCommand(args, cfg, rw)
+ if cmd == nil {
+ fmt.Fprintf(ch.Stderr(), "Unknown command: %v\n", args.CommandType)
+ exitSession(ch, 128)
+ return
+ }
+ if err := cmd.Execute(ctx); err != nil {
+ fmt.Fprintf(ch.Stderr(), "remote: ERROR: %v\n", err.Error())
+ exitSession(ch, 1)
+ return
+ }
+ exitSession(ch, 0)
+ return
+ default:
+ if req.WantReply {
+ req.Reply(false, []byte{})
+ }
+ }
+ }
+}