// Copyright 2012 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package agent import ( "crypto/dsa" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rsa" "encoding/binary" "errors" "fmt" "io" "log" "math/big" "github.com/Neur0toxine/sshpoke/pkg/proto/ssh" ) // server wraps an Agent and uses it to implement the agent side of // the SSH-agent, wire protocol. type server struct { agent Agent } func (s *server) processRequestBytes(reqData []byte) []byte { rep, err := s.processRequest(reqData) if err != nil { if err != errLocked { // TODO(hanwen): provide better logging interface? log.Printf("agent %d: %v", reqData[0], err) } return []byte{agentFailure} } if err == nil && rep == nil { return []byte{agentSuccess} } return ssh.Marshal(rep) } func marshalKey(k *Key) []byte { var record struct { Blob []byte Comment string } record.Blob = k.Marshal() record.Comment = k.Comment return ssh.Marshal(&record) } // See [PROTOCOL.agent], section 2.5.1. const agentV1IdentitiesAnswer = 2 type agentV1IdentityMsg struct { Numkeys uint32 `sshtype:"2"` } type agentRemoveIdentityMsg struct { KeyBlob []byte `sshtype:"18"` } type agentLockMsg struct { Passphrase []byte `sshtype:"22"` } type agentUnlockMsg struct { Passphrase []byte `sshtype:"23"` } func (s *server) processRequest(data []byte) (interface{}, error) { switch data[0] { case agentRequestV1Identities: return &agentV1IdentityMsg{0}, nil case agentRemoveAllV1Identities: return nil, nil case agentRemoveIdentity: var req agentRemoveIdentityMsg if err := ssh.Unmarshal(data, &req); err != nil { return nil, err } var wk wireKey if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil { return nil, err } return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob}) case agentRemoveAllIdentities: return nil, s.agent.RemoveAll() case agentLock: var req agentLockMsg if err := ssh.Unmarshal(data, &req); err != nil { return nil, err } return nil, s.agent.Lock(req.Passphrase) case agentUnlock: var req agentUnlockMsg if err := ssh.Unmarshal(data, &req); err != nil { return nil, err } return nil, s.agent.Unlock(req.Passphrase) case agentSignRequest: var req signRequestAgentMsg if err := ssh.Unmarshal(data, &req); err != nil { return nil, err } var wk wireKey if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil { return nil, err } k := &Key{ Format: wk.Format, Blob: req.KeyBlob, } var sig *ssh.Signature var err error if extendedAgent, ok := s.agent.(ExtendedAgent); ok { sig, err = extendedAgent.SignWithFlags(k, req.Data, SignatureFlags(req.Flags)) } else { sig, err = s.agent.Sign(k, req.Data) } if err != nil { return nil, err } return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil case agentRequestIdentities: keys, err := s.agent.List() if err != nil { return nil, err } rep := identitiesAnswerAgentMsg{ NumKeys: uint32(len(keys)), } for _, k := range keys { rep.Keys = append(rep.Keys, marshalKey(k)...) } return rep, nil case agentAddIDConstrained, agentAddIdentity: return nil, s.insertIdentity(data) case agentExtension: // Return a stub object where the whole contents of the response gets marshaled. var responseStub struct { Rest []byte `ssh:"rest"` } if extendedAgent, ok := s.agent.(ExtendedAgent); !ok { // If this agent doesn't implement extensions, [PROTOCOL.agent] section 4.7 // requires that we return a standard SSH_AGENT_FAILURE message. responseStub.Rest = []byte{agentFailure} } else { var req extensionAgentMsg if err := ssh.Unmarshal(data, &req); err != nil { return nil, err } res, err := extendedAgent.Extension(req.ExtensionType, req.Contents) if err != nil { // If agent extensions are unsupported, return a standard SSH_AGENT_FAILURE // message as required by [PROTOCOL.agent] section 4.7. if err == ErrExtensionUnsupported { responseStub.Rest = []byte{agentFailure} } else { // As the result of any other error processing an extension request, // [PROTOCOL.agent] section 4.7 requires that we return a // SSH_AGENT_EXTENSION_FAILURE code. responseStub.Rest = []byte{agentExtensionFailure} } } else { if len(res) == 0 { return nil, nil } responseStub.Rest = res } } return responseStub, nil } return nil, fmt.Errorf("unknown opcode %d", data[0]) } func parseConstraints(constraints []byte) (lifetimeSecs uint32, confirmBeforeUse bool, extensions []ConstraintExtension, err error) { for len(constraints) != 0 { switch constraints[0] { case agentConstrainLifetime: lifetimeSecs = binary.BigEndian.Uint32(constraints[1:5]) constraints = constraints[5:] case agentConstrainConfirm: confirmBeforeUse = true constraints = constraints[1:] case agentConstrainExtension, agentConstrainExtensionV00: var msg constrainExtensionAgentMsg if err = ssh.Unmarshal(constraints, &msg); err != nil { return 0, false, nil, err } extensions = append(extensions, ConstraintExtension{ ExtensionName: msg.ExtensionName, ExtensionDetails: msg.ExtensionDetails, }) constraints = msg.Rest default: return 0, false, nil, fmt.Errorf("unknown constraint type: %d", constraints[0]) } } return } func setConstraints(key *AddedKey, constraintBytes []byte) error { lifetimeSecs, confirmBeforeUse, constraintExtensions, err := parseConstraints(constraintBytes) if err != nil { return err } key.LifetimeSecs = lifetimeSecs key.ConfirmBeforeUse = confirmBeforeUse key.ConstraintExtensions = constraintExtensions return nil } func parseRSAKey(req []byte) (*AddedKey, error) { var k rsaKeyMsg if err := ssh.Unmarshal(req, &k); err != nil { return nil, err } if k.E.BitLen() > 30 { return nil, errors.New("agent: RSA public exponent too large") } priv := &rsa.PrivateKey{ PublicKey: rsa.PublicKey{ E: int(k.E.Int64()), N: k.N, }, D: k.D, Primes: []*big.Int{k.P, k.Q}, } priv.Precompute() addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments} if err := setConstraints(addedKey, k.Constraints); err != nil { return nil, err } return addedKey, nil } func parseEd25519Key(req []byte) (*AddedKey, error) { var k ed25519KeyMsg if err := ssh.Unmarshal(req, &k); err != nil { return nil, err } priv := ed25519.PrivateKey(k.Priv) addedKey := &AddedKey{PrivateKey: &priv, Comment: k.Comments} if err := setConstraints(addedKey, k.Constraints); err != nil { return nil, err } return addedKey, nil } func parseDSAKey(req []byte) (*AddedKey, error) { var k dsaKeyMsg if err := ssh.Unmarshal(req, &k); err != nil { return nil, err } priv := &dsa.PrivateKey{ PublicKey: dsa.PublicKey{ Parameters: dsa.Parameters{ P: k.P, Q: k.Q, G: k.G, }, Y: k.Y, }, X: k.X, } addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments} if err := setConstraints(addedKey, k.Constraints); err != nil { return nil, err } return addedKey, nil } func unmarshalECDSA(curveName string, keyBytes []byte, privScalar *big.Int) (priv *ecdsa.PrivateKey, err error) { priv = &ecdsa.PrivateKey{ D: privScalar, } switch curveName { case "nistp256": priv.Curve = elliptic.P256() case "nistp384": priv.Curve = elliptic.P384() case "nistp521": priv.Curve = elliptic.P521() default: return nil, fmt.Errorf("agent: unknown curve %q", curveName) } priv.X, priv.Y = elliptic.Unmarshal(priv.Curve, keyBytes) if priv.X == nil || priv.Y == nil { return nil, errors.New("agent: point not on curve") } return priv, nil } func parseEd25519Cert(req []byte) (*AddedKey, error) { var k ed25519CertMsg if err := ssh.Unmarshal(req, &k); err != nil { return nil, err } pubKey, err := ssh.ParsePublicKey(k.CertBytes) if err != nil { return nil, err } priv := ed25519.PrivateKey(k.Priv) cert, ok := pubKey.(*ssh.Certificate) if !ok { return nil, errors.New("agent: bad ED25519 certificate") } addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments} if err := setConstraints(addedKey, k.Constraints); err != nil { return nil, err } return addedKey, nil } func parseECDSAKey(req []byte) (*AddedKey, error) { var k ecdsaKeyMsg if err := ssh.Unmarshal(req, &k); err != nil { return nil, err } priv, err := unmarshalECDSA(k.Curve, k.KeyBytes, k.D) if err != nil { return nil, err } addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments} if err := setConstraints(addedKey, k.Constraints); err != nil { return nil, err } return addedKey, nil } func parseRSACert(req []byte) (*AddedKey, error) { var k rsaCertMsg if err := ssh.Unmarshal(req, &k); err != nil { return nil, err } pubKey, err := ssh.ParsePublicKey(k.CertBytes) if err != nil { return nil, err } cert, ok := pubKey.(*ssh.Certificate) if !ok { return nil, errors.New("agent: bad RSA certificate") } // An RSA publickey as marshaled by rsaPublicKey.Marshal() in keys.go var rsaPub struct { Name string E *big.Int N *big.Int } if err := ssh.Unmarshal(cert.Key.Marshal(), &rsaPub); err != nil { return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err) } if rsaPub.E.BitLen() > 30 { return nil, errors.New("agent: RSA public exponent too large") } priv := rsa.PrivateKey{ PublicKey: rsa.PublicKey{ E: int(rsaPub.E.Int64()), N: rsaPub.N, }, D: k.D, Primes: []*big.Int{k.Q, k.P}, } priv.Precompute() addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments} if err := setConstraints(addedKey, k.Constraints); err != nil { return nil, err } return addedKey, nil } func parseDSACert(req []byte) (*AddedKey, error) { var k dsaCertMsg if err := ssh.Unmarshal(req, &k); err != nil { return nil, err } pubKey, err := ssh.ParsePublicKey(k.CertBytes) if err != nil { return nil, err } cert, ok := pubKey.(*ssh.Certificate) if !ok { return nil, errors.New("agent: bad DSA certificate") } // A DSA publickey as marshaled by dsaPublicKey.Marshal() in keys.go var w struct { Name string P, Q, G, Y *big.Int } if err := ssh.Unmarshal(cert.Key.Marshal(), &w); err != nil { return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err) } priv := &dsa.PrivateKey{ PublicKey: dsa.PublicKey{ Parameters: dsa.Parameters{ P: w.P, Q: w.Q, G: w.G, }, Y: w.Y, }, X: k.X, } addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments} if err := setConstraints(addedKey, k.Constraints); err != nil { return nil, err } return addedKey, nil } func parseECDSACert(req []byte) (*AddedKey, error) { var k ecdsaCertMsg if err := ssh.Unmarshal(req, &k); err != nil { return nil, err } pubKey, err := ssh.ParsePublicKey(k.CertBytes) if err != nil { return nil, err } cert, ok := pubKey.(*ssh.Certificate) if !ok { return nil, errors.New("agent: bad ECDSA certificate") } // An ECDSA publickey as marshaled by ecdsaPublicKey.Marshal() in keys.go var ecdsaPub struct { Name string ID string Key []byte } if err := ssh.Unmarshal(cert.Key.Marshal(), &ecdsaPub); err != nil { return nil, err } priv, err := unmarshalECDSA(ecdsaPub.ID, ecdsaPub.Key, k.D) if err != nil { return nil, err } addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments} if err := setConstraints(addedKey, k.Constraints); err != nil { return nil, err } return addedKey, nil } func (s *server) insertIdentity(req []byte) error { var record struct { Type string `sshtype:"17|25"` Rest []byte `ssh:"rest"` } if err := ssh.Unmarshal(req, &record); err != nil { return err } var addedKey *AddedKey var err error switch record.Type { case ssh.KeyAlgoRSA: addedKey, err = parseRSAKey(req) case ssh.KeyAlgoDSA: addedKey, err = parseDSAKey(req) case ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521: addedKey, err = parseECDSAKey(req) case ssh.KeyAlgoED25519: addedKey, err = parseEd25519Key(req) case ssh.CertAlgoRSAv01: addedKey, err = parseRSACert(req) case ssh.CertAlgoDSAv01: addedKey, err = parseDSACert(req) case ssh.CertAlgoECDSA256v01, ssh.CertAlgoECDSA384v01, ssh.CertAlgoECDSA521v01: addedKey, err = parseECDSACert(req) case ssh.CertAlgoED25519v01: addedKey, err = parseEd25519Cert(req) default: return fmt.Errorf("agent: not implemented: %q", record.Type) } if err != nil { return err } return s.agent.Add(*addedKey) } // ServeAgent serves the agent protocol on the given connection. It // returns when an I/O error occurs. func ServeAgent(agent Agent, c io.ReadWriter) error { s := &server{agent} var length [4]byte for { if _, err := io.ReadFull(c, length[:]); err != nil { return err } l := binary.BigEndian.Uint32(length[:]) if l == 0 { return fmt.Errorf("agent: request size is 0") } if l > maxAgentResponseBytes { // We also cap requests. return fmt.Errorf("agent: request too large: %d", l) } req := make([]byte, l) if _, err := io.ReadFull(c, req); err != nil { return err } repData := s.processRequestBytes(req) if len(repData) > maxAgentResponseBytes { return fmt.Errorf("agent: reply too large: %d bytes", len(repData)) } binary.BigEndian.PutUint32(length[:], uint32(len(repData))) if _, err := c.Write(length[:]); err != nil { return err } if _, err := c.Write(repData); err != nil { return err } } }