ssh/tailssh: handle not-authenticated-yet connections in matchRule
Also make more fields in conn.info thread safe, there was previously a data race here. Fixes #5110 Signed-off-by: Maisem Ali <maisem@tailscale.com>pull/5117/head
parent
41e60dae80
commit
480fd6c797
|
@ -86,8 +86,11 @@ func (ss *sshSession) newIncubatorCommand() *exec.Cmd {
|
||||||
// TODO(maisem): this doesn't work with sftp
|
// TODO(maisem): this doesn't work with sftp
|
||||||
return exec.CommandContext(ss.ctx, name, args...)
|
return exec.CommandContext(ss.ctx, name, args...)
|
||||||
}
|
}
|
||||||
|
ss.conn.mu.Lock()
|
||||||
lu := ss.conn.localUser
|
lu := ss.conn.localUser
|
||||||
ci := ss.conn.info
|
ci := ss.conn.info
|
||||||
|
gids := strings.Join(ss.conn.userGroupIDs, ",")
|
||||||
|
ss.conn.mu.Unlock()
|
||||||
remoteUser := ci.uprof.LoginName
|
remoteUser := ci.uprof.LoginName
|
||||||
if len(ci.node.Tags) > 0 {
|
if len(ci.node.Tags) > 0 {
|
||||||
remoteUser = strings.Join(ci.node.Tags, ",")
|
remoteUser = strings.Join(ci.node.Tags, ",")
|
||||||
|
@ -98,7 +101,7 @@ func (ss *sshSession) newIncubatorCommand() *exec.Cmd {
|
||||||
"ssh",
|
"ssh",
|
||||||
"--uid=" + lu.Uid,
|
"--uid=" + lu.Uid,
|
||||||
"--gid=" + lu.Gid,
|
"--gid=" + lu.Gid,
|
||||||
"--groups=" + strings.Join(ss.conn.userGroupIDs, ","),
|
"--groups=" + gids,
|
||||||
"--local-user=" + lu.Username,
|
"--local-user=" + lu.Username,
|
||||||
"--remote-user=" + remoteUser,
|
"--remote-user=" + remoteUser,
|
||||||
"--remote-ip=" + ci.src.IP().String(),
|
"--remote-ip=" + ci.src.IP().String(),
|
||||||
|
|
|
@ -141,6 +141,14 @@ func (srv *server) OnPolicyChange() {
|
||||||
srv.mu.Lock()
|
srv.mu.Lock()
|
||||||
defer srv.mu.Unlock()
|
defer srv.mu.Unlock()
|
||||||
for c := range srv.activeConns {
|
for c := range srv.activeConns {
|
||||||
|
c.mu.Lock()
|
||||||
|
ci := c.info
|
||||||
|
c.mu.Unlock()
|
||||||
|
if ci == nil {
|
||||||
|
// c.info is nil when the connection hasn't been authenticated yet.
|
||||||
|
// In that case, the connection will be terminated when it is.
|
||||||
|
continue
|
||||||
|
}
|
||||||
go c.checkStillValid()
|
go c.checkStillValid()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -155,11 +163,11 @@ type conn struct {
|
||||||
connID string // ID that's shared with control
|
connID string // ID that's shared with control
|
||||||
action0 *tailcfg.SSHAction // first matching action
|
action0 *tailcfg.SSHAction // first matching action
|
||||||
srv *server
|
srv *server
|
||||||
info *sshConnInfo // set by setInfo
|
|
||||||
localUser *user.User // set by checkAuth
|
|
||||||
userGroupIDs []string // set by checkAuth
|
|
||||||
|
|
||||||
mu sync.Mutex // protects the following
|
mu sync.Mutex // protects the following
|
||||||
|
localUser *user.User // set by checkAuth
|
||||||
|
userGroupIDs []string // set by checkAuth
|
||||||
|
info *sshConnInfo // set by setInfo
|
||||||
// idH is the RFC4253 sec8 hash H. It is used to identify the connection,
|
// idH is the RFC4253 sec8 hash H. It is used to identify the connection,
|
||||||
// and is shared among all sessions. It should not be shared outside
|
// and is shared among all sessions. It should not be shared outside
|
||||||
// process. It is confusingly referred to as SessionID by the gliderlabs/ssh
|
// process. It is confusingly referred to as SessionID by the gliderlabs/ssh
|
||||||
|
@ -179,9 +187,13 @@ func (c *conn) logf(format string, args ...any) {
|
||||||
// PublicKeyHandler implements ssh.PublicKeyHandler is called by the the
|
// PublicKeyHandler implements ssh.PublicKeyHandler is called by the the
|
||||||
// ssh.Server when the client presents a public key.
|
// ssh.Server when the client presents a public key.
|
||||||
func (c *conn) PublicKeyHandler(ctx ssh.Context, pubKey ssh.PublicKey) error {
|
func (c *conn) PublicKeyHandler(ctx ssh.Context, pubKey ssh.PublicKey) error {
|
||||||
if c.info == nil {
|
c.mu.Lock()
|
||||||
|
ci := c.info
|
||||||
|
c.mu.Unlock()
|
||||||
|
if ci == nil {
|
||||||
return gossh.ErrDenied
|
return gossh.ErrDenied
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.checkAuth(pubKey); err != nil {
|
if err := c.checkAuth(pubKey); err != nil {
|
||||||
// TODO(maisem/bradfitz): surface the error here.
|
// TODO(maisem/bradfitz): surface the error here.
|
||||||
c.logf("rejecting SSH public key %s: %v", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)), err)
|
c.logf("rejecting SSH public key %s: %v", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)), err)
|
||||||
|
@ -217,7 +229,7 @@ func (c *conn) NoClientAuthCallback(cm gossh.ConnMetadata) (*gossh.Permissions,
|
||||||
func (c *conn) checkAuth(pubKey ssh.PublicKey) error {
|
func (c *conn) checkAuth(pubKey ssh.PublicKey) error {
|
||||||
a, localUser, err := c.evaluatePolicy(pubKey)
|
a, localUser, err := c.evaluatePolicy(pubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if pubKey == nil && c.havePubKeyPolicy(c.info) {
|
if pubKey == nil && c.havePubKeyPolicy() {
|
||||||
return errPubKeyRequired
|
return errPubKeyRequired
|
||||||
}
|
}
|
||||||
return fmt.Errorf("%w: %v", gossh.ErrDenied, err)
|
return fmt.Errorf("%w: %v", gossh.ErrDenied, err)
|
||||||
|
@ -236,6 +248,8 @@ func (c *conn) checkAuth(pubKey ssh.PublicKey) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
c.userGroupIDs = gids
|
c.userGroupIDs = gids
|
||||||
c.localUser = lu
|
c.localUser = lu
|
||||||
return nil
|
return nil
|
||||||
|
@ -329,7 +343,13 @@ func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, de
|
||||||
|
|
||||||
// havePubKeyPolicy reports whether any policy rule may provide access by means
|
// havePubKeyPolicy reports whether any policy rule may provide access by means
|
||||||
// of a ssh.PublicKey.
|
// of a ssh.PublicKey.
|
||||||
func (c *conn) havePubKeyPolicy(ci *sshConnInfo) bool {
|
func (c *conn) havePubKeyPolicy() bool {
|
||||||
|
c.mu.Lock()
|
||||||
|
ci := c.info
|
||||||
|
c.mu.Unlock()
|
||||||
|
if ci == nil {
|
||||||
|
panic("havePubKeyPolicy called before setInfo")
|
||||||
|
}
|
||||||
// Is there any rule that looks like it'd require a public key for this
|
// Is there any rule that looks like it'd require a public key for this
|
||||||
// sshUser?
|
// sshUser?
|
||||||
pol, ok := c.sshPolicy()
|
pol, ok := c.sshPolicy()
|
||||||
|
@ -414,6 +434,8 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error {
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("unknown Tailscale identity from src %v", ci.src)
|
return fmt.Errorf("unknown Tailscale identity from src %v", ci.src)
|
||||||
}
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
ci.node = node
|
ci.node = node
|
||||||
ci.uprof = &uprof
|
ci.uprof = &uprof
|
||||||
|
|
||||||
|
@ -589,8 +611,10 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ss := c.newSSHSession(s)
|
ss := c.newSSHSession(s)
|
||||||
|
c.mu.Lock()
|
||||||
ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.IP(), c.localUser.Username)
|
ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.IP(), c.localUser.Username)
|
||||||
ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Username)
|
ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Username)
|
||||||
|
c.mu.Unlock()
|
||||||
ss.run()
|
ss.run()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -688,7 +712,10 @@ func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (ac
|
||||||
|
|
||||||
func (c *conn) expandDelegateURL(actionURL string) string {
|
func (c *conn) expandDelegateURL(actionURL string) string {
|
||||||
nm := c.srv.lb.NetMap()
|
nm := c.srv.lb.NetMap()
|
||||||
|
c.mu.Lock()
|
||||||
ci := c.info
|
ci := c.info
|
||||||
|
lu := c.localUser
|
||||||
|
c.mu.Unlock()
|
||||||
var dstNodeID string
|
var dstNodeID string
|
||||||
if nm != nil {
|
if nm != nil {
|
||||||
dstNodeID = fmt.Sprint(int64(nm.SelfNode.ID))
|
dstNodeID = fmt.Sprint(int64(nm.SelfNode.ID))
|
||||||
|
@ -699,7 +726,7 @@ func (c *conn) expandDelegateURL(actionURL string) string {
|
||||||
"$DST_NODE_IP", url.QueryEscape(ci.dst.IP().String()),
|
"$DST_NODE_IP", url.QueryEscape(ci.dst.IP().String()),
|
||||||
"$DST_NODE_ID", dstNodeID,
|
"$DST_NODE_ID", dstNodeID,
|
||||||
"$SSH_USER", url.QueryEscape(ci.sshUser),
|
"$SSH_USER", url.QueryEscape(ci.sshUser),
|
||||||
"$LOCAL_USER", url.QueryEscape(c.localUser.Username),
|
"$LOCAL_USER", url.QueryEscape(lu.Username),
|
||||||
).Replace(actionURL)
|
).Replace(actionURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -709,10 +736,12 @@ func (c *conn) expandPublicKeyURL(pubKeyURL string) string {
|
||||||
}
|
}
|
||||||
var localPart string
|
var localPart string
|
||||||
var loginName string
|
var loginName string
|
||||||
|
c.mu.Lock()
|
||||||
if c.info.uprof != nil {
|
if c.info.uprof != nil {
|
||||||
loginName = c.info.uprof.LoginName
|
loginName = c.info.uprof.LoginName
|
||||||
localPart, _, _ = strings.Cut(loginName, "@")
|
localPart, _, _ = strings.Cut(loginName, "@")
|
||||||
}
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
return strings.NewReplacer(
|
return strings.NewReplacer(
|
||||||
"$LOGINNAME_EMAIL", loginName,
|
"$LOGINNAME_EMAIL", loginName,
|
||||||
"$LOGINNAME_LOCALPART", localPart,
|
"$LOGINNAME_LOCALPART", localPart,
|
||||||
|
@ -768,6 +797,8 @@ func (c *conn) isStillValid() bool {
|
||||||
if !a.Accept && a.HoldAndDelegate == "" {
|
if !a.Accept && a.HoldAndDelegate == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
return c.localUser.Username == localUser
|
return c.localUser.Username == localUser
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -944,6 +975,8 @@ func (ss *sshSession) run() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ss.conn.startSessionLocked(ss)
|
ss.conn.startSessionLocked(ss)
|
||||||
|
lu := ss.conn.localUser
|
||||||
|
localUser := lu.Username
|
||||||
srv.mu.Unlock()
|
srv.mu.Unlock()
|
||||||
|
|
||||||
defer ss.conn.endSession(ss)
|
defer ss.conn.endSession(ss)
|
||||||
|
@ -959,8 +992,6 @@ func (ss *sshSession) run() {
|
||||||
}
|
}
|
||||||
|
|
||||||
logf := ss.logf
|
logf := ss.logf
|
||||||
lu := ss.conn.localUser
|
|
||||||
localUser := lu.Username
|
|
||||||
|
|
||||||
if euid := os.Geteuid(); euid != 0 {
|
if euid := os.Geteuid(); euid != 0 {
|
||||||
if lu.Uid != fmt.Sprint(euid) {
|
if lu.Uid != fmt.Sprint(euid) {
|
||||||
|
@ -1110,9 +1141,20 @@ var (
|
||||||
errRuleExpired = errors.New("rule expired")
|
errRuleExpired = errors.New("rule expired")
|
||||||
errPrincipalMatch = errors.New("principal didn't match")
|
errPrincipalMatch = errors.New("principal didn't match")
|
||||||
errUserMatch = errors.New("user didn't match")
|
errUserMatch = errors.New("user didn't match")
|
||||||
|
errInvalidConn = errors.New("invalid connection state")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, err error) {
|
func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, err error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, "", errInvalidConn
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
ci := c.info
|
||||||
|
c.mu.Unlock()
|
||||||
|
if ci == nil {
|
||||||
|
c.logf("invalid connection state")
|
||||||
|
return nil, "", errInvalidConn
|
||||||
|
}
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return nil, "", errNilRule
|
return nil, "", errNilRule
|
||||||
}
|
}
|
||||||
|
@ -1126,7 +1168,7 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg
|
||||||
// For all but Reject rules, SSHUsers is required.
|
// For all but Reject rules, SSHUsers is required.
|
||||||
// If SSHUsers is nil or empty, mapLocalUser will return an
|
// If SSHUsers is nil or empty, mapLocalUser will return an
|
||||||
// empty string anyway.
|
// empty string anyway.
|
||||||
localUser = mapLocalUser(r.SSHUsers, c.info.sshUser)
|
localUser = mapLocalUser(r.SSHUsers, ci.sshUser)
|
||||||
if localUser == "" {
|
if localUser == "" {
|
||||||
return nil, "", errUserMatch
|
return nil, "", errUserMatch
|
||||||
}
|
}
|
||||||
|
@ -1175,7 +1217,9 @@ func (c *conn) principalMatches(p *tailcfg.SSHPrincipal, pubKey gossh.PublicKey)
|
||||||
// that match the Tailscale identity match (Node, NodeIP, UserLogin, Any).
|
// that match the Tailscale identity match (Node, NodeIP, UserLogin, Any).
|
||||||
// This function does not consider PubKeys.
|
// This function does not consider PubKeys.
|
||||||
func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool {
|
func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool {
|
||||||
|
c.mu.Lock()
|
||||||
ci := c.info
|
ci := c.info
|
||||||
|
c.mu.Unlock()
|
||||||
if p.Any {
|
if p.Any {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,13 +47,26 @@ func TestMatchRule(t *testing.T) {
|
||||||
wantErr error
|
wantErr error
|
||||||
wantUser string
|
wantUser string
|
||||||
}{
|
}{
|
||||||
|
{
|
||||||
|
name: "invalid-conn",
|
||||||
|
rule: &tailcfg.SSHRule{
|
||||||
|
Action: someAction,
|
||||||
|
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||||
|
SSHUsers: map[string]string{
|
||||||
|
"*": "ubuntu",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: errInvalidConn,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "nil-rule",
|
name: "nil-rule",
|
||||||
|
ci: &sshConnInfo{},
|
||||||
rule: nil,
|
rule: nil,
|
||||||
wantErr: errNilRule,
|
wantErr: errNilRule,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "nil-action",
|
name: "nil-action",
|
||||||
|
ci: &sshConnInfo{},
|
||||||
rule: &tailcfg.SSHRule{},
|
rule: &tailcfg.SSHRule{},
|
||||||
wantErr: errNilAction,
|
wantErr: errNilAction,
|
||||||
},
|
},
|
||||||
|
@ -180,6 +193,7 @@ func TestMatchRule(t *testing.T) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
c := &conn{
|
c := &conn{
|
||||||
info: tt.ci,
|
info: tt.ci,
|
||||||
|
srv: &server{logf: t.Logf},
|
||||||
}
|
}
|
||||||
got, gotUser, err := c.matchRule(tt.rule, nil)
|
got, gotUser, err := c.matchRule(tt.rule, nil)
|
||||||
if err != tt.wantErr {
|
if err != tt.wantErr {
|
||||||
|
|
Loading…
Reference in New Issue