summaryrefslogtreecommitdiff
path: root/internal/sshd/server_config.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/sshd/server_config.go')
-rw-r--r--internal/sshd/server_config.go109
1 files changed, 98 insertions, 11 deletions
diff --git a/internal/sshd/server_config.go b/internal/sshd/server_config.go
index 3c1fdbf..394a9c9 100644
--- a/internal/sshd/server_config.go
+++ b/internal/sshd/server_config.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/base64"
"fmt"
+ "io/ioutil"
"os"
"strconv"
"time"
@@ -40,6 +41,7 @@ type serverConfig struct {
cfg *config.Config
hostKeys []ssh.Signer
hostKeyToCertMap map[string]*ssh.Certificate
+ trustedUserCAKeys map[string]ssh.PublicKey
authorizedKeysClient *authorizedkeys.Client
}
@@ -110,6 +112,33 @@ func parseHostCerts(hostKeys []ssh.Signer, certFiles []string) map[string]*ssh.C
return keyToCertMap
}
+func parseTrustedUserCAKeys(filename string) (map[string]ssh.PublicKey, error) {
+ keys := make(map[string]ssh.PublicKey)
+
+ if filename == "" {
+ return keys, nil
+ }
+
+ keysRaw, err := ioutil.ReadFile(filename)
+ if err != nil {
+ log.WithError(err).WithFields(log.Fields{"filename": filename}).Warn("failed to read trusted user keys")
+ return keys, err
+ }
+
+ for len(keysRaw) > 0 {
+ publicKey, _, _, rest, err := ssh.ParseAuthorizedKey(keysRaw)
+ if err != nil {
+ log.WithError(err).WithFields(log.Fields{"filename": filename}).Warn("failed to parse trusted user keys")
+ return keys, err
+ }
+
+ keys[string(publicKey.Marshal())] = publicKey
+ keysRaw = rest
+ }
+
+ return keys, nil
+}
+
func newServerConfig(cfg *config.Config) (*serverConfig, error) {
authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
if err != nil {
@@ -122,8 +151,19 @@ func newServerConfig(cfg *config.Config) (*serverConfig, error) {
}
hostKeyToCertMap := parseHostCerts(hostKeys, cfg.Server.HostCertFiles)
+ trustedUserCAKeys, err := parseTrustedUserCAKeys(cfg.Server.TrustedUserCAKeys)
+ if err != nil {
+ return nil, fmt.Errorf("failed to load trusted user keys")
+ }
- return &serverConfig{cfg: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys, hostKeyToCertMap: hostKeyToCertMap}, nil
+ return &serverConfig{
+ cfg: cfg,
+ authorizedKeysClient: authorizedKeysClient,
+ hostKeys: hostKeys,
+ hostKeyToCertMap: hostKeyToCertMap,
+ trustedUserCAKeys: trustedUserCAKeys,
+ },
+ nil
}
func (s *serverConfig) getAuthKey(ctx context.Context, user string, key ssh.PublicKey) (*authorizedkeys.Response, error) {
@@ -145,6 +185,57 @@ func (s *serverConfig) getAuthKey(ctx context.Context, user string, key ssh.Publ
return res, nil
}
+func (s *serverConfig) handleUserKey(ctx context.Context, user string, key ssh.PublicKey) (*ssh.Permissions, error) {
+ res, err := s.getAuthKey(ctx, user, key)
+ if err != nil {
+ return nil, err
+ }
+
+ return &ssh.Permissions{
+ // Record the public key used for authentication.
+ Extensions: map[string]string{
+ "key-id": strconv.FormatInt(res.Id, 10),
+ },
+ }, nil
+}
+
+func (s *serverConfig) validUserCertificate(cert *ssh.Certificate) bool {
+ if cert.CertType != ssh.UserCert {
+ return false
+ }
+
+ publicKey := s.trustedUserCAKeys[string(cert.SignatureKey.Marshal())]
+ if publicKey == nil {
+ return false
+ }
+
+ return true
+}
+
+func (s *serverConfig) handleUserCertificate(user string, cert *ssh.Certificate) (*ssh.Permissions, error) {
+ logger := log.WithFields(log.Fields{
+ "ssh_user": user,
+ "certificate_identity": cert.KeyId,
+ "public_key_fingerprint": ssh.FingerprintSHA256(cert.Key),
+ "signing_ca_fingerprint": ssh.FingerprintSHA256(cert.SignatureKey),
+ })
+
+ if !s.validUserCertificate(cert) {
+ logger.Warn("user certificate not signed by trusted key")
+ return nil, fmt.Errorf("user certificate not signed by trusted key")
+ }
+
+ logger.Info("user certificate is valid")
+
+ // The gitlab-shell commands will make an internal API call to /discover
+ // to look up the username, so unlike the SSH key case we don't need to do it here.
+ return &ssh.Permissions{
+ Extensions: map[string]string{
+ "gitlab-username": cert.KeyId,
+ },
+ }, nil
+}
+
func (s *serverConfig) get(ctx context.Context) *ssh.ServerConfig {
var gssapiWithMICConfig *ssh.GSSAPIWithMICConfig
if s.cfg.Server.GSSAPI.Enabled {
@@ -168,17 +259,13 @@ func (s *serverConfig) get(ctx context.Context) *ssh.ServerConfig {
}
sshCfg := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
- res, err := s.getAuthKey(ctx, conn.User(), key)
- if err != nil {
- return nil, err
- }
+ cert, ok := key.(*ssh.Certificate)
- return &ssh.Permissions{
- // Record the public key used for authentication.
- Extensions: map[string]string{
- "key-id": strconv.FormatInt(res.Id, 10),
- },
- }, nil
+ if !ok {
+ return s.handleUserKey(ctx, conn.User(), key)
+ } else {
+ return s.handleUserCertificate(conn.User(), cert)
+ }
},
GSSAPIWithMICConfig: gssapiWithMICConfig,
ServerVersion: "SSH-2.0-GitLab-SSHD",