update crypto

This commit is contained in:
mxi4oyu 2017-05-19 18:03:26 +08:00
parent d213d3ea6d
commit bf43273def
97 changed files with 7361 additions and 2794 deletions

2
x/crypto/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
# Add no patterns to .hgignore except for files generated by the build.
last-change

View File

@ -15,6 +15,7 @@ package acme
import ( import (
"bytes" "bytes"
"context"
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
@ -36,9 +37,6 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"golang.org/x/net/context"
"golang.org/x/net/context/ctxhttp"
) )
// LetsEncryptURL is the Directory endpoint of Let's Encrypt CA. // LetsEncryptURL is the Directory endpoint of Let's Encrypt CA.
@ -47,6 +45,10 @@ const LetsEncryptURL = "https://acme-v01.api.letsencrypt.org/directory"
const ( const (
maxChainLen = 5 // max depth and breadth of a certificate chain maxChainLen = 5 // max depth and breadth of a certificate chain
maxCertSize = 1 << 20 // max size of a certificate, in bytes maxCertSize = 1 << 20 // max size of a certificate, in bytes
// Max number of collected nonces kept in memory.
// Expect usual peak of 1 or 2.
maxNonces = 100
) )
// CertOption is an optional argument type for Client methods which manipulate // CertOption is an optional argument type for Client methods which manipulate
@ -108,6 +110,9 @@ type Client struct {
dirMu sync.Mutex // guards writes to dir dirMu sync.Mutex // guards writes to dir
dir *Directory // cached result of Client's Discover method dir *Directory // cached result of Client's Discover method
noncesMu sync.Mutex
nonces map[string]struct{} // nonces collected from previous responses
} }
// Discover performs ACME server discovery using c.DirectoryURL. // Discover performs ACME server discovery using c.DirectoryURL.
@ -126,11 +131,12 @@ func (c *Client) Discover(ctx context.Context) (Directory, error) {
if dirURL == "" { if dirURL == "" {
dirURL = LetsEncryptURL dirURL = LetsEncryptURL
} }
res, err := ctxhttp.Get(ctx, c.HTTPClient, dirURL) res, err := c.get(ctx, dirURL)
if err != nil { if err != nil {
return Directory{}, err return Directory{}, err
} }
defer res.Body.Close() defer res.Body.Close()
c.addNonce(res.Header)
if res.StatusCode != http.StatusOK { if res.StatusCode != http.StatusOK {
return Directory{}, responseError(res) return Directory{}, responseError(res)
} }
@ -146,7 +152,7 @@ func (c *Client) Discover(ctx context.Context) (Directory, error) {
CAA []string `json:"caa-identities"` CAA []string `json:"caa-identities"`
} }
} }
if json.NewDecoder(res.Body).Decode(&v); err != nil { if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
return Directory{}, err return Directory{}, err
} }
c.dir = &Directory{ c.dir = &Directory{
@ -192,7 +198,7 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
req.NotAfter = now.Add(exp).Format(time.RFC3339) req.NotAfter = now.Add(exp).Format(time.RFC3339)
} }
res, err := postJWS(ctx, c.HTTPClient, c.Key, c.dir.CertURL, req) res, err := c.retryPostJWS(ctx, c.Key, c.dir.CertURL, req)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
@ -208,7 +214,7 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
return cert, curl, err return cert, curl, err
} }
// slurp issued cert and CA chain, if requested // slurp issued cert and CA chain, if requested
cert, err := responseCert(ctx, c.HTTPClient, res, bundle) cert, err := c.responseCert(ctx, res, bundle)
return cert, curl, err return cert, curl, err
} }
@ -223,13 +229,13 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
// and has expected features. // and has expected features.
func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) { func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) {
for { for {
res, err := ctxhttp.Get(ctx, c.HTTPClient, url) res, err := c.get(ctx, url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode == http.StatusOK { if res.StatusCode == http.StatusOK {
return responseCert(ctx, c.HTTPClient, res, bundle) return c.responseCert(ctx, res, bundle)
} }
if res.StatusCode > 299 { if res.StatusCode > 299 {
return nil, responseError(res) return nil, responseError(res)
@ -267,7 +273,7 @@ func (c *Client) RevokeCert(ctx context.Context, key crypto.Signer, cert []byte,
if key == nil { if key == nil {
key = c.Key key = c.Key
} }
res, err := postJWS(ctx, c.HTTPClient, key, c.dir.RevokeURL, body) res, err := c.retryPostJWS(ctx, key, c.dir.RevokeURL, body)
if err != nil { if err != nil {
return err return err
} }
@ -355,7 +361,7 @@ func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization,
Resource: "new-authz", Resource: "new-authz",
Identifier: authzID{Type: "dns", Value: domain}, Identifier: authzID{Type: "dns", Value: domain},
} }
res, err := postJWS(ctx, c.HTTPClient, c.Key, c.dir.AuthzURL, req) res, err := c.retryPostJWS(ctx, c.Key, c.dir.AuthzURL, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -379,7 +385,7 @@ func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization,
// If a caller needs to poll an authorization until its status is final, // If a caller needs to poll an authorization until its status is final,
// see the WaitAuthorization method. // see the WaitAuthorization method.
func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) { func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) {
res, err := ctxhttp.Get(ctx, c.HTTPClient, url) res, err := c.get(ctx, url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -413,7 +419,7 @@ func (c *Client) RevokeAuthorization(ctx context.Context, url string) error {
Status: "deactivated", Status: "deactivated",
Delete: true, Delete: true,
} }
res, err := postJWS(ctx, c.HTTPClient, c.Key, url, req) res, err := c.retryPostJWS(ctx, c.Key, url, req)
if err != nil { if err != nil {
return err return err
} }
@ -430,25 +436,11 @@ func (c *Client) RevokeAuthorization(ctx context.Context, url string) error {
// //
// It returns a non-nil Authorization only if its Status is StatusValid. // It returns a non-nil Authorization only if its Status is StatusValid.
// In all other cases WaitAuthorization returns an error. // In all other cases WaitAuthorization returns an error.
// If the Status is StatusInvalid, the returned error is ErrAuthorizationFailed. // If the Status is StatusInvalid, the returned error is of type *AuthorizationError.
func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorization, error) { func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorization, error) {
var count int sleep := sleeper(ctx)
sleep := func(v string, inc int) error {
count += inc
d := backoff(count, 10*time.Second)
d = retryAfter(v, d)
wakeup := time.NewTimer(d)
defer wakeup.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-wakeup.C:
return nil
}
}
for { for {
res, err := ctxhttp.Get(ctx, c.HTTPClient, url) res, err := c.get(ctx, url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -473,7 +465,7 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat
return raw.authorization(url), nil return raw.authorization(url), nil
} }
if raw.Status == StatusInvalid { if raw.Status == StatusInvalid {
return nil, ErrAuthorizationFailed return nil, raw.error(url)
} }
if err := sleep(retry, 0); err != nil { if err := sleep(retry, 0); err != nil {
return nil, err return nil, err
@ -485,7 +477,7 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat
// //
// A client typically polls a challenge status using this method. // A client typically polls a challenge status using this method.
func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) { func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) {
res, err := ctxhttp.Get(ctx, c.HTTPClient, url) res, err := c.get(ctx, url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -519,7 +511,7 @@ func (c *Client) Accept(ctx context.Context, chal *Challenge) (*Challenge, error
Type: chal.Type, Type: chal.Type,
Auth: auth, Auth: auth,
} }
res, err := postJWS(ctx, c.HTTPClient, c.Key, chal.URI, req) res, err := c.retryPostJWS(ctx, c.Key, chal.URI, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -652,7 +644,7 @@ func (c *Client) doReg(ctx context.Context, url string, typ string, acct *Accoun
req.Contact = acct.Contact req.Contact = acct.Contact
req.Agreement = acct.AgreedTerms req.Agreement = acct.AgreedTerms
} }
res, err := postJWS(ctx, c.HTTPClient, c.Key, url, req) res, err := c.retryPostJWS(ctx, c.Key, url, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -689,7 +681,170 @@ func (c *Client) doReg(ctx context.Context, url string, typ string, acct *Accoun
}, nil }, nil
} }
func responseCert(ctx context.Context, client *http.Client, res *http.Response, bundle bool) ([][]byte, error) { // retryPostJWS will retry calls to postJWS if there is a badNonce error,
// clearing the stored nonces after each error.
// If the response was 4XX-5XX, then responseError is called on the body,
// the body is closed, and the error returned.
func (c *Client) retryPostJWS(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, error) {
sleep := sleeper(ctx)
for {
res, err := c.postJWS(ctx, key, url, body)
if err != nil {
return nil, err
}
// handle errors 4XX-5XX with responseError
if res.StatusCode >= 400 && res.StatusCode <= 599 {
err := responseError(res)
res.Body.Close()
// according to spec badNonce is urn:ietf:params:acme:error:badNonce
// however, acme servers in the wild return their version of the error
// https://tools.ietf.org/html/draft-ietf-acme-acme-02#section-5.4
if ae, ok := err.(*Error); ok && strings.HasSuffix(strings.ToLower(ae.ProblemType), ":badnonce") {
// clear any nonces that we might've stored that might now be
// considered bad
c.clearNonces()
retry := res.Header.Get("retry-after")
if err := sleep(retry, 1); err != nil {
return nil, err
}
continue
}
return nil, err
}
return res, nil
}
}
// postJWS signs the body with the given key and POSTs it to the provided url.
// The body argument must be JSON-serializable.
func (c *Client) postJWS(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, error) {
nonce, err := c.popNonce(ctx, url)
if err != nil {
return nil, err
}
b, err := jwsEncodeJSON(body, key, nonce)
if err != nil {
return nil, err
}
res, err := c.post(ctx, url, "application/jose+json", bytes.NewReader(b))
if err != nil {
return nil, err
}
c.addNonce(res.Header)
return res, nil
}
// popNonce returns a nonce value previously stored with c.addNonce
// or fetches a fresh one from the given URL.
func (c *Client) popNonce(ctx context.Context, url string) (string, error) {
c.noncesMu.Lock()
defer c.noncesMu.Unlock()
if len(c.nonces) == 0 {
return c.fetchNonce(ctx, url)
}
var nonce string
for nonce = range c.nonces {
delete(c.nonces, nonce)
break
}
return nonce, nil
}
// clearNonces clears any stored nonces
func (c *Client) clearNonces() {
c.noncesMu.Lock()
defer c.noncesMu.Unlock()
c.nonces = make(map[string]struct{})
}
// addNonce stores a nonce value found in h (if any) for future use.
func (c *Client) addNonce(h http.Header) {
v := nonceFromHeader(h)
if v == "" {
return
}
c.noncesMu.Lock()
defer c.noncesMu.Unlock()
if len(c.nonces) >= maxNonces {
return
}
if c.nonces == nil {
c.nonces = make(map[string]struct{})
}
c.nonces[v] = struct{}{}
}
func (c *Client) httpClient() *http.Client {
if c.HTTPClient != nil {
return c.HTTPClient
}
return http.DefaultClient
}
func (c *Client) get(ctx context.Context, urlStr string) (*http.Response, error) {
req, err := http.NewRequest("GET", urlStr, nil)
if err != nil {
return nil, err
}
return c.do(ctx, req)
}
func (c *Client) head(ctx context.Context, urlStr string) (*http.Response, error) {
req, err := http.NewRequest("HEAD", urlStr, nil)
if err != nil {
return nil, err
}
return c.do(ctx, req)
}
func (c *Client) post(ctx context.Context, urlStr, contentType string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest("POST", urlStr, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", contentType)
return c.do(ctx, req)
}
func (c *Client) do(ctx context.Context, req *http.Request) (*http.Response, error) {
res, err := c.httpClient().Do(req.WithContext(ctx))
if err != nil {
select {
case <-ctx.Done():
// Prefer the unadorned context error.
// (The acme package had tests assuming this, previously from ctxhttp's
// behavior, predating net/http supporting contexts natively)
// TODO(bradfitz): reconsider this in the future. But for now this
// requires no test updates.
return nil, ctx.Err()
default:
return nil, err
}
}
return res, nil
}
func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) {
resp, err := c.head(ctx, url)
if err != nil {
return "", err
}
defer resp.Body.Close()
nonce := nonceFromHeader(resp.Header)
if nonce == "" {
if resp.StatusCode > 299 {
return "", responseError(resp)
}
return "", errors.New("acme: nonce not found")
}
return nonce, nil
}
func nonceFromHeader(h http.Header) string {
return h.Get("Replay-Nonce")
}
func (c *Client) responseCert(ctx context.Context, res *http.Response, bundle bool) ([][]byte, error) {
b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1)) b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1))
if err != nil { if err != nil {
return nil, fmt.Errorf("acme: response stream: %v", err) return nil, fmt.Errorf("acme: response stream: %v", err)
@ -713,7 +868,7 @@ func responseCert(ctx context.Context, client *http.Client, res *http.Response,
return nil, errors.New("acme: rel=up link is too large") return nil, errors.New("acme: rel=up link is too large")
} }
for _, url := range up { for _, url := range up {
cc, err := chainCert(ctx, client, url, 0) cc, err := c.chainCert(ctx, url, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -727,14 +882,8 @@ func responseError(resp *http.Response) error {
// don't care if ReadAll returns an error: // don't care if ReadAll returns an error:
// json.Unmarshal will fail in that case anyway // json.Unmarshal will fail in that case anyway
b, _ := ioutil.ReadAll(resp.Body) b, _ := ioutil.ReadAll(resp.Body)
e := struct { e := &wireError{Status: resp.StatusCode}
Status int if err := json.Unmarshal(b, e); err != nil {
Type string
Detail string
}{
Status: resp.StatusCode,
}
if err := json.Unmarshal(b, &e); err != nil {
// this is not a regular error response: // this is not a regular error response:
// populate detail with anything we received, // populate detail with anything we received,
// e.Status will already contain HTTP response code value // e.Status will already contain HTTP response code value
@ -743,12 +892,7 @@ func responseError(resp *http.Response) error {
e.Detail = resp.Status e.Detail = resp.Status
} }
} }
return &Error{ return e.error(resp.Header)
StatusCode: e.Status,
ProblemType: e.Type,
Detail: e.Detail,
Header: resp.Header,
}
} }
// chainCert fetches CA certificate chain recursively by following "up" links. // chainCert fetches CA certificate chain recursively by following "up" links.
@ -756,12 +900,12 @@ func responseError(resp *http.Response) error {
// if the recursion level reaches maxChainLen. // if the recursion level reaches maxChainLen.
// //
// First chainCert call starts with depth of 0. // First chainCert call starts with depth of 0.
func chainCert(ctx context.Context, client *http.Client, url string, depth int) ([][]byte, error) { func (c *Client) chainCert(ctx context.Context, url string, depth int) ([][]byte, error) {
if depth >= maxChainLen { if depth >= maxChainLen {
return nil, errors.New("acme: certificate chain is too deep") return nil, errors.New("acme: certificate chain is too deep")
} }
res, err := ctxhttp.Get(ctx, client, url) res, err := c.get(ctx, url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -783,7 +927,7 @@ func chainCert(ctx context.Context, client *http.Client, url string, depth int)
return nil, errors.New("acme: certificate chain is too large") return nil, errors.New("acme: certificate chain is too large")
} }
for _, up := range uplink { for _, up := range uplink {
cc, err := chainCert(ctx, client, up, depth+1) cc, err := c.chainCert(ctx, up, depth+1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -793,33 +937,6 @@ func chainCert(ctx context.Context, client *http.Client, url string, depth int)
return chain, nil return chain, nil
} }
// postJWS signs the body with the given key and POSTs it to the provided url.
// The body argument must be JSON-serializable.
func postJWS(ctx context.Context, client *http.Client, key crypto.Signer, url string, body interface{}) (*http.Response, error) {
nonce, err := fetchNonce(ctx, client, url)
if err != nil {
return nil, err
}
b, err := jwsEncodeJSON(body, key, nonce)
if err != nil {
return nil, err
}
return ctxhttp.Post(ctx, client, url, "application/jose+json", bytes.NewReader(b))
}
func fetchNonce(ctx context.Context, client *http.Client, url string) (string, error) {
resp, err := ctxhttp.Head(ctx, client, url)
if err != nil {
return "", nil
}
defer resp.Body.Close()
enc := resp.Header.Get("replay-nonce")
if enc == "" {
return "", errors.New("acme: nonce not found")
}
return enc, nil
}
// linkHeader returns URI-Reference values of all Link headers // linkHeader returns URI-Reference values of all Link headers
// with relation-type rel. // with relation-type rel.
// See https://tools.ietf.org/html/rfc5988#section-5 for details. // See https://tools.ietf.org/html/rfc5988#section-5 for details.
@ -840,6 +957,28 @@ func linkHeader(h http.Header, rel string) []string {
return links return links
} }
// sleeper returns a function that accepts the Retry-After HTTP header value
// and an increment that's used with backoff to increasingly sleep on
// consecutive calls until the context is done. If the Retry-After header
// cannot be parsed, then backoff is used with a maximum sleep time of 10
// seconds.
func sleeper(ctx context.Context) func(ra string, inc int) error {
var count int
return func(ra string, inc int) error {
count += inc
d := backoff(count, 10*time.Second)
d = retryAfter(ra, d)
wakeup := time.NewTimer(d)
defer wakeup.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-wakeup.C:
return nil
}
}
}
// retryAfter parses a Retry-After HTTP header value, // retryAfter parses a Retry-After HTTP header value,
// trying to convert v into an int (seconds) or use http.ParseTime otherwise. // trying to convert v into an int (seconds) or use http.ParseTime otherwise.
// It returns d if v cannot be parsed. // It returns d if v cannot be parsed.
@ -921,7 +1060,8 @@ func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) {
NotBefore: time.Now(), NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour), NotAfter: time.Now().Add(24 * time.Hour),
BasicConstraintsValid: true, BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageKeyEncipherment, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
} }
} }
tmpl.DNSNames = san tmpl.DNSNames = san

View File

@ -6,6 +6,7 @@ package acme
import ( import (
"bytes" "bytes"
"context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/tls" "crypto/tls"
@ -23,8 +24,6 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
"golang.org/x/net/context"
) )
// Decodes a JWS-encoded request and unmarshals the decoded JSON into a provided // Decodes a JWS-encoded request and unmarshals the decoded JSON into a provided
@ -45,6 +44,28 @@ func decodeJWSRequest(t *testing.T, v interface{}, r *http.Request) {
} }
} }
type jwsHead struct {
Alg string
Nonce string
JWK map[string]string `json:"jwk"`
}
func decodeJWSHead(r *http.Request) (*jwsHead, error) {
var req struct{ Protected string }
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, err
}
b, err := base64.RawURLEncoding.DecodeString(req.Protected)
if err != nil {
return nil, err
}
var head jwsHead
if err := json.Unmarshal(b, &head); err != nil {
return nil, err
}
return &head, nil
}
func TestDiscover(t *testing.T) { func TestDiscover(t *testing.T) {
const ( const (
reg = "https://example.com/acme/new-reg" reg = "https://example.com/acme/new-reg"
@ -522,6 +543,9 @@ func TestWaitAuthorizationInvalid(t *testing.T) {
if err == nil { if err == nil {
t.Error("err is nil") t.Error("err is nil")
} }
if _, ok := err.(*AuthorizationError); !ok {
t.Errorf("err is %T; want *AuthorizationError", err)
}
} }
} }
@ -916,7 +940,30 @@ func TestRevokeCert(t *testing.T) {
} }
} }
func TestFetchNonce(t *testing.T) { func TestNonce_add(t *testing.T) {
var c Client
c.addNonce(http.Header{"Replay-Nonce": {"nonce"}})
c.addNonce(http.Header{"Replay-Nonce": {}})
c.addNonce(http.Header{"Replay-Nonce": {"nonce"}})
nonces := map[string]struct{}{"nonce": struct{}{}}
if !reflect.DeepEqual(c.nonces, nonces) {
t.Errorf("c.nonces = %q; want %q", c.nonces, nonces)
}
}
func TestNonce_addMax(t *testing.T) {
c := &Client{nonces: make(map[string]struct{})}
for i := 0; i < maxNonces; i++ {
c.nonces[fmt.Sprintf("%d", i)] = struct{}{}
}
c.addNonce(http.Header{"Replay-Nonce": {"nonce"}})
if n := len(c.nonces); n != maxNonces {
t.Errorf("len(c.nonces) = %d; want %d", n, maxNonces)
}
}
func TestNonce_fetch(t *testing.T) {
tests := []struct { tests := []struct {
code int code int
nonce string nonce string
@ -936,7 +983,8 @@ func TestFetchNonce(t *testing.T) {
defer ts.Close() defer ts.Close()
for ; i < len(tests); i++ { for ; i < len(tests); i++ {
test := tests[i] test := tests[i]
n, err := fetchNonce(context.Background(), http.DefaultClient, ts.URL) c := &Client{}
n, err := c.fetchNonce(context.Background(), ts.URL)
if n != test.nonce { if n != test.nonce {
t.Errorf("%d: n=%q; want %q", i, n, test.nonce) t.Errorf("%d: n=%q; want %q", i, n, test.nonce)
} }
@ -949,6 +997,115 @@ func TestFetchNonce(t *testing.T) {
} }
} }
func TestNonce_fetchError(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTooManyRequests)
}))
defer ts.Close()
c := &Client{}
_, err := c.fetchNonce(context.Background(), ts.URL)
e, ok := err.(*Error)
if !ok {
t.Fatalf("err is %T; want *Error", err)
}
if e.StatusCode != http.StatusTooManyRequests {
t.Errorf("e.StatusCode = %d; want %d", e.StatusCode, http.StatusTooManyRequests)
}
}
func TestNonce_postJWS(t *testing.T) {
var count int
seen := make(map[string]bool)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count++
w.Header().Set("replay-nonce", fmt.Sprintf("nonce%d", count))
if r.Method == "HEAD" {
// We expect the client do a HEAD request
// but only to fetch the first nonce.
return
}
// Make client.Authorize happy; we're not testing its result.
defer func() {
w.WriteHeader(http.StatusCreated)
w.Write([]byte(`{"status":"valid"}`))
}()
head, err := decodeJWSHead(r)
if err != nil {
t.Errorf("decodeJWSHead: %v", err)
return
}
if head.Nonce == "" {
t.Error("head.Nonce is empty")
return
}
if seen[head.Nonce] {
t.Errorf("nonce is already used: %q", head.Nonce)
}
seen[head.Nonce] = true
}))
defer ts.Close()
client := Client{Key: testKey, dir: &Directory{AuthzURL: ts.URL}}
if _, err := client.Authorize(context.Background(), "example.com"); err != nil {
t.Errorf("client.Authorize 1: %v", err)
}
// The second call should not generate another extra HEAD request.
if _, err := client.Authorize(context.Background(), "example.com"); err != nil {
t.Errorf("client.Authorize 2: %v", err)
}
if count != 3 {
t.Errorf("total requests count: %d; want 3", count)
}
if n := len(client.nonces); n != 1 {
t.Errorf("len(client.nonces) = %d; want 1", n)
}
for k := range seen {
if _, exist := client.nonces[k]; exist {
t.Errorf("used nonce %q in client.nonces", k)
}
}
}
func TestRetryPostJWS(t *testing.T) {
var count int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count++
w.Header().Set("replay-nonce", fmt.Sprintf("nonce%d", count))
if r.Method == "HEAD" {
// We expect the client to do 2 head requests to fetch
// nonces, one to start and another after getting badNonce
return
}
head, err := decodeJWSHead(r)
if err != nil {
t.Errorf("decodeJWSHead: %v", err)
} else if head.Nonce == "" {
t.Error("head.Nonce is empty")
} else if head.Nonce == "nonce1" {
// return a badNonce error to force the call to retry
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"type":"urn:ietf:params:acme:error:badNonce"}`))
return
}
// Make client.Authorize happy; we're not testing its result.
w.WriteHeader(http.StatusCreated)
w.Write([]byte(`{"status":"valid"}`))
}))
defer ts.Close()
client := Client{Key: testKey, dir: &Directory{AuthzURL: ts.URL}}
// This call will fail with badNonce, causing a retry
if _, err := client.Authorize(context.Background(), "example.com"); err != nil {
t.Errorf("client.Authorize 1: %v", err)
}
if count != 4 {
t.Errorf("total requests count: %d; want 4", count)
}
}
func TestLinkHeader(t *testing.T) { func TestLinkHeader(t *testing.T) {
h := http.Header{"Link": { h := http.Header{"Link": {
`<https://example.com/acme/new-authz>;rel="next"`, `<https://example.com/acme/new-authz>;rel="next"`,

View File

@ -10,6 +10,7 @@ package autocert
import ( import (
"bytes" "bytes"
"context"
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
@ -30,9 +31,14 @@ import (
"time" "time"
"golang.org/x/crypto/acme" "golang.org/x/crypto/acme"
"golang.org/x/net/context"
) )
// createCertRetryAfter is how much time to wait before removing a failed state
// entry due to an unsuccessful createCert call.
// This is a variable instead of a const for testing.
// TODO: Consider making it configurable or an exp backoff?
var createCertRetryAfter = time.Minute
// pseudoRand is safe for concurrent use. // pseudoRand is safe for concurrent use.
var pseudoRand *lockedMathRand var pseudoRand *lockedMathRand
@ -41,8 +47,9 @@ func init() {
pseudoRand = &lockedMathRand{rnd: mathrand.New(src)} pseudoRand = &lockedMathRand{rnd: mathrand.New(src)}
} }
// AcceptTOS always returns true to indicate the acceptance of a CA Terms of Service // AcceptTOS is a Manager.Prompt function that always returns true to
// during account registration. // indicate acceptance of the CA's Terms of Service during account
// registration.
func AcceptTOS(tosURL string) bool { return true } func AcceptTOS(tosURL string) bool { return true }
// HostPolicy specifies which host names the Manager is allowed to respond to. // HostPolicy specifies which host names the Manager is allowed to respond to.
@ -76,18 +83,6 @@ func defaultHostPolicy(context.Context, string) error {
// It obtains and refreshes certificates automatically, // It obtains and refreshes certificates automatically,
// as well as providing them to a TLS server via tls.Config. // as well as providing them to a TLS server via tls.Config.
// //
// A simple usage example:
//
// m := autocert.Manager{
// Prompt: autocert.AcceptTOS,
// HostPolicy: autocert.HostWhitelist("example.org"),
// }
// s := &http.Server{
// Addr: ":https",
// TLSConfig: &tls.Config{GetCertificate: m.GetCertificate},
// }
// s.ListenAndServeTLS("", "")
//
// To preserve issued certificates and improve overall performance, // To preserve issued certificates and improve overall performance,
// use a cache implementation of Cache. For instance, DirCache. // use a cache implementation of Cache. For instance, DirCache.
type Manager struct { type Manager struct {
@ -123,7 +118,7 @@ type Manager struct {
// RenewBefore optionally specifies how early certificates should // RenewBefore optionally specifies how early certificates should
// be renewed before they expire. // be renewed before they expire.
// //
// If zero, they're renewed 1 week before expiration. // If zero, they're renewed 30 days before expiration.
RenewBefore time.Duration RenewBefore time.Duration
// Client is used to perform low-level operations, such as account registration // Client is used to perform low-level operations, such as account registration
@ -173,10 +168,23 @@ type Manager struct {
// The error is propagated back to the caller of GetCertificate and is user-visible. // The error is propagated back to the caller of GetCertificate and is user-visible.
// This does not affect cached certs. See HostPolicy field description for more details. // This does not affect cached certs. See HostPolicy field description for more details.
func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if m.Prompt == nil {
return nil, errors.New("acme/autocert: Manager.Prompt not set")
}
name := hello.ServerName name := hello.ServerName
if name == "" { if name == "" {
return nil, errors.New("acme/autocert: missing server name") return nil, errors.New("acme/autocert: missing server name")
} }
if !strings.Contains(strings.Trim(name, "."), ".") {
return nil, errors.New("acme/autocert: server name component count invalid")
}
if strings.ContainsAny(name, `/\`) {
return nil, errors.New("acme/autocert: server name contains invalid character")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// check whether this is a token cert requested for TLS-SNI challenge // check whether this is a token cert requested for TLS-SNI challenge
if strings.HasSuffix(name, ".acme.invalid") { if strings.HasSuffix(name, ".acme.invalid") {
@ -185,7 +193,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
if cert := m.tokenCert[name]; cert != nil { if cert := m.tokenCert[name]; cert != nil {
return cert, nil return cert, nil
} }
if cert, err := m.cacheGet(name); err == nil { if cert, err := m.cacheGet(ctx, name); err == nil {
return cert, nil return cert, nil
} }
// TODO: cache error results? // TODO: cache error results?
@ -194,7 +202,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
// regular domain // regular domain
name = strings.TrimSuffix(name, ".") // golang.org/issue/18114 name = strings.TrimSuffix(name, ".") // golang.org/issue/18114
cert, err := m.cert(name) cert, err := m.cert(ctx, name)
if err == nil { if err == nil {
return cert, nil return cert, nil
} }
@ -203,7 +211,6 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
} }
// first-time // first-time
ctx := context.Background() // TODO: use a deadline?
if err := m.hostPolicy()(ctx, name); err != nil { if err := m.hostPolicy()(ctx, name); err != nil {
return nil, err return nil, err
} }
@ -211,14 +218,14 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
if err != nil { if err != nil {
return nil, err return nil, err
} }
m.cachePut(name, cert) m.cachePut(ctx, name, cert)
return cert, nil return cert, nil
} }
// cert returns an existing certificate either from m.state or cache. // cert returns an existing certificate either from m.state or cache.
// If a certificate is found in cache but not in m.state, the latter will be filled // If a certificate is found in cache but not in m.state, the latter will be filled
// with the cached value. // with the cached value.
func (m *Manager) cert(name string) (*tls.Certificate, error) { func (m *Manager) cert(ctx context.Context, name string) (*tls.Certificate, error) {
m.stateMu.Lock() m.stateMu.Lock()
if s, ok := m.state[name]; ok { if s, ok := m.state[name]; ok {
m.stateMu.Unlock() m.stateMu.Unlock()
@ -227,7 +234,7 @@ func (m *Manager) cert(name string) (*tls.Certificate, error) {
return s.tlscert() return s.tlscert()
} }
defer m.stateMu.Unlock() defer m.stateMu.Unlock()
cert, err := m.cacheGet(name) cert, err := m.cacheGet(ctx, name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -249,12 +256,11 @@ func (m *Manager) cert(name string) (*tls.Certificate, error) {
} }
// cacheGet always returns a valid certificate, or an error otherwise. // cacheGet always returns a valid certificate, or an error otherwise.
func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) { // If a cached certficate exists but is not valid, ErrCacheMiss is returned.
func (m *Manager) cacheGet(ctx context.Context, domain string) (*tls.Certificate, error) {
if m.Cache == nil { if m.Cache == nil {
return nil, ErrCacheMiss return nil, ErrCacheMiss
} }
// TODO: might want to define a cache timeout on m
ctx := context.Background()
data, err := m.Cache.Get(ctx, domain) data, err := m.Cache.Get(ctx, domain)
if err != nil { if err != nil {
return nil, err return nil, err
@ -263,7 +269,7 @@ func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) {
// private // private
priv, pub := pem.Decode(data) priv, pub := pem.Decode(data)
if priv == nil || !strings.Contains(priv.Type, "PRIVATE") { if priv == nil || !strings.Contains(priv.Type, "PRIVATE") {
return nil, errors.New("acme/autocert: no private key found in cache") return nil, ErrCacheMiss
} }
privKey, err := parsePrivateKey(priv.Bytes) privKey, err := parsePrivateKey(priv.Bytes)
if err != nil { if err != nil {
@ -281,13 +287,14 @@ func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) {
pubDER = append(pubDER, b.Bytes) pubDER = append(pubDER, b.Bytes)
} }
if len(pub) > 0 { if len(pub) > 0 {
return nil, errors.New("acme/autocert: invalid public key") // Leftover content not consumed by pem.Decode. Corrupt. Ignore.
return nil, ErrCacheMiss
} }
// verify and create TLS cert // verify and create TLS cert
leaf, err := validCert(domain, pubDER, privKey) leaf, err := validCert(domain, pubDER, privKey)
if err != nil { if err != nil {
return nil, err return nil, ErrCacheMiss
} }
tlscert := &tls.Certificate{ tlscert := &tls.Certificate{
Certificate: pubDER, Certificate: pubDER,
@ -297,7 +304,7 @@ func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) {
return tlscert, nil return tlscert, nil
} }
func (m *Manager) cachePut(domain string, tlscert *tls.Certificate) error { func (m *Manager) cachePut(ctx context.Context, domain string, tlscert *tls.Certificate) error {
if m.Cache == nil { if m.Cache == nil {
return nil return nil
} }
@ -329,8 +336,6 @@ func (m *Manager) cachePut(domain string, tlscert *tls.Certificate) error {
} }
} }
// TODO: might want to define a cache timeout on m
ctx := context.Background()
return m.Cache.Put(ctx, domain, buf.Bytes()) return m.Cache.Put(ctx, domain, buf.Bytes())
} }
@ -370,6 +375,23 @@ func (m *Manager) createCert(ctx context.Context, domain string) (*tls.Certifica
der, leaf, err := m.authorizedCert(ctx, state.key, domain) der, leaf, err := m.authorizedCert(ctx, state.key, domain)
if err != nil { if err != nil {
// Remove the failed state after some time,
// making the manager call createCert again on the following TLS hello.
time.AfterFunc(createCertRetryAfter, func() {
defer testDidRemoveState(domain)
m.stateMu.Lock()
defer m.stateMu.Unlock()
// Verify the state hasn't changed and it's still invalid
// before deleting.
s, ok := m.state[domain]
if !ok {
return
}
if _, err := validCert(domain, s.cert, s.key); err == nil {
return
}
delete(m.state, domain)
})
return nil, err return nil, err
} }
state.cert = der state.cert = der
@ -418,7 +440,6 @@ func (m *Manager) certState(domain string) (*certState, error) {
// authorizedCert starts domain ownership verification process and requests a new cert upon success. // authorizedCert starts domain ownership verification process and requests a new cert upon success.
// The key argument is the certificate private key. // The key argument is the certificate private key.
func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, domain string) (der [][]byte, leaf *x509.Certificate, err error) { func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, domain string) (der [][]byte, leaf *x509.Certificate, err error) {
// TODO: make m.verify retry or retry m.verify calls here
if err := m.verify(ctx, domain); err != nil { if err := m.verify(ctx, domain); err != nil {
return nil, nil, err return nil, nil, err
} }
@ -494,7 +515,7 @@ func (m *Manager) verify(ctx context.Context, domain string) error {
if err != nil { if err != nil {
return err return err
} }
m.putTokenCert(name, &cert) m.putTokenCert(ctx, name, &cert)
defer func() { defer func() {
// verification has ended at this point // verification has ended at this point
// don't need token cert anymore // don't need token cert anymore
@ -512,14 +533,14 @@ func (m *Manager) verify(ctx context.Context, domain string) error {
// putTokenCert stores the cert under the named key in both m.tokenCert map // putTokenCert stores the cert under the named key in both m.tokenCert map
// and m.Cache. // and m.Cache.
func (m *Manager) putTokenCert(name string, cert *tls.Certificate) { func (m *Manager) putTokenCert(ctx context.Context, name string, cert *tls.Certificate) {
m.tokenCertMu.Lock() m.tokenCertMu.Lock()
defer m.tokenCertMu.Unlock() defer m.tokenCertMu.Unlock()
if m.tokenCert == nil { if m.tokenCert == nil {
m.tokenCert = make(map[string]*tls.Certificate) m.tokenCert = make(map[string]*tls.Certificate)
} }
m.tokenCert[name] = cert m.tokenCert[name] = cert
m.cachePut(name, cert) m.cachePut(ctx, name, cert)
} }
// deleteTokenCert removes the token certificate for the specified domain name // deleteTokenCert removes the token certificate for the specified domain name
@ -644,10 +665,10 @@ func (m *Manager) hostPolicy() HostPolicy {
} }
func (m *Manager) renewBefore() time.Duration { func (m *Manager) renewBefore() time.Duration {
if m.RenewBefore > maxRandRenew { if m.RenewBefore > renewJitter {
return m.RenewBefore return m.RenewBefore
} }
return 7 * 24 * time.Hour // 1 week return 720 * time.Hour // 30 days
} }
// certState is ready when its mutex is unlocked for reading. // certState is ready when its mutex is unlocked for reading.
@ -789,5 +810,10 @@ func (r *lockedMathRand) int63n(max int64) int64 {
return n return n
} }
// for easier testing // For easier testing.
var timeNow = time.Now var (
timeNow = time.Now
// Called when a state is removed.
testDidRemoveState = func(domain string) {}
)

View File

@ -5,6 +5,7 @@
package autocert package autocert
import ( import (
"context"
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
@ -22,11 +23,11 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
"sync"
"testing" "testing"
"time" "time"
"golang.org/x/crypto/acme" "golang.org/x/crypto/acme"
"golang.org/x/net/context"
) )
var discoTmpl = template.Must(template.New("disco").Parse(`{ var discoTmpl = template.Must(template.New("disco").Parse(`{
@ -51,26 +52,44 @@ var authzTmpl = template.Must(template.New("authz").Parse(`{
] ]
}`)) }`))
type memCache map[string][]byte type memCache struct {
mu sync.Mutex
keyData map[string][]byte
}
func (m memCache) Get(ctx context.Context, key string) ([]byte, error) { func (m *memCache) Get(ctx context.Context, key string) ([]byte, error) {
v, ok := m[key] m.mu.Lock()
defer m.mu.Unlock()
v, ok := m.keyData[key]
if !ok { if !ok {
return nil, ErrCacheMiss return nil, ErrCacheMiss
} }
return v, nil return v, nil
} }
func (m memCache) Put(ctx context.Context, key string, data []byte) error { func (m *memCache) Put(ctx context.Context, key string, data []byte) error {
m[key] = data m.mu.Lock()
defer m.mu.Unlock()
m.keyData[key] = data
return nil return nil
} }
func (m memCache) Delete(ctx context.Context, key string) error { func (m *memCache) Delete(ctx context.Context, key string) error {
delete(m, key) m.mu.Lock()
defer m.mu.Unlock()
delete(m.keyData, key)
return nil return nil
} }
func newMemCache() *memCache {
return &memCache{
keyData: make(map[string][]byte),
}
}
func dummyCert(pub interface{}, san ...string) ([]byte, error) { func dummyCert(pub interface{}, san ...string) ([]byte, error) {
return dateDummyCert(pub, time.Now(), time.Now().Add(90*24*time.Hour), san...) return dateDummyCert(pub, time.Now(), time.Now().Add(90*24*time.Hour), san...)
} }
@ -124,14 +143,14 @@ func TestGetCertificate_trailingDot(t *testing.T) {
func TestGetCertificate_ForceRSA(t *testing.T) { func TestGetCertificate_ForceRSA(t *testing.T) {
man := &Manager{ man := &Manager{
Prompt: AcceptTOS, Prompt: AcceptTOS,
Cache: make(memCache), Cache: newMemCache(),
ForceRSA: true, ForceRSA: true,
} }
defer man.stopRenew() defer man.stopRenew()
hello := &tls.ClientHelloInfo{ServerName: "example.org"} hello := &tls.ClientHelloInfo{ServerName: "example.org"}
testGetCertificate(t, man, "example.org", hello) testGetCertificate(t, man, "example.org", hello)
cert, err := man.cacheGet("example.org") cert, err := man.cacheGet(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("man.cacheGet: %v", err) t.Fatalf("man.cacheGet: %v", err)
} }
@ -140,9 +159,110 @@ func TestGetCertificate_ForceRSA(t *testing.T) {
} }
} }
// tests man.GetCertificate flow using the provided hello argument. func TestGetCertificate_nilPrompt(t *testing.T) {
man := &Manager{}
defer man.stopRenew()
url, finish := startACMEServerStub(t, man, "example.org")
defer finish()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
man.Client = &acme.Client{
Key: key,
DirectoryURL: url,
}
hello := &tls.ClientHelloInfo{ServerName: "example.org"}
if _, err := man.GetCertificate(hello); err == nil {
t.Error("got certificate for example.org; wanted error")
}
}
func TestGetCertificate_expiredCache(t *testing.T) {
// Make an expired cert and cache it.
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "example.org"},
NotAfter: time.Now(),
}
pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &pk.PublicKey, pk)
if err != nil {
t.Fatal(err)
}
tlscert := &tls.Certificate{
Certificate: [][]byte{pub},
PrivateKey: pk,
}
man := &Manager{Prompt: AcceptTOS, Cache: newMemCache()}
defer man.stopRenew()
if err := man.cachePut(context.Background(), "example.org", tlscert); err != nil {
t.Fatalf("man.cachePut: %v", err)
}
// The expired cached cert should trigger a new cert issuance
// and return without an error.
hello := &tls.ClientHelloInfo{ServerName: "example.org"}
testGetCertificate(t, man, "example.org", hello)
}
func TestGetCertificate_failedAttempt(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
defer ts.Close()
const example = "example.org"
d := createCertRetryAfter
f := testDidRemoveState
defer func() {
createCertRetryAfter = d
testDidRemoveState = f
}()
createCertRetryAfter = 0
done := make(chan struct{})
testDidRemoveState = func(domain string) {
if domain != example {
t.Errorf("testDidRemoveState: domain = %q; want %q", domain, example)
}
close(done)
}
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
man := &Manager{
Prompt: AcceptTOS,
Client: &acme.Client{
Key: key,
DirectoryURL: ts.URL,
},
}
defer man.stopRenew()
hello := &tls.ClientHelloInfo{ServerName: example}
if _, err := man.GetCertificate(hello); err == nil {
t.Error("GetCertificate: err is nil")
}
select {
case <-time.After(5 * time.Second):
t.Errorf("took too long to remove the %q state", example)
case <-done:
man.stateMu.Lock()
defer man.stateMu.Unlock()
if v, exist := man.state[example]; exist {
t.Errorf("state exists for %q: %+v", example, v)
}
}
}
// startACMEServerStub runs an ACME server
// The domain argument is the expected domain name of a certificate request. // The domain argument is the expected domain name of a certificate request.
func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) { func startACMEServerStub(t *testing.T, man *Manager, domain string) (url string, finish func()) {
// echo token-02 | shasum -a 256 // echo token-02 | shasum -a 256
// then divide result in 2 parts separated by dot // then divide result in 2 parts separated by dot
tokenCertName := "4e8eb87631187e9ff2153b56b13a4dec.13a35d002e485d60ff37354b32f665d9.token.acme.invalid" tokenCertName := "4e8eb87631187e9ff2153b56b13a4dec.13a35d002e485d60ff37354b32f665d9.token.acme.invalid"
@ -168,7 +288,7 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
// discovery // discovery
case "/": case "/":
if err := discoTmpl.Execute(w, ca.URL); err != nil { if err := discoTmpl.Execute(w, ca.URL); err != nil {
t.Fatalf("discoTmpl: %v", err) t.Errorf("discoTmpl: %v", err)
} }
// client key registration // client key registration
case "/new-reg": case "/new-reg":
@ -178,7 +298,7 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
w.Header().Set("location", ca.URL+"/authz/1") w.Header().Set("location", ca.URL+"/authz/1")
w.WriteHeader(http.StatusCreated) w.WriteHeader(http.StatusCreated)
if err := authzTmpl.Execute(w, ca.URL); err != nil { if err := authzTmpl.Execute(w, ca.URL); err != nil {
t.Fatalf("authzTmpl: %v", err) t.Errorf("authzTmpl: %v", err)
} }
// accept tls-sni-02 challenge // accept tls-sni-02 challenge
case "/challenge/2": case "/challenge/2":
@ -196,14 +316,14 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
b, _ := base64.RawURLEncoding.DecodeString(req.CSR) b, _ := base64.RawURLEncoding.DecodeString(req.CSR)
csr, err := x509.ParseCertificateRequest(b) csr, err := x509.ParseCertificateRequest(b)
if err != nil { if err != nil {
t.Fatalf("new-cert: CSR: %v", err) t.Errorf("new-cert: CSR: %v", err)
} }
if csr.Subject.CommonName != domain { if csr.Subject.CommonName != domain {
t.Errorf("CommonName in CSR = %q; want %q", csr.Subject.CommonName, domain) t.Errorf("CommonName in CSR = %q; want %q", csr.Subject.CommonName, domain)
} }
der, err := dummyCert(csr.PublicKey, domain) der, err := dummyCert(csr.PublicKey, domain)
if err != nil { if err != nil {
t.Fatalf("new-cert: dummyCert: %v", err) t.Errorf("new-cert: dummyCert: %v", err)
} }
chainUp := fmt.Sprintf("<%s/ca-cert>; rel=up", ca.URL) chainUp := fmt.Sprintf("<%s/ca-cert>; rel=up", ca.URL)
w.Header().Set("link", chainUp) w.Header().Set("link", chainUp)
@ -213,14 +333,51 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
case "/ca-cert": case "/ca-cert":
der, err := dummyCert(nil, "ca") der, err := dummyCert(nil, "ca")
if err != nil { if err != nil {
t.Fatalf("ca-cert: dummyCert: %v", err) t.Errorf("ca-cert: dummyCert: %v", err)
} }
w.Write(der) w.Write(der)
default: default:
t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path) t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path)
} }
})) }))
defer ca.Close() finish = func() {
ca.Close()
// make sure token cert was removed
cancel := make(chan struct{})
done := make(chan struct{})
go func() {
defer close(done)
tick := time.NewTicker(100 * time.Millisecond)
defer tick.Stop()
for {
hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
if _, err := man.GetCertificate(hello); err != nil {
return
}
select {
case <-tick.C:
case <-cancel:
return
}
}
}()
select {
case <-done:
case <-time.After(5 * time.Second):
close(cancel)
t.Error("token cert was not removed")
<-done
}
}
return ca.URL, finish
}
// tests man.GetCertificate flow using the provided hello argument.
// The domain argument is the expected domain name of a certificate request.
func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) {
url, finish := startACMEServerStub(t, man, domain)
defer finish()
// use EC key to run faster on 386 // use EC key to run faster on 386
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@ -229,7 +386,7 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
} }
man.Client = &acme.Client{ man.Client = &acme.Client{
Key: key, Key: key,
DirectoryURL: ca.URL, DirectoryURL: url,
} }
// simulate tls.Config.GetCertificate // simulate tls.Config.GetCertificate
@ -260,28 +417,10 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
t.Errorf("cert.DNSNames = %v; want %q", cert.DNSNames, domain) t.Errorf("cert.DNSNames = %v; want %q", cert.DNSNames, domain)
} }
// make sure token cert was removed
done = make(chan struct{})
go func() {
for {
hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
if _, err := man.GetCertificate(hello); err != nil {
break
}
time.Sleep(100 * time.Millisecond)
}
close(done)
}()
select {
case <-time.After(5 * time.Second):
t.Error("token cert was not removed")
case <-done:
}
} }
func TestAccountKeyCache(t *testing.T) { func TestAccountKeyCache(t *testing.T) {
cache := make(memCache) m := Manager{Cache: newMemCache()}
m := Manager{Cache: cache}
ctx := context.Background() ctx := context.Background()
k1, err := m.accountKey(ctx) k1, err := m.accountKey(ctx)
if err != nil { if err != nil {
@ -315,13 +454,13 @@ func TestCache(t *testing.T) {
PrivateKey: privKey, PrivateKey: privKey,
} }
cache := make(memCache) man := &Manager{Cache: newMemCache()}
man := &Manager{Cache: cache}
defer man.stopRenew() defer man.stopRenew()
if err := man.cachePut("example.org", tlscert); err != nil { ctx := context.Background()
if err := man.cachePut(ctx, "example.org", tlscert); err != nil {
t.Fatalf("man.cachePut: %v", err) t.Fatalf("man.cachePut: %v", err)
} }
res, err := man.cacheGet("example.org") res, err := man.cacheGet(ctx, "example.org")
if err != nil { if err != nil {
t.Fatalf("man.cacheGet: %v", err) t.Fatalf("man.cacheGet: %v", err)
} }
@ -421,3 +560,47 @@ func TestValidCert(t *testing.T) {
} }
} }
} }
type cacheGetFunc func(ctx context.Context, key string) ([]byte, error)
func (f cacheGetFunc) Get(ctx context.Context, key string) ([]byte, error) {
return f(ctx, key)
}
func (f cacheGetFunc) Put(ctx context.Context, key string, data []byte) error {
return fmt.Errorf("unsupported Put of %q = %q", key, data)
}
func (f cacheGetFunc) Delete(ctx context.Context, key string) error {
return fmt.Errorf("unsupported Delete of %q", key)
}
func TestManagerGetCertificateBogusSNI(t *testing.T) {
m := Manager{
Prompt: AcceptTOS,
Cache: cacheGetFunc(func(ctx context.Context, key string) ([]byte, error) {
return nil, fmt.Errorf("cache.Get of %s", key)
}),
}
tests := []struct {
name string
wantErr string
}{
{"foo.com", "cache.Get of foo.com"},
{"foo.com.", "cache.Get of foo.com"},
{`a\b.com`, "acme/autocert: server name contains invalid character"},
{`a/b.com`, "acme/autocert: server name contains invalid character"},
{"", "acme/autocert: missing server name"},
{"foo", "acme/autocert: server name component count invalid"},
{".foo", "acme/autocert: server name component count invalid"},
{"foo.", "acme/autocert: server name component count invalid"},
{"fo.o", "cache.Get of fo.o"},
}
for _, tt := range tests {
_, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: tt.name})
got := fmt.Sprint(err)
if got != tt.wantErr {
t.Errorf("GetCertificate(SNI = %q) = %q; want %q", tt.name, got, tt.wantErr)
}
}
}

View File

@ -5,12 +5,11 @@
package autocert package autocert
import ( import (
"context"
"errors" "errors"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"golang.org/x/net/context"
) )
// ErrCacheMiss is returned when a certificate is not found in cache. // ErrCacheMiss is returned when a certificate is not found in cache.
@ -27,7 +26,7 @@ type Cache interface {
Get(ctx context.Context, key string) ([]byte, error) Get(ctx context.Context, key string) ([]byte, error)
// Put stores the data in the cache under the specified key. // Put stores the data in the cache under the specified key.
// Inderlying implementations may use any data storage format, // Underlying implementations may use any data storage format,
// as long as the reverse operation, Get, results in the original data. // as long as the reverse operation, Get, results in the original data.
Put(ctx context.Context, key string, data []byte) error Put(ctx context.Context, key string, data []byte) error
@ -78,12 +77,13 @@ func (d DirCache) Put(ctx context.Context, name string, data []byte) error {
if tmp, err = d.writeTempFile(name, data); err != nil { if tmp, err = d.writeTempFile(name, data); err != nil {
return return
} }
// prevent overwriting the file if the context was cancelled select {
if ctx.Err() != nil { case <-ctx.Done():
return // no need to set err // Don't overwrite the file if the context was canceled.
default:
newName := filepath.Join(string(d), name)
err = os.Rename(tmp, newName)
} }
name = filepath.Join(string(d), name)
err = os.Rename(tmp, name)
}() }()
select { select {
case <-ctx.Done(): case <-ctx.Done():

View File

@ -5,13 +5,12 @@
package autocert package autocert
import ( import (
"context"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"testing" "testing"
"golang.org/x/net/context"
) )
// make sure DirCache satisfies Cache interface // make sure DirCache satisfies Cache interface

View File

@ -0,0 +1,34 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package autocert_test
import (
"crypto/tls"
"fmt"
"log"
"net/http"
"golang.org/x/crypto/acme/autocert"
)
func ExampleNewListener() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello, TLS user! Your config: %+v", r.TLS)
})
log.Fatal(http.Serve(autocert.NewListener("example.com"), mux))
}
func ExampleManager() {
m := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist("example.org"),
}
s := &http.Server{
Addr: ":https",
TLSConfig: &tls.Config{GetCertificate: m.GetCertificate},
}
s.ListenAndServeTLS("", "")
}

View File

@ -0,0 +1,153 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package autocert
import (
"crypto/tls"
"log"
"net"
"os"
"path/filepath"
"runtime"
"time"
)
// NewListener returns a net.Listener that listens on the standard TLS
// port (443) on all interfaces and returns *tls.Conn connections with
// LetsEncrypt certificates for the provided domain or domains.
//
// It enables one-line HTTPS servers:
//
// log.Fatal(http.Serve(autocert.NewListener("example.com"), handler))
//
// NewListener is a convenience function for a common configuration.
// More complex or custom configurations can use the autocert.Manager
// type instead.
//
// Use of this function implies acceptance of the LetsEncrypt Terms of
// Service. If domains is not empty, the provided domains are passed
// to HostWhitelist. If domains is empty, the listener will do
// LetsEncrypt challenges for any requested domain, which is not
// recommended.
//
// Certificates are cached in a "golang-autocert" directory under an
// operating system-specific cache or temp directory. This may not
// be suitable for servers spanning multiple machines.
//
// The returned Listener also enables TCP keep-alives on the accepted
// connections. The returned *tls.Conn are returned before their TLS
// handshake has completed.
func NewListener(domains ...string) net.Listener {
m := &Manager{
Prompt: AcceptTOS,
}
if len(domains) > 0 {
m.HostPolicy = HostWhitelist(domains...)
}
dir := cacheDir()
if err := os.MkdirAll(dir, 0700); err != nil {
log.Printf("warning: autocert.NewListener not using a cache: %v", err)
} else {
m.Cache = DirCache(dir)
}
return m.Listener()
}
// Listener listens on the standard TLS port (443) on all interfaces
// and returns a net.Listener returning *tls.Conn connections.
//
// The returned Listener also enables TCP keep-alives on the accepted
// connections. The returned *tls.Conn are returned before their TLS
// handshake has completed.
//
// Unlike NewListener, it is the caller's responsibility to initialize
// the Manager m's Prompt, Cache, HostPolicy, and other desired options.
func (m *Manager) Listener() net.Listener {
ln := &listener{
m: m,
conf: &tls.Config{
GetCertificate: m.GetCertificate, // bonus: panic on nil m
},
}
ln.tcpListener, ln.tcpListenErr = net.Listen("tcp", ":443")
return ln
}
type listener struct {
m *Manager
conf *tls.Config
tcpListener net.Listener
tcpListenErr error
}
func (ln *listener) Accept() (net.Conn, error) {
if ln.tcpListenErr != nil {
return nil, ln.tcpListenErr
}
conn, err := ln.tcpListener.Accept()
if err != nil {
return nil, err
}
tcpConn := conn.(*net.TCPConn)
// Because Listener is a convenience function, help out with
// this too. This is not possible for the caller to set once
// we return a *tcp.Conn wrapping an inaccessible net.Conn.
// If callers don't want this, they can do things the manual
// way and tweak as needed. But this is what net/http does
// itself, so copy that. If net/http changes, we can change
// here too.
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(3 * time.Minute)
return tls.Server(tcpConn, ln.conf), nil
}
func (ln *listener) Addr() net.Addr {
if ln.tcpListener != nil {
return ln.tcpListener.Addr()
}
// net.Listen failed. Return something non-nil in case callers
// call Addr before Accept:
return &net.TCPAddr{IP: net.IP{0, 0, 0, 0}, Port: 443}
}
func (ln *listener) Close() error {
if ln.tcpListenErr != nil {
return ln.tcpListenErr
}
return ln.tcpListener.Close()
}
func homeDir() string {
if runtime.GOOS == "windows" {
return os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH")
}
if h := os.Getenv("HOME"); h != "" {
return h
}
return "/"
}
func cacheDir() string {
const base = "golang-autocert"
switch runtime.GOOS {
case "darwin":
return filepath.Join(homeDir(), "Library", "Caches", base)
case "windows":
for _, ev := range []string{"APPDATA", "CSIDL_APPDATA", "TEMP", "TMP"} {
if v := os.Getenv(ev); v != "" {
return filepath.Join(v, base)
}
}
// Worst case:
return filepath.Join(homeDir(), base)
}
if xdg := os.Getenv("XDG_CACHE_HOME"); xdg != "" {
return filepath.Join(xdg, base)
}
return filepath.Join(homeDir(), ".cache", base)
}

View File

@ -5,15 +5,14 @@
package autocert package autocert
import ( import (
"context"
"crypto" "crypto"
"sync" "sync"
"time" "time"
"golang.org/x/net/context"
) )
// maxRandRenew is a maximum deviation from Manager.RenewBefore. // renewJitter is the maximum deviation from Manager.RenewBefore.
const maxRandRenew = time.Hour const renewJitter = time.Hour
// domainRenewal tracks the state used by the periodic timers // domainRenewal tracks the state used by the periodic timers
// renewing a single domain's cert. // renewing a single domain's cert.
@ -65,7 +64,7 @@ func (dr *domainRenewal) renew() {
// TODO: rotate dr.key at some point? // TODO: rotate dr.key at some point?
next, err := dr.do(ctx) next, err := dr.do(ctx)
if err != nil { if err != nil {
next = maxRandRenew / 2 next = renewJitter / 2
next += time.Duration(pseudoRand.int63n(int64(next))) next += time.Duration(pseudoRand.int63n(int64(next)))
} }
dr.timer = time.AfterFunc(next, dr.renew) dr.timer = time.AfterFunc(next, dr.renew)
@ -83,9 +82,9 @@ func (dr *domainRenewal) renew() {
func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) { func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
// a race is likely unavoidable in a distributed environment // a race is likely unavoidable in a distributed environment
// but we try nonetheless // but we try nonetheless
if tlscert, err := dr.m.cacheGet(dr.domain); err == nil { if tlscert, err := dr.m.cacheGet(ctx, dr.domain); err == nil {
next := dr.next(tlscert.Leaf.NotAfter) next := dr.next(tlscert.Leaf.NotAfter)
if next > dr.m.renewBefore()+maxRandRenew { if next > dr.m.renewBefore()+renewJitter {
return next, nil return next, nil
} }
} }
@ -103,7 +102,7 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
dr.m.cachePut(dr.domain, tlscert) dr.m.cachePut(ctx, dr.domain, tlscert)
dr.m.stateMu.Lock() dr.m.stateMu.Lock()
defer dr.m.stateMu.Unlock() defer dr.m.stateMu.Unlock()
// m.state is guaranteed to be non-nil at this point // m.state is guaranteed to be non-nil at this point
@ -114,7 +113,7 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
func (dr *domainRenewal) next(expiry time.Time) time.Duration { func (dr *domainRenewal) next(expiry time.Time) time.Duration {
d := expiry.Sub(timeNow()) - dr.m.renewBefore() d := expiry.Sub(timeNow()) - dr.m.renewBefore()
// add a bit of randomness to renew deadline // add a bit of randomness to renew deadline
n := pseudoRand.int63n(int64(maxRandRenew)) n := pseudoRand.int63n(int64(renewJitter))
d -= time.Duration(n) d -= time.Duration(n)
if d < 0 { if d < 0 {
return 0 return 0

View File

@ -5,6 +5,7 @@
package autocert package autocert
import ( import (
"context"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
@ -31,7 +32,7 @@ func TestRenewalNext(t *testing.T) {
expiry time.Time expiry time.Time
min, max time.Duration min, max time.Duration
}{ }{
{now.Add(90 * 24 * time.Hour), 83*24*time.Hour - maxRandRenew, 83 * 24 * time.Hour}, {now.Add(90 * 24 * time.Hour), 83*24*time.Hour - renewJitter, 83 * 24 * time.Hour},
{now.Add(time.Hour), 0, 1}, {now.Add(time.Hour), 0, 1},
{now, 0, 1}, {now, 0, 1},
{now.Add(-time.Hour), 0, 1}, {now.Add(-time.Hour), 0, 1},
@ -111,7 +112,7 @@ func TestRenewFromCache(t *testing.T) {
} }
man := &Manager{ man := &Manager{
Prompt: AcceptTOS, Prompt: AcceptTOS,
Cache: make(memCache), Cache: newMemCache(),
RenewBefore: 24 * time.Hour, RenewBefore: 24 * time.Hour,
Client: &acme.Client{ Client: &acme.Client{
Key: key, Key: key,
@ -127,7 +128,7 @@ func TestRenewFromCache(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
tlscert := &tls.Certificate{PrivateKey: key, Certificate: [][]byte{cert}} tlscert := &tls.Certificate{PrivateKey: key, Certificate: [][]byte{cert}}
if err := man.cachePut(domain, tlscert); err != nil { if err := man.cachePut(context.Background(), domain, tlscert); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -151,7 +152,7 @@ func TestRenewFromCache(t *testing.T) {
// ensure the new cert is cached // ensure the new cert is cached
after := time.Now().Add(future) after := time.Now().Add(future)
tlscert, err := man.cacheGet(domain) tlscert, err := man.cacheGet(context.Background(), domain)
if err != nil { if err != nil {
t.Fatalf("man.cacheGet: %v", err) t.Fatalf("man.cacheGet: %v", err)
} }

View File

@ -134,7 +134,7 @@ func jwsHasher(key crypto.Signer) (string, crypto.Hash) {
return "ES256", crypto.SHA256 return "ES256", crypto.SHA256
case "P-384": case "P-384":
return "ES384", crypto.SHA384 return "ES384", crypto.SHA384
case "P-512": case "P-521":
return "ES512", crypto.SHA512 return "ES512", crypto.SHA512
} }
} }

View File

@ -12,11 +12,13 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt"
"math/big" "math/big"
"testing" "testing"
) )
const testKeyPEM = ` const (
testKeyPEM = `
-----BEGIN RSA PRIVATE KEY----- -----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEA4xgZ3eRPkwoRvy7qeRUbmMDe0V+xH9eWLdu0iheeLlrmD2mq MIIEowIBAAKCAQEA4xgZ3eRPkwoRvy7qeRUbmMDe0V+xH9eWLdu0iheeLlrmD2mq
WXfP9IeSKApbn34g8TuAS9g5zhq8ELQ3kmjr+KV86GAMgI6VAcGlq3QrzpTCf/30 WXfP9IeSKApbn34g8TuAS9g5zhq8ELQ3kmjr+KV86GAMgI6VAcGlq3QrzpTCf/30
@ -46,10 +48,9 @@ EQeIP6dZtv8IMgtGIb91QX9pXvP0aznzQKwYIA8nZgoENCPfiMTPiEDT9e/0lObO
-----END RSA PRIVATE KEY----- -----END RSA PRIVATE KEY-----
` `
// This thumbprint is for the testKey defined above. // This thumbprint is for the testKey defined above.
const testKeyThumbprint = "6nicxzh6WETQlrvdchkz-U3e3DOQZ4heJKU63rfqMqQ" testKeyThumbprint = "6nicxzh6WETQlrvdchkz-U3e3DOQZ4heJKU63rfqMqQ"
const (
// openssl ecparam -name secp256k1 -genkey -noout // openssl ecparam -name secp256k1 -genkey -noout
testKeyECPEM = ` testKeyECPEM = `
-----BEGIN EC PRIVATE KEY----- -----BEGIN EC PRIVATE KEY-----
@ -58,39 +59,78 @@ AwEHoUQDQgAE5lhEug5xK4xBDZ2nAbaxLtaLiv85bxJ7ePd1dkO23HThqIrvawF5
QAaS/RNouybCiRhRjI3EaxLkQwgrCw0gqQ== QAaS/RNouybCiRhRjI3EaxLkQwgrCw0gqQ==
-----END EC PRIVATE KEY----- -----END EC PRIVATE KEY-----
` `
// 1. opnessl ec -in key.pem -noout -text // openssl ecparam -name secp384r1 -genkey -noout
testKeyEC384PEM = `
-----BEGIN EC PRIVATE KEY-----
MIGkAgEBBDAQ4lNtXRORWr1bgKR1CGysr9AJ9SyEk4jiVnlUWWUChmSNL+i9SLSD
Oe/naPqXJ6CgBwYFK4EEACKhZANiAAQzKtj+Ms0vHoTX5dzv3/L5YMXOWuI5UKRj
JigpahYCqXD2BA1j0E/2xt5vlPf+gm0PL+UHSQsCokGnIGuaHCsJAp3ry0gHQEke
WYXapUUFdvaK1R2/2hn5O+eiQM8YzCg=
-----END EC PRIVATE KEY-----
`
// openssl ecparam -name secp521r1 -genkey -noout
testKeyEC512PEM = `
-----BEGIN EC PRIVATE KEY-----
MIHcAgEBBEIBSNZKFcWzXzB/aJClAb305ibalKgtDA7+70eEkdPt28/3LZMM935Z
KqYHh/COcxuu3Kt8azRAUz3gyr4zZKhlKUSgBwYFK4EEACOhgYkDgYYABAHUNKbx
7JwC7H6pa2sV0tERWhHhB3JmW+OP6SUgMWryvIKajlx73eS24dy4QPGrWO9/ABsD
FqcRSkNVTXnIv6+0mAF25knqIBIg5Q8M9BnOu9GGAchcwt3O7RDHmqewnJJDrbjd
GGnm6rb+NnWR9DIopM0nKNkToWoF/hzopxu4Ae/GsQ==
-----END EC PRIVATE KEY-----
`
// 1. openssl ec -in key.pem -noout -text
// 2. remove first byte, 04 (the header); the rest is X and Y // 2. remove first byte, 04 (the header); the rest is X and Y
// 3. covert each with: echo <val> | xxd -r -p | base64 | tr -d '=' | tr '/+' '_-' // 3. convert each with: echo <val> | xxd -r -p | base64 -w 100 | tr -d '=' | tr '/+' '_-'
testKeyECPubX = "5lhEug5xK4xBDZ2nAbaxLtaLiv85bxJ7ePd1dkO23HQ" testKeyECPubX = "5lhEug5xK4xBDZ2nAbaxLtaLiv85bxJ7ePd1dkO23HQ"
testKeyECPubY = "4aiK72sBeUAGkv0TaLsmwokYUYyNxGsS5EMIKwsNIKk" testKeyECPubY = "4aiK72sBeUAGkv0TaLsmwokYUYyNxGsS5EMIKwsNIKk"
testKeyEC384PubX = "MyrY_jLNLx6E1-Xc79_y-WDFzlriOVCkYyYoKWoWAqlw9gQNY9BP9sbeb5T3_oJt"
testKeyEC384PubY = "Dy_lB0kLAqJBpyBrmhwrCQKd68tIB0BJHlmF2qVFBXb2itUdv9oZ-TvnokDPGMwo"
testKeyEC512PubX = "AdQ0pvHsnALsfqlraxXS0RFaEeEHcmZb44_pJSAxavK8gpqOXHvd5Lbh3LhA8atY738AGwMWpxFKQ1VNeci_r7SY"
testKeyEC512PubY = "AXbmSeogEiDlDwz0Gc670YYByFzC3c7tEMeap7CckkOtuN0Yaebqtv42dZH0MiikzSco2ROhagX-HOinG7gB78ax"
// echo -n '{"crv":"P-256","kty":"EC","x":"<testKeyECPubX>","y":"<testKeyECPubY>"}' | \ // echo -n '{"crv":"P-256","kty":"EC","x":"<testKeyECPubX>","y":"<testKeyECPubY>"}' | \
// openssl dgst -binary -sha256 | base64 | tr -d '=' | tr '/+' '_-' // openssl dgst -binary -sha256 | base64 | tr -d '=' | tr '/+' '_-'
testKeyECThumbprint = "zedj-Bd1Zshp8KLePv2MB-lJ_Hagp7wAwdkA0NUTniU" testKeyECThumbprint = "zedj-Bd1Zshp8KLePv2MB-lJ_Hagp7wAwdkA0NUTniU"
) )
var ( var (
testKey *rsa.PrivateKey testKey *rsa.PrivateKey
testKeyEC *ecdsa.PrivateKey testKeyEC *ecdsa.PrivateKey
testKeyEC384 *ecdsa.PrivateKey
testKeyEC512 *ecdsa.PrivateKey
) )
func init() { func init() {
d, _ := pem.Decode([]byte(testKeyPEM)) testKey = parseRSA(testKeyPEM, "testKeyPEM")
if d == nil { testKeyEC = parseEC(testKeyECPEM, "testKeyECPEM")
panic("no block found in testKeyPEM") testKeyEC384 = parseEC(testKeyEC384PEM, "testKeyEC384PEM")
} testKeyEC512 = parseEC(testKeyEC512PEM, "testKeyEC512PEM")
var err error }
testKey, err = x509.ParsePKCS1PrivateKey(d.Bytes)
if err != nil {
panic(err.Error())
}
if d, _ = pem.Decode([]byte(testKeyECPEM)); d == nil { func decodePEM(s, name string) []byte {
panic("no block found in testKeyECPEM") d, _ := pem.Decode([]byte(s))
if d == nil {
panic("no block found in " + name)
} }
testKeyEC, err = x509.ParseECPrivateKey(d.Bytes) return d.Bytes
}
func parseRSA(s, name string) *rsa.PrivateKey {
b := decodePEM(s, name)
k, err := x509.ParsePKCS1PrivateKey(b)
if err != nil { if err != nil {
panic(err.Error()) panic(fmt.Sprintf("%s: %v", name, err))
} }
return k
}
func parseEC(s, name string) *ecdsa.PrivateKey {
b := decodePEM(s, name)
k, err := x509.ParseECPrivateKey(b)
if err != nil {
panic(fmt.Sprintf("%s: %v", name, err))
}
return k
} }
func TestJWSEncodeJSON(t *testing.T) { func TestJWSEncodeJSON(t *testing.T) {
@ -141,50 +181,63 @@ func TestJWSEncodeJSON(t *testing.T) {
} }
func TestJWSEncodeJSONEC(t *testing.T) { func TestJWSEncodeJSONEC(t *testing.T) {
claims := struct{ Msg string }{"Hello JWS"} tt := []struct {
key *ecdsa.PrivateKey
x, y string
alg, crv string
}{
{testKeyEC, testKeyECPubX, testKeyECPubY, "ES256", "P-256"},
{testKeyEC384, testKeyEC384PubX, testKeyEC384PubY, "ES384", "P-384"},
{testKeyEC512, testKeyEC512PubX, testKeyEC512PubY, "ES512", "P-521"},
}
for i, test := range tt {
claims := struct{ Msg string }{"Hello JWS"}
b, err := jwsEncodeJSON(claims, test.key, "nonce")
if err != nil {
t.Errorf("%d: %v", i, err)
continue
}
var jws struct{ Protected, Payload, Signature string }
if err := json.Unmarshal(b, &jws); err != nil {
t.Errorf("%d: %v", i, err)
continue
}
b, err := jwsEncodeJSON(claims, testKeyEC, "nonce") b, err = base64.RawURLEncoding.DecodeString(jws.Protected)
if err != nil { if err != nil {
t.Fatal(err) t.Errorf("%d: jws.Protected: %v", i, err)
} }
var jws struct{ Protected, Payload, Signature string } var head struct {
if err := json.Unmarshal(b, &jws); err != nil { Alg string
t.Fatal(err) Nonce string
} JWK struct {
Crv string
if b, err = base64.RawURLEncoding.DecodeString(jws.Protected); err != nil { Kty string
t.Fatalf("jws.Protected: %v", err) X string
} Y string
var head struct { } `json:"jwk"`
Alg string }
Nonce string if err := json.Unmarshal(b, &head); err != nil {
JWK struct { t.Errorf("%d: jws.Protected: %v", i, err)
Crv string }
Kty string if head.Alg != test.alg {
X string t.Errorf("%d: head.Alg = %q; want %q", i, head.Alg, test.alg)
Y string }
} `json:"jwk"` if head.Nonce != "nonce" {
} t.Errorf("%d: head.Nonce = %q; want nonce", i, head.Nonce)
if err := json.Unmarshal(b, &head); err != nil { }
t.Fatalf("jws.Protected: %v", err) if head.JWK.Crv != test.crv {
} t.Errorf("%d: head.JWK.Crv = %q; want %q", i, head.JWK.Crv, test.crv)
if head.Alg != "ES256" { }
t.Errorf("head.Alg = %q; want ES256", head.Alg) if head.JWK.Kty != "EC" {
} t.Errorf("%d: head.JWK.Kty = %q; want EC", i, head.JWK.Kty)
if head.Nonce != "nonce" { }
t.Errorf("head.Nonce = %q; want nonce", head.Nonce) if head.JWK.X != test.x {
} t.Errorf("%d: head.JWK.X = %q; want %q", i, head.JWK.X, test.x)
if head.JWK.Crv != "P-256" { }
t.Errorf("head.JWK.Crv = %q; want P-256", head.JWK.Crv) if head.JWK.Y != test.y {
} t.Errorf("%d: head.JWK.Y = %q; want %q", i, head.JWK.Y, test.y)
if head.JWK.Kty != "EC" { }
t.Errorf("head.JWK.Kty = %q; want EC", head.JWK.Kty)
}
if head.JWK.X != testKeyECPubX {
t.Errorf("head.JWK.X = %q; want %q", head.JWK.X, testKeyECPubX)
}
if head.JWK.Y != testKeyECPubY {
t.Errorf("head.JWK.Y = %q; want %q", head.JWK.Y, testKeyECPubY)
} }
} }

View File

@ -1,9 +1,15 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package acme package acme
import ( import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"time"
) )
// ACME server response statuses used to describe Authorization and Challenge states. // ACME server response statuses used to describe Authorization and Challenge states.
@ -33,14 +39,8 @@ const (
CRLReasonAACompromise CRLReasonCode = 10 CRLReasonAACompromise CRLReasonCode = 10
) )
var ( // ErrUnsupportedKey is returned when an unsupported key type is encountered.
// ErrAuthorizationFailed indicates that an authorization for an identifier var ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported")
// did not succeed.
ErrAuthorizationFailed = errors.New("acme: identifier authorization failed")
// ErrUnsupportedKey is returned when an unsupported key type is encountered.
ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported")
)
// Error is an ACME error, defined in Problem Details for HTTP APIs doc // Error is an ACME error, defined in Problem Details for HTTP APIs doc
// http://tools.ietf.org/html/draft-ietf-appsawg-http-problem. // http://tools.ietf.org/html/draft-ietf-appsawg-http-problem.
@ -53,6 +53,7 @@ type Error struct {
// Detail is a human-readable explanation specific to this occurrence of the problem. // Detail is a human-readable explanation specific to this occurrence of the problem.
Detail string Detail string
// Header is the original server error response headers. // Header is the original server error response headers.
// It may be nil.
Header http.Header Header http.Header
} }
@ -60,6 +61,50 @@ func (e *Error) Error() string {
return fmt.Sprintf("%d %s: %s", e.StatusCode, e.ProblemType, e.Detail) return fmt.Sprintf("%d %s: %s", e.StatusCode, e.ProblemType, e.Detail)
} }
// AuthorizationError indicates that an authorization for an identifier
// did not succeed.
// It contains all errors from Challenge items of the failed Authorization.
type AuthorizationError struct {
// URI uniquely identifies the failed Authorization.
URI string
// Identifier is an AuthzID.Value of the failed Authorization.
Identifier string
// Errors is a collection of non-nil error values of Challenge items
// of the failed Authorization.
Errors []error
}
func (a *AuthorizationError) Error() string {
e := make([]string, len(a.Errors))
for i, err := range a.Errors {
e[i] = err.Error()
}
return fmt.Sprintf("acme: authorization error for %s: %s", a.Identifier, strings.Join(e, "; "))
}
// RateLimit reports whether err represents a rate limit error and
// any Retry-After duration returned by the server.
//
// See the following for more details on rate limiting:
// https://tools.ietf.org/html/draft-ietf-acme-acme-05#section-5.6
func RateLimit(err error) (time.Duration, bool) {
e, ok := err.(*Error)
if !ok {
return 0, false
}
// Some CA implementations may return incorrect values.
// Use case-insensitive comparison.
if !strings.HasSuffix(strings.ToLower(e.ProblemType), ":ratelimited") {
return 0, false
}
if e.Header == nil {
return 0, true
}
return retryAfter(e.Header.Get("Retry-After"), 0), true
}
// Account is a user account. It is associated with a private key. // Account is a user account. It is associated with a private key.
type Account struct { type Account struct {
// URI is the account unique ID, which is also a URL used to retrieve // URI is the account unique ID, which is also a URL used to retrieve
@ -118,6 +163,8 @@ type Directory struct {
} }
// Challenge encodes a returned CA challenge. // Challenge encodes a returned CA challenge.
// Its Error field may be non-nil if the challenge is part of an Authorization
// with StatusInvalid.
type Challenge struct { type Challenge struct {
// Type is the challenge type, e.g. "http-01", "tls-sni-02", "dns-01". // Type is the challenge type, e.g. "http-01", "tls-sni-02", "dns-01".
Type string Type string
@ -130,6 +177,11 @@ type Challenge struct {
// Status identifies the status of this challenge. // Status identifies the status of this challenge.
Status string Status string
// Error indicates the reason for an authorization failure
// when this challenge was used.
// The type of a non-nil value is *Error.
Error error
} }
// Authorization encodes an authorization response. // Authorization encodes an authorization response.
@ -187,12 +239,26 @@ func (z *wireAuthz) authorization(uri string) *Authorization {
return a return a
} }
func (z *wireAuthz) error(uri string) *AuthorizationError {
err := &AuthorizationError{
URI: uri,
Identifier: z.Identifier.Value,
}
for _, raw := range z.Challenges {
if raw.Error != nil {
err.Errors = append(err.Errors, raw.Error.error(nil))
}
}
return err
}
// wireChallenge is ACME JSON challenge representation. // wireChallenge is ACME JSON challenge representation.
type wireChallenge struct { type wireChallenge struct {
URI string `json:"uri"` URI string `json:"uri"`
Type string Type string
Token string Token string
Status string Status string
Error *wireError
} }
func (c *wireChallenge) challenge() *Challenge { func (c *wireChallenge) challenge() *Challenge {
@ -205,5 +271,25 @@ func (c *wireChallenge) challenge() *Challenge {
if v.Status == "" { if v.Status == "" {
v.Status = StatusPending v.Status = StatusPending
} }
if c.Error != nil {
v.Error = c.Error.error(nil)
}
return v return v
} }
// wireError is a subset of fields of the Problem Details object
// as described in https://tools.ietf.org/html/rfc7807#section-3.1.
type wireError struct {
Status int
Type string
Detail string
}
func (e *wireError) error(h http.Header) *Error {
return &Error{
StatusCode: e.Status,
ProblemType: e.Type,
Detail: e.Detail,
Header: h,
}
}

View File

@ -0,0 +1,63 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package acme
import (
"errors"
"net/http"
"testing"
"time"
)
func TestRateLimit(t *testing.T) {
now := time.Date(2017, 04, 27, 10, 0, 0, 0, time.UTC)
f := timeNow
defer func() { timeNow = f }()
timeNow = func() time.Time { return now }
h120, hTime := http.Header{}, http.Header{}
h120.Set("Retry-After", "120")
hTime.Set("Retry-After", "Tue Apr 27 11:00:00 2017")
err1 := &Error{
ProblemType: "urn:ietf:params:acme:error:nolimit",
Header: h120,
}
err2 := &Error{
ProblemType: "urn:ietf:params:acme:error:rateLimited",
Header: h120,
}
err3 := &Error{
ProblemType: "urn:ietf:params:acme:error:rateLimited",
Header: nil,
}
err4 := &Error{
ProblemType: "urn:ietf:params:acme:error:rateLimited",
Header: hTime,
}
tt := []struct {
err error
res time.Duration
ok bool
}{
{nil, 0, false},
{errors.New("dummy"), 0, false},
{err1, 0, false},
{err2, 2 * time.Minute, true},
{err3, 0, true},
{err4, time.Hour, true},
}
for i, test := range tt {
res, ok := RateLimit(test.err)
if ok != test.ok {
t.Errorf("%d: RateLimit(%+v): ok = %v; want %v", i, test.err, ok, test.ok)
continue
}
if res != test.res {
t.Errorf("%d: RateLimit(%+v) = %v; want %v", i, test.err, res, test.res)
}
}
}

View File

@ -4,7 +4,7 @@
// Package blake2b implements the BLAKE2b hash algorithm as // Package blake2b implements the BLAKE2b hash algorithm as
// defined in RFC 7693. // defined in RFC 7693.
package blake2b package blake2b // import "golang.org/x/crypto/blake2b"
import ( import (
"encoding/binary" "encoding/binary"
@ -23,6 +23,12 @@ const (
Size256 = 32 Size256 = 32
) )
var (
useAVX2 bool
useAVX bool
useSSE4 bool
)
var errKeySize = errors.New("blake2b: invalid key size") var errKeySize = errors.New("blake2b: invalid key size")
var iv = [8]uint64{ var iv = [8]uint64{

View File

@ -6,24 +6,35 @@
package blake2b package blake2b
var useAVX2 = supportAVX2() func init() {
var useSSE4 = supportSSE4() useAVX2 = supportsAVX2()
useAVX = supportsAVX()
useSSE4 = supportsSSE4()
}
//go:noescape //go:noescape
func supportSSE4() bool func supportsSSE4() bool
//go:noescape //go:noescape
func supportAVX2() bool func supportsAVX() bool
//go:noescape
func supportsAVX2() bool
//go:noescape //go:noescape
func hashBlocksAVX2(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte) func hashBlocksAVX2(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte)
//go:noescape
func hashBlocksAVX(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte)
//go:noescape //go:noescape
func hashBlocksSSE4(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte) func hashBlocksSSE4(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte)
func hashBlocks(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte) { func hashBlocks(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte) {
if useAVX2 { if useAVX2 {
hashBlocksAVX2(h, c, flag, blocks) hashBlocksAVX2(h, c, flag, blocks)
} else if useAVX {
hashBlocksAVX(h, c, flag, blocks)
} else if useSSE4 { } else if useSSE4 {
hashBlocksSSE4(h, c, flag, blocks) hashBlocksSSE4(h, c, flag, blocks)
} else { } else {

View File

@ -6,92 +6,271 @@
#include "textflag.h" #include "textflag.h"
DATA ·AVX2_iv0<>+0x00(SB)/8, $0x6a09e667f3bcc908
DATA ·AVX2_iv0<>+0x08(SB)/8, $0xbb67ae8584caa73b
DATA ·AVX2_iv0<>+0x10(SB)/8, $0x3c6ef372fe94f82b
DATA ·AVX2_iv0<>+0x18(SB)/8, $0xa54ff53a5f1d36f1
GLOBL ·AVX2_iv0<>(SB), (NOPTR+RODATA), $32
DATA ·AVX2_iv1<>+0x00(SB)/8, $0x510e527fade682d1
DATA ·AVX2_iv1<>+0x08(SB)/8, $0x9b05688c2b3e6c1f
DATA ·AVX2_iv1<>+0x10(SB)/8, $0x1f83d9abfb41bd6b
DATA ·AVX2_iv1<>+0x18(SB)/8, $0x5be0cd19137e2179
GLOBL ·AVX2_iv1<>(SB), (NOPTR+RODATA), $32
DATA ·AVX2_c40<>+0x00(SB)/8, $0x0201000706050403
DATA ·AVX2_c40<>+0x08(SB)/8, $0x0a09080f0e0d0c0b
DATA ·AVX2_c40<>+0x10(SB)/8, $0x0201000706050403
DATA ·AVX2_c40<>+0x18(SB)/8, $0x0a09080f0e0d0c0b
GLOBL ·AVX2_c40<>(SB), (NOPTR+RODATA), $32
DATA ·AVX2_c48<>+0x00(SB)/8, $0x0100070605040302
DATA ·AVX2_c48<>+0x08(SB)/8, $0x09080f0e0d0c0b0a
DATA ·AVX2_c48<>+0x10(SB)/8, $0x0100070605040302
DATA ·AVX2_c48<>+0x18(SB)/8, $0x09080f0e0d0c0b0a
GLOBL ·AVX2_c48<>(SB), (NOPTR+RODATA), $32
DATA ·AVX_iv0<>+0x00(SB)/8, $0x6a09e667f3bcc908 DATA ·AVX_iv0<>+0x00(SB)/8, $0x6a09e667f3bcc908
DATA ·AVX_iv0<>+0x08(SB)/8, $0xbb67ae8584caa73b DATA ·AVX_iv0<>+0x08(SB)/8, $0xbb67ae8584caa73b
DATA ·AVX_iv0<>+0x10(SB)/8, $0x3c6ef372fe94f82b GLOBL ·AVX_iv0<>(SB), (NOPTR+RODATA), $16
DATA ·AVX_iv0<>+0x18(SB)/8, $0xa54ff53a5f1d36f1
GLOBL ·AVX_iv0<>(SB), (NOPTR+RODATA), $32
DATA ·AVX_iv1<>+0x00(SB)/8, $0x510e527fade682d1 DATA ·AVX_iv1<>+0x00(SB)/8, $0x3c6ef372fe94f82b
DATA ·AVX_iv1<>+0x08(SB)/8, $0x9b05688c2b3e6c1f DATA ·AVX_iv1<>+0x08(SB)/8, $0xa54ff53a5f1d36f1
DATA ·AVX_iv1<>+0x10(SB)/8, $0x1f83d9abfb41bd6b GLOBL ·AVX_iv1<>(SB), (NOPTR+RODATA), $16
DATA ·AVX_iv1<>+0x18(SB)/8, $0x5be0cd19137e2179
GLOBL ·AVX_iv1<>(SB), (NOPTR+RODATA), $32 DATA ·AVX_iv2<>+0x00(SB)/8, $0x510e527fade682d1
DATA ·AVX_iv2<>+0x08(SB)/8, $0x9b05688c2b3e6c1f
GLOBL ·AVX_iv2<>(SB), (NOPTR+RODATA), $16
DATA ·AVX_iv3<>+0x00(SB)/8, $0x1f83d9abfb41bd6b
DATA ·AVX_iv3<>+0x08(SB)/8, $0x5be0cd19137e2179
GLOBL ·AVX_iv3<>(SB), (NOPTR+RODATA), $16
DATA ·AVX_c40<>+0x00(SB)/8, $0x0201000706050403 DATA ·AVX_c40<>+0x00(SB)/8, $0x0201000706050403
DATA ·AVX_c40<>+0x08(SB)/8, $0x0a09080f0e0d0c0b DATA ·AVX_c40<>+0x08(SB)/8, $0x0a09080f0e0d0c0b
DATA ·AVX_c40<>+0x10(SB)/8, $0x0201000706050403 GLOBL ·AVX_c40<>(SB), (NOPTR+RODATA), $16
DATA ·AVX_c40<>+0x18(SB)/8, $0x0a09080f0e0d0c0b
GLOBL ·AVX_c40<>(SB), (NOPTR+RODATA), $32
DATA ·AVX_c48<>+0x00(SB)/8, $0x0100070605040302 DATA ·AVX_c48<>+0x00(SB)/8, $0x0100070605040302
DATA ·AVX_c48<>+0x08(SB)/8, $0x09080f0e0d0c0b0a DATA ·AVX_c48<>+0x08(SB)/8, $0x09080f0e0d0c0b0a
DATA ·AVX_c48<>+0x10(SB)/8, $0x0100070605040302 GLOBL ·AVX_c48<>(SB), (NOPTR+RODATA), $16
DATA ·AVX_c48<>+0x18(SB)/8, $0x09080f0e0d0c0b0a
GLOBL ·AVX_c48<>(SB), (NOPTR+RODATA), $32
// unfortunately the BYTE representation of VPERMQ must be used #define VPERMQ_0x39_Y1_Y1 BYTE $0xc4; BYTE $0xe3; BYTE $0xfd; BYTE $0x00; BYTE $0xc9; BYTE $0x39
#define ROUND(m0, m1, m2, m3, t, c40, c48) \ #define VPERMQ_0x93_Y1_Y1 BYTE $0xc4; BYTE $0xe3; BYTE $0xfd; BYTE $0x00; BYTE $0xc9; BYTE $0x93
VPADDQ m0, Y0, Y0; \ #define VPERMQ_0x4E_Y2_Y2 BYTE $0xc4; BYTE $0xe3; BYTE $0xfd; BYTE $0x00; BYTE $0xd2; BYTE $0x4e
VPADDQ Y1, Y0, Y0; \ #define VPERMQ_0x93_Y3_Y3 BYTE $0xc4; BYTE $0xe3; BYTE $0xfd; BYTE $0x00; BYTE $0xdb; BYTE $0x93
VPXOR Y0, Y3, Y3; \ #define VPERMQ_0x39_Y3_Y3 BYTE $0xc4; BYTE $0xe3; BYTE $0xfd; BYTE $0x00; BYTE $0xdb; BYTE $0x39
VPSHUFD $-79, Y3, Y3; \
VPADDQ Y3, Y2, Y2; \
VPXOR Y2, Y1, Y1; \
VPSHUFB c40, Y1, Y1; \
VPADDQ m1, Y0, Y0; \
VPADDQ Y1, Y0, Y0; \
VPXOR Y0, Y3, Y3; \
VPSHUFB c48, Y3, Y3; \
VPADDQ Y3, Y2, Y2; \
VPXOR Y2, Y1, Y1; \
VPADDQ Y1, Y1, t; \
VPSRLQ $63, Y1, Y1; \
VPXOR t, Y1, Y1; \
BYTE $0xc4; BYTE $0xe3; BYTE $0xfd; BYTE $0x00; BYTE $0xc9; BYTE $0x39 \ // VPERMQ 0x39, Y1, Y1
BYTE $0xc4; BYTE $0xe3; BYTE $0xfd; BYTE $0x00; BYTE $0xd2; BYTE $0x4e \ // VPERMQ 0x4e, Y2, Y2
BYTE $0xc4; BYTE $0xe3; BYTE $0xfd; BYTE $0x00; BYTE $0xdb; BYTE $0x93 \ // VPERMQ 0x93, Y3, Y3
VPADDQ m2, Y0, Y0; \
VPADDQ Y1, Y0, Y0; \
VPXOR Y0, Y3, Y3; \
VPSHUFD $-79, Y3, Y3; \
VPADDQ Y3, Y2, Y2; \
VPXOR Y2, Y1, Y1; \
VPSHUFB c40, Y1, Y1; \
VPADDQ m3, Y0, Y0; \
VPADDQ Y1, Y0, Y0; \
VPXOR Y0, Y3, Y3; \
VPSHUFB c48, Y3, Y3; \
VPADDQ Y3, Y2, Y2; \
VPXOR Y2, Y1, Y1; \
VPADDQ Y1, Y1, t; \
VPSRLQ $63, Y1, Y1; \
VPXOR t, Y1, Y1; \
BYTE $0xc4; BYTE $0xe3; BYTE $0xfd; BYTE $0x00; BYTE $0xdb; BYTE $0x39 \ // VPERMQ 0x39, Y3, Y3
BYTE $0xc4; BYTE $0xe3; BYTE $0xfd; BYTE $0x00; BYTE $0xd2; BYTE $0x4e \ // VPERMQ 0x4e, Y2, Y2
BYTE $0xc4; BYTE $0xe3; BYTE $0xfd; BYTE $0x00; BYTE $0xc9; BYTE $0x93 \ // VPERMQ 0x93, Y1, Y1
// load msg into Y12, Y13, Y14, Y15 #define ROUND_AVX2(m0, m1, m2, m3, t, c40, c48) \
#define LOAD_MSG(src, i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, i14, i15) \ VPADDQ m0, Y0, Y0; \
MOVQ i0*8(src), X12; \ VPADDQ Y1, Y0, Y0; \
PINSRQ $1, i1*8(src), X12; \ VPXOR Y0, Y3, Y3; \
MOVQ i2*8(src), X11; \ VPSHUFD $-79, Y3, Y3; \
PINSRQ $1, i3*8(src), X11; \ VPADDQ Y3, Y2, Y2; \
VINSERTI128 $1, X11, Y12, Y12; \ VPXOR Y2, Y1, Y1; \
MOVQ i4*8(src), X13; \ VPSHUFB c40, Y1, Y1; \
PINSRQ $1, i5*8(src), X13; \ VPADDQ m1, Y0, Y0; \
MOVQ i6*8(src), X11; \ VPADDQ Y1, Y0, Y0; \
PINSRQ $1, i7*8(src), X11; \ VPXOR Y0, Y3, Y3; \
VINSERTI128 $1, X11, Y13, Y13; \ VPSHUFB c48, Y3, Y3; \
MOVQ i8*8(src), X14; \ VPADDQ Y3, Y2, Y2; \
PINSRQ $1, i9*8(src), X14; \ VPXOR Y2, Y1, Y1; \
MOVQ i10*8(src), X11; \ VPADDQ Y1, Y1, t; \
PINSRQ $1, i11*8(src), X11; \ VPSRLQ $63, Y1, Y1; \
VPXOR t, Y1, Y1; \
VPERMQ_0x39_Y1_Y1; \
VPERMQ_0x4E_Y2_Y2; \
VPERMQ_0x93_Y3_Y3; \
VPADDQ m2, Y0, Y0; \
VPADDQ Y1, Y0, Y0; \
VPXOR Y0, Y3, Y3; \
VPSHUFD $-79, Y3, Y3; \
VPADDQ Y3, Y2, Y2; \
VPXOR Y2, Y1, Y1; \
VPSHUFB c40, Y1, Y1; \
VPADDQ m3, Y0, Y0; \
VPADDQ Y1, Y0, Y0; \
VPXOR Y0, Y3, Y3; \
VPSHUFB c48, Y3, Y3; \
VPADDQ Y3, Y2, Y2; \
VPXOR Y2, Y1, Y1; \
VPADDQ Y1, Y1, t; \
VPSRLQ $63, Y1, Y1; \
VPXOR t, Y1, Y1; \
VPERMQ_0x39_Y3_Y3; \
VPERMQ_0x4E_Y2_Y2; \
VPERMQ_0x93_Y1_Y1
#define VMOVQ_SI_X11_0 BYTE $0xC5; BYTE $0x7A; BYTE $0x7E; BYTE $0x1E
#define VMOVQ_SI_X12_0 BYTE $0xC5; BYTE $0x7A; BYTE $0x7E; BYTE $0x26
#define VMOVQ_SI_X13_0 BYTE $0xC5; BYTE $0x7A; BYTE $0x7E; BYTE $0x2E
#define VMOVQ_SI_X14_0 BYTE $0xC5; BYTE $0x7A; BYTE $0x7E; BYTE $0x36
#define VMOVQ_SI_X15_0 BYTE $0xC5; BYTE $0x7A; BYTE $0x7E; BYTE $0x3E
#define VMOVQ_SI_X11(n) BYTE $0xC5; BYTE $0x7A; BYTE $0x7E; BYTE $0x5E; BYTE $n
#define VMOVQ_SI_X12(n) BYTE $0xC5; BYTE $0x7A; BYTE $0x7E; BYTE $0x66; BYTE $n
#define VMOVQ_SI_X13(n) BYTE $0xC5; BYTE $0x7A; BYTE $0x7E; BYTE $0x6E; BYTE $n
#define VMOVQ_SI_X14(n) BYTE $0xC5; BYTE $0x7A; BYTE $0x7E; BYTE $0x76; BYTE $n
#define VMOVQ_SI_X15(n) BYTE $0xC5; BYTE $0x7A; BYTE $0x7E; BYTE $0x7E; BYTE $n
#define VPINSRQ_1_SI_X11_0 BYTE $0xC4; BYTE $0x63; BYTE $0xA1; BYTE $0x22; BYTE $0x1E; BYTE $0x01
#define VPINSRQ_1_SI_X12_0 BYTE $0xC4; BYTE $0x63; BYTE $0x99; BYTE $0x22; BYTE $0x26; BYTE $0x01
#define VPINSRQ_1_SI_X13_0 BYTE $0xC4; BYTE $0x63; BYTE $0x91; BYTE $0x22; BYTE $0x2E; BYTE $0x01
#define VPINSRQ_1_SI_X14_0 BYTE $0xC4; BYTE $0x63; BYTE $0x89; BYTE $0x22; BYTE $0x36; BYTE $0x01
#define VPINSRQ_1_SI_X15_0 BYTE $0xC4; BYTE $0x63; BYTE $0x81; BYTE $0x22; BYTE $0x3E; BYTE $0x01
#define VPINSRQ_1_SI_X11(n) BYTE $0xC4; BYTE $0x63; BYTE $0xA1; BYTE $0x22; BYTE $0x5E; BYTE $n; BYTE $0x01
#define VPINSRQ_1_SI_X12(n) BYTE $0xC4; BYTE $0x63; BYTE $0x99; BYTE $0x22; BYTE $0x66; BYTE $n; BYTE $0x01
#define VPINSRQ_1_SI_X13(n) BYTE $0xC4; BYTE $0x63; BYTE $0x91; BYTE $0x22; BYTE $0x6E; BYTE $n; BYTE $0x01
#define VPINSRQ_1_SI_X14(n) BYTE $0xC4; BYTE $0x63; BYTE $0x89; BYTE $0x22; BYTE $0x76; BYTE $n; BYTE $0x01
#define VPINSRQ_1_SI_X15(n) BYTE $0xC4; BYTE $0x63; BYTE $0x81; BYTE $0x22; BYTE $0x7E; BYTE $n; BYTE $0x01
#define VMOVQ_R8_X15 BYTE $0xC4; BYTE $0x41; BYTE $0xF9; BYTE $0x6E; BYTE $0xF8
#define VPINSRQ_1_R9_X15 BYTE $0xC4; BYTE $0x43; BYTE $0x81; BYTE $0x22; BYTE $0xF9; BYTE $0x01
// load msg: Y12 = (i0, i1, i2, i3)
// i0, i1, i2, i3 must not be 0
#define LOAD_MSG_AVX2_Y12(i0, i1, i2, i3) \
VMOVQ_SI_X12(i0*8); \
VMOVQ_SI_X11(i2*8); \
VPINSRQ_1_SI_X12(i1*8); \
VPINSRQ_1_SI_X11(i3*8); \
VINSERTI128 $1, X11, Y12, Y12
// load msg: Y13 = (i0, i1, i2, i3)
// i0, i1, i2, i3 must not be 0
#define LOAD_MSG_AVX2_Y13(i0, i1, i2, i3) \
VMOVQ_SI_X13(i0*8); \
VMOVQ_SI_X11(i2*8); \
VPINSRQ_1_SI_X13(i1*8); \
VPINSRQ_1_SI_X11(i3*8); \
VINSERTI128 $1, X11, Y13, Y13
// load msg: Y14 = (i0, i1, i2, i3)
// i0, i1, i2, i3 must not be 0
#define LOAD_MSG_AVX2_Y14(i0, i1, i2, i3) \
VMOVQ_SI_X14(i0*8); \
VMOVQ_SI_X11(i2*8); \
VPINSRQ_1_SI_X14(i1*8); \
VPINSRQ_1_SI_X11(i3*8); \
VINSERTI128 $1, X11, Y14, Y14
// load msg: Y15 = (i0, i1, i2, i3)
// i0, i1, i2, i3 must not be 0
#define LOAD_MSG_AVX2_Y15(i0, i1, i2, i3) \
VMOVQ_SI_X15(i0*8); \
VMOVQ_SI_X11(i2*8); \
VPINSRQ_1_SI_X15(i1*8); \
VPINSRQ_1_SI_X11(i3*8); \
VINSERTI128 $1, X11, Y15, Y15
#define LOAD_MSG_AVX2_0_2_4_6_1_3_5_7_8_10_12_14_9_11_13_15() \
VMOVQ_SI_X12_0; \
VMOVQ_SI_X11(4*8); \
VPINSRQ_1_SI_X12(2*8); \
VPINSRQ_1_SI_X11(6*8); \
VINSERTI128 $1, X11, Y12, Y12; \
LOAD_MSG_AVX2_Y13(1, 3, 5, 7); \
LOAD_MSG_AVX2_Y14(8, 10, 12, 14); \
LOAD_MSG_AVX2_Y15(9, 11, 13, 15)
#define LOAD_MSG_AVX2_14_4_9_13_10_8_15_6_1_0_11_5_12_2_7_3() \
LOAD_MSG_AVX2_Y12(14, 4, 9, 13); \
LOAD_MSG_AVX2_Y13(10, 8, 15, 6); \
VMOVQ_SI_X11(11*8); \
VPSHUFD $0x4E, 0*8(SI), X14; \
VPINSRQ_1_SI_X11(5*8); \
VINSERTI128 $1, X11, Y14, Y14; \ VINSERTI128 $1, X11, Y14, Y14; \
MOVQ i12*8(src), X15; \ LOAD_MSG_AVX2_Y15(12, 2, 7, 3)
PINSRQ $1, i13*8(src), X15; \
MOVQ i14*8(src), X11; \ #define LOAD_MSG_AVX2_11_12_5_15_8_0_2_13_10_3_7_9_14_6_1_4() \
PINSRQ $1, i15*8(src), X11; \ VMOVQ_SI_X11(5*8); \
VMOVDQU 11*8(SI), X12; \
VPINSRQ_1_SI_X11(15*8); \
VINSERTI128 $1, X11, Y12, Y12; \
VMOVQ_SI_X13(8*8); \
VMOVQ_SI_X11(2*8); \
VPINSRQ_1_SI_X13_0; \
VPINSRQ_1_SI_X11(13*8); \
VINSERTI128 $1, X11, Y13, Y13; \
LOAD_MSG_AVX2_Y14(10, 3, 7, 9); \
LOAD_MSG_AVX2_Y15(14, 6, 1, 4)
#define LOAD_MSG_AVX2_7_3_13_11_9_1_12_14_2_5_4_15_6_10_0_8() \
LOAD_MSG_AVX2_Y12(7, 3, 13, 11); \
LOAD_MSG_AVX2_Y13(9, 1, 12, 14); \
LOAD_MSG_AVX2_Y14(2, 5, 4, 15); \
VMOVQ_SI_X15(6*8); \
VMOVQ_SI_X11_0; \
VPINSRQ_1_SI_X15(10*8); \
VPINSRQ_1_SI_X11(8*8); \
VINSERTI128 $1, X11, Y15, Y15
#define LOAD_MSG_AVX2_9_5_2_10_0_7_4_15_14_11_6_3_1_12_8_13() \
LOAD_MSG_AVX2_Y12(9, 5, 2, 10); \
VMOVQ_SI_X13_0; \
VMOVQ_SI_X11(4*8); \
VPINSRQ_1_SI_X13(7*8); \
VPINSRQ_1_SI_X11(15*8); \
VINSERTI128 $1, X11, Y13, Y13; \
LOAD_MSG_AVX2_Y14(14, 11, 6, 3); \
LOAD_MSG_AVX2_Y15(1, 12, 8, 13)
#define LOAD_MSG_AVX2_2_6_0_8_12_10_11_3_4_7_15_1_13_5_14_9() \
VMOVQ_SI_X12(2*8); \
VMOVQ_SI_X11_0; \
VPINSRQ_1_SI_X12(6*8); \
VPINSRQ_1_SI_X11(8*8); \
VINSERTI128 $1, X11, Y12, Y12; \
LOAD_MSG_AVX2_Y13(12, 10, 11, 3); \
LOAD_MSG_AVX2_Y14(4, 7, 15, 1); \
LOAD_MSG_AVX2_Y15(13, 5, 14, 9)
#define LOAD_MSG_AVX2_12_1_14_4_5_15_13_10_0_6_9_8_7_3_2_11() \
LOAD_MSG_AVX2_Y12(12, 1, 14, 4); \
LOAD_MSG_AVX2_Y13(5, 15, 13, 10); \
VMOVQ_SI_X14_0; \
VPSHUFD $0x4E, 8*8(SI), X11; \
VPINSRQ_1_SI_X14(6*8); \
VINSERTI128 $1, X11, Y14, Y14; \
LOAD_MSG_AVX2_Y15(7, 3, 2, 11)
#define LOAD_MSG_AVX2_13_7_12_3_11_14_1_9_5_15_8_2_0_4_6_10() \
LOAD_MSG_AVX2_Y12(13, 7, 12, 3); \
LOAD_MSG_AVX2_Y13(11, 14, 1, 9); \
LOAD_MSG_AVX2_Y14(5, 15, 8, 2); \
VMOVQ_SI_X15_0; \
VMOVQ_SI_X11(6*8); \
VPINSRQ_1_SI_X15(4*8); \
VPINSRQ_1_SI_X11(10*8); \
VINSERTI128 $1, X11, Y15, Y15
#define LOAD_MSG_AVX2_6_14_11_0_15_9_3_8_12_13_1_10_2_7_4_5() \
VMOVQ_SI_X12(6*8); \
VMOVQ_SI_X11(11*8); \
VPINSRQ_1_SI_X12(14*8); \
VPINSRQ_1_SI_X11_0; \
VINSERTI128 $1, X11, Y12, Y12; \
LOAD_MSG_AVX2_Y13(15, 9, 3, 8); \
VMOVQ_SI_X11(1*8); \
VMOVDQU 12*8(SI), X14; \
VPINSRQ_1_SI_X11(10*8); \
VINSERTI128 $1, X11, Y14, Y14; \
VMOVQ_SI_X15(2*8); \
VMOVDQU 4*8(SI), X11; \
VPINSRQ_1_SI_X15(7*8); \
VINSERTI128 $1, X11, Y15, Y15
#define LOAD_MSG_AVX2_10_8_7_1_2_4_6_5_15_9_3_13_11_14_12_0() \
LOAD_MSG_AVX2_Y12(10, 8, 7, 1); \
VMOVQ_SI_X13(2*8); \
VPSHUFD $0x4E, 5*8(SI), X11; \
VPINSRQ_1_SI_X13(4*8); \
VINSERTI128 $1, X11, Y13, Y13; \
LOAD_MSG_AVX2_Y14(15, 9, 3, 13); \
VMOVQ_SI_X15(11*8); \
VMOVQ_SI_X11(12*8); \
VPINSRQ_1_SI_X15(14*8); \
VPINSRQ_1_SI_X11_0; \
VINSERTI128 $1, X11, Y15, Y15 VINSERTI128 $1, X11, Y15, Y15
// func hashBlocksAVX2(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte) // func hashBlocksAVX2(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte)
@ -112,13 +291,13 @@ TEXT ·hashBlocksAVX2(SB), 4, $320-48 // frame size = 288 + 32 byte alignment
XORQ CX, CX XORQ CX, CX
MOVQ CX, 24(SP) MOVQ CX, 24(SP)
VMOVDQU ·AVX_c40<>(SB), Y4 VMOVDQU ·AVX2_c40<>(SB), Y4
VMOVDQU ·AVX_c48<>(SB), Y5 VMOVDQU ·AVX2_c48<>(SB), Y5
VMOVDQU 0(AX), Y8 VMOVDQU 0(AX), Y8
VMOVDQU 32(AX), Y9 VMOVDQU 32(AX), Y9
VMOVDQU ·AVX_iv0<>(SB), Y6 VMOVDQU ·AVX2_iv0<>(SB), Y6
VMOVDQU ·AVX_iv1<>(SB), Y7 VMOVDQU ·AVX2_iv1<>(SB), Y7
MOVQ 0(BX), R8 MOVQ 0(BX), R8
MOVQ 8(BX), R9 MOVQ 8(BX), R9
@ -135,41 +314,41 @@ loop:
noinc: noinc:
VMOVDQA Y8, Y0 VMOVDQA Y8, Y0
VMOVDQA Y9, Y1 VMOVDQA Y9, Y1
VMOVDQU Y6, Y2 VMOVDQA Y6, Y2
VPXOR 0(SP), Y7, Y3 VPXOR 0(SP), Y7, Y3
LOAD_MSG(SI, 0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15) LOAD_MSG_AVX2_0_2_4_6_1_3_5_7_8_10_12_14_9_11_13_15()
VMOVDQA Y12, 32(SP) VMOVDQA Y12, 32(SP)
VMOVDQA Y13, 64(SP) VMOVDQA Y13, 64(SP)
VMOVDQA Y14, 96(SP) VMOVDQA Y14, 96(SP)
VMOVDQA Y15, 128(SP) VMOVDQA Y15, 128(SP)
ROUND(Y12, Y13, Y14, Y15, Y10, Y4, Y5) ROUND_AVX2(Y12, Y13, Y14, Y15, Y10, Y4, Y5)
LOAD_MSG(SI, 14, 4, 9, 13, 10, 8, 15, 6, 1, 0, 11, 5, 12, 2, 7, 3) LOAD_MSG_AVX2_14_4_9_13_10_8_15_6_1_0_11_5_12_2_7_3()
VMOVDQA Y12, 160(SP) VMOVDQA Y12, 160(SP)
VMOVDQA Y13, 192(SP) VMOVDQA Y13, 192(SP)
VMOVDQA Y14, 224(SP) VMOVDQA Y14, 224(SP)
VMOVDQA Y15, 256(SP) VMOVDQA Y15, 256(SP)
ROUND(Y12, Y13, Y14, Y15, Y10, Y4, Y5) ROUND_AVX2(Y12, Y13, Y14, Y15, Y10, Y4, Y5)
LOAD_MSG(SI, 11, 12, 5, 15, 8, 0, 2, 13, 10, 3, 7, 9, 14, 6, 1, 4) LOAD_MSG_AVX2_11_12_5_15_8_0_2_13_10_3_7_9_14_6_1_4()
ROUND(Y12, Y13, Y14, Y15, Y10, Y4, Y5) ROUND_AVX2(Y12, Y13, Y14, Y15, Y10, Y4, Y5)
LOAD_MSG(SI, 7, 3, 13, 11, 9, 1, 12, 14, 2, 5, 4, 15, 6, 10, 0, 8) LOAD_MSG_AVX2_7_3_13_11_9_1_12_14_2_5_4_15_6_10_0_8()
ROUND(Y12, Y13, Y14, Y15, Y10, Y4, Y5) ROUND_AVX2(Y12, Y13, Y14, Y15, Y10, Y4, Y5)
LOAD_MSG(SI, 9, 5, 2, 10, 0, 7, 4, 15, 14, 11, 6, 3, 1, 12, 8, 13) LOAD_MSG_AVX2_9_5_2_10_0_7_4_15_14_11_6_3_1_12_8_13()
ROUND(Y12, Y13, Y14, Y15, Y10, Y4, Y5) ROUND_AVX2(Y12, Y13, Y14, Y15, Y10, Y4, Y5)
LOAD_MSG(SI, 2, 6, 0, 8, 12, 10, 11, 3, 4, 7, 15, 1, 13, 5, 14, 9) LOAD_MSG_AVX2_2_6_0_8_12_10_11_3_4_7_15_1_13_5_14_9()
ROUND(Y12, Y13, Y14, Y15, Y10, Y4, Y5) ROUND_AVX2(Y12, Y13, Y14, Y15, Y10, Y4, Y5)
LOAD_MSG(SI, 12, 1, 14, 4, 5, 15, 13, 10, 0, 6, 9, 8, 7, 3, 2, 11) LOAD_MSG_AVX2_12_1_14_4_5_15_13_10_0_6_9_8_7_3_2_11()
ROUND(Y12, Y13, Y14, Y15, Y10, Y4, Y5) ROUND_AVX2(Y12, Y13, Y14, Y15, Y10, Y4, Y5)
LOAD_MSG(SI, 13, 7, 12, 3, 11, 14, 1, 9, 5, 15, 8, 2, 0, 4, 6, 10) LOAD_MSG_AVX2_13_7_12_3_11_14_1_9_5_15_8_2_0_4_6_10()
ROUND(Y12, Y13, Y14, Y15, Y10, Y4, Y5) ROUND_AVX2(Y12, Y13, Y14, Y15, Y10, Y4, Y5)
LOAD_MSG(SI, 6, 14, 11, 0, 15, 9, 3, 8, 12, 13, 1, 10, 2, 7, 4, 5) LOAD_MSG_AVX2_6_14_11_0_15_9_3_8_12_13_1_10_2_7_4_5()
ROUND(Y12, Y13, Y14, Y15, Y10, Y4, Y5) ROUND_AVX2(Y12, Y13, Y14, Y15, Y10, Y4, Y5)
LOAD_MSG(SI, 10, 8, 7, 1, 2, 4, 6, 5, 15, 9, 3, 13, 11, 14, 12, 0) LOAD_MSG_AVX2_10_8_7_1_2_4_6_5_15_9_3_13_11_14_12_0()
ROUND(Y12, Y13, Y14, Y15, Y10, Y4, Y5) ROUND_AVX2(Y12, Y13, Y14, Y15, Y10, Y4, Y5)
ROUND(32(SP), 64(SP), 96(SP), 128(SP), Y10, Y4, Y5) ROUND_AVX2(32(SP), 64(SP), 96(SP), 128(SP), Y10, Y4, Y5)
ROUND(160(SP), 192(SP), 224(SP), 256(SP), Y10, Y4, Y5) ROUND_AVX2(160(SP), 192(SP), 224(SP), 256(SP), Y10, Y4, Y5)
VPXOR Y0, Y8, Y8 VPXOR Y0, Y8, Y8
VPXOR Y1, Y9, Y9 VPXOR Y1, Y9, Y9
@ -185,12 +364,399 @@ noinc:
VMOVDQU Y8, 0(AX) VMOVDQU Y8, 0(AX)
VMOVDQU Y9, 32(AX) VMOVDQU Y9, 32(AX)
VZEROUPPER
MOVQ DX, SP MOVQ DX, SP
RET RET
// func supportAVX2() bool #define VPUNPCKLQDQ_X2_X2_X15 BYTE $0xC5; BYTE $0x69; BYTE $0x6C; BYTE $0xFA
TEXT ·supportAVX2(SB), 4, $0-1 #define VPUNPCKLQDQ_X3_X3_X15 BYTE $0xC5; BYTE $0x61; BYTE $0x6C; BYTE $0xFB
#define VPUNPCKLQDQ_X7_X7_X15 BYTE $0xC5; BYTE $0x41; BYTE $0x6C; BYTE $0xFF
#define VPUNPCKLQDQ_X13_X13_X15 BYTE $0xC4; BYTE $0x41; BYTE $0x11; BYTE $0x6C; BYTE $0xFD
#define VPUNPCKLQDQ_X14_X14_X15 BYTE $0xC4; BYTE $0x41; BYTE $0x09; BYTE $0x6C; BYTE $0xFE
#define VPUNPCKHQDQ_X15_X2_X2 BYTE $0xC4; BYTE $0xC1; BYTE $0x69; BYTE $0x6D; BYTE $0xD7
#define VPUNPCKHQDQ_X15_X3_X3 BYTE $0xC4; BYTE $0xC1; BYTE $0x61; BYTE $0x6D; BYTE $0xDF
#define VPUNPCKHQDQ_X15_X6_X6 BYTE $0xC4; BYTE $0xC1; BYTE $0x49; BYTE $0x6D; BYTE $0xF7
#define VPUNPCKHQDQ_X15_X7_X7 BYTE $0xC4; BYTE $0xC1; BYTE $0x41; BYTE $0x6D; BYTE $0xFF
#define VPUNPCKHQDQ_X15_X3_X2 BYTE $0xC4; BYTE $0xC1; BYTE $0x61; BYTE $0x6D; BYTE $0xD7
#define VPUNPCKHQDQ_X15_X7_X6 BYTE $0xC4; BYTE $0xC1; BYTE $0x41; BYTE $0x6D; BYTE $0xF7
#define VPUNPCKHQDQ_X15_X13_X3 BYTE $0xC4; BYTE $0xC1; BYTE $0x11; BYTE $0x6D; BYTE $0xDF
#define VPUNPCKHQDQ_X15_X13_X7 BYTE $0xC4; BYTE $0xC1; BYTE $0x11; BYTE $0x6D; BYTE $0xFF
#define SHUFFLE_AVX() \
VMOVDQA X6, X13; \
VMOVDQA X2, X14; \
VMOVDQA X4, X6; \
VPUNPCKLQDQ_X13_X13_X15; \
VMOVDQA X5, X4; \
VMOVDQA X6, X5; \
VPUNPCKHQDQ_X15_X7_X6; \
VPUNPCKLQDQ_X7_X7_X15; \
VPUNPCKHQDQ_X15_X13_X7; \
VPUNPCKLQDQ_X3_X3_X15; \
VPUNPCKHQDQ_X15_X2_X2; \
VPUNPCKLQDQ_X14_X14_X15; \
VPUNPCKHQDQ_X15_X3_X3; \
#define SHUFFLE_AVX_INV() \
VMOVDQA X2, X13; \
VMOVDQA X4, X14; \
VPUNPCKLQDQ_X2_X2_X15; \
VMOVDQA X5, X4; \
VPUNPCKHQDQ_X15_X3_X2; \
VMOVDQA X14, X5; \
VPUNPCKLQDQ_X3_X3_X15; \
VMOVDQA X6, X14; \
VPUNPCKHQDQ_X15_X13_X3; \
VPUNPCKLQDQ_X7_X7_X15; \
VPUNPCKHQDQ_X15_X6_X6; \
VPUNPCKLQDQ_X14_X14_X15; \
VPUNPCKHQDQ_X15_X7_X7; \
#define HALF_ROUND_AVX(v0, v1, v2, v3, v4, v5, v6, v7, m0, m1, m2, m3, t0, c40, c48) \
VPADDQ m0, v0, v0; \
VPADDQ v2, v0, v0; \
VPADDQ m1, v1, v1; \
VPADDQ v3, v1, v1; \
VPXOR v0, v6, v6; \
VPXOR v1, v7, v7; \
VPSHUFD $-79, v6, v6; \
VPSHUFD $-79, v7, v7; \
VPADDQ v6, v4, v4; \
VPADDQ v7, v5, v5; \
VPXOR v4, v2, v2; \
VPXOR v5, v3, v3; \
VPSHUFB c40, v2, v2; \
VPSHUFB c40, v3, v3; \
VPADDQ m2, v0, v0; \
VPADDQ v2, v0, v0; \
VPADDQ m3, v1, v1; \
VPADDQ v3, v1, v1; \
VPXOR v0, v6, v6; \
VPXOR v1, v7, v7; \
VPSHUFB c48, v6, v6; \
VPSHUFB c48, v7, v7; \
VPADDQ v6, v4, v4; \
VPADDQ v7, v5, v5; \
VPXOR v4, v2, v2; \
VPXOR v5, v3, v3; \
VPADDQ v2, v2, t0; \
VPSRLQ $63, v2, v2; \
VPXOR t0, v2, v2; \
VPADDQ v3, v3, t0; \
VPSRLQ $63, v3, v3; \
VPXOR t0, v3, v3
// load msg: X12 = (i0, i1), X13 = (i2, i3), X14 = (i4, i5), X15 = (i6, i7)
// i0, i1, i2, i3, i4, i5, i6, i7 must not be 0
#define LOAD_MSG_AVX(i0, i1, i2, i3, i4, i5, i6, i7) \
VMOVQ_SI_X12(i0*8); \
VMOVQ_SI_X13(i2*8); \
VMOVQ_SI_X14(i4*8); \
VMOVQ_SI_X15(i6*8); \
VPINSRQ_1_SI_X12(i1*8); \
VPINSRQ_1_SI_X13(i3*8); \
VPINSRQ_1_SI_X14(i5*8); \
VPINSRQ_1_SI_X15(i7*8)
// load msg: X12 = (0, 2), X13 = (4, 6), X14 = (1, 3), X15 = (5, 7)
#define LOAD_MSG_AVX_0_2_4_6_1_3_5_7() \
VMOVQ_SI_X12_0; \
VMOVQ_SI_X13(4*8); \
VMOVQ_SI_X14(1*8); \
VMOVQ_SI_X15(5*8); \
VPINSRQ_1_SI_X12(2*8); \
VPINSRQ_1_SI_X13(6*8); \
VPINSRQ_1_SI_X14(3*8); \
VPINSRQ_1_SI_X15(7*8)
// load msg: X12 = (1, 0), X13 = (11, 5), X14 = (12, 2), X15 = (7, 3)
#define LOAD_MSG_AVX_1_0_11_5_12_2_7_3() \
VPSHUFD $0x4E, 0*8(SI), X12; \
VMOVQ_SI_X13(11*8); \
VMOVQ_SI_X14(12*8); \
VMOVQ_SI_X15(7*8); \
VPINSRQ_1_SI_X13(5*8); \
VPINSRQ_1_SI_X14(2*8); \
VPINSRQ_1_SI_X15(3*8)
// load msg: X12 = (11, 12), X13 = (5, 15), X14 = (8, 0), X15 = (2, 13)
#define LOAD_MSG_AVX_11_12_5_15_8_0_2_13() \
VMOVDQU 11*8(SI), X12; \
VMOVQ_SI_X13(5*8); \
VMOVQ_SI_X14(8*8); \
VMOVQ_SI_X15(2*8); \
VPINSRQ_1_SI_X13(15*8); \
VPINSRQ_1_SI_X14_0; \
VPINSRQ_1_SI_X15(13*8)
// load msg: X12 = (2, 5), X13 = (4, 15), X14 = (6, 10), X15 = (0, 8)
#define LOAD_MSG_AVX_2_5_4_15_6_10_0_8() \
VMOVQ_SI_X12(2*8); \
VMOVQ_SI_X13(4*8); \
VMOVQ_SI_X14(6*8); \
VMOVQ_SI_X15_0; \
VPINSRQ_1_SI_X12(5*8); \
VPINSRQ_1_SI_X13(15*8); \
VPINSRQ_1_SI_X14(10*8); \
VPINSRQ_1_SI_X15(8*8)
// load msg: X12 = (9, 5), X13 = (2, 10), X14 = (0, 7), X15 = (4, 15)
#define LOAD_MSG_AVX_9_5_2_10_0_7_4_15() \
VMOVQ_SI_X12(9*8); \
VMOVQ_SI_X13(2*8); \
VMOVQ_SI_X14_0; \
VMOVQ_SI_X15(4*8); \
VPINSRQ_1_SI_X12(5*8); \
VPINSRQ_1_SI_X13(10*8); \
VPINSRQ_1_SI_X14(7*8); \
VPINSRQ_1_SI_X15(15*8)
// load msg: X12 = (2, 6), X13 = (0, 8), X14 = (12, 10), X15 = (11, 3)
#define LOAD_MSG_AVX_2_6_0_8_12_10_11_3() \
VMOVQ_SI_X12(2*8); \
VMOVQ_SI_X13_0; \
VMOVQ_SI_X14(12*8); \
VMOVQ_SI_X15(11*8); \
VPINSRQ_1_SI_X12(6*8); \
VPINSRQ_1_SI_X13(8*8); \
VPINSRQ_1_SI_X14(10*8); \
VPINSRQ_1_SI_X15(3*8)
// load msg: X12 = (0, 6), X13 = (9, 8), X14 = (7, 3), X15 = (2, 11)
#define LOAD_MSG_AVX_0_6_9_8_7_3_2_11() \
MOVQ 0*8(SI), X12; \
VPSHUFD $0x4E, 8*8(SI), X13; \
MOVQ 7*8(SI), X14; \
MOVQ 2*8(SI), X15; \
VPINSRQ_1_SI_X12(6*8); \
VPINSRQ_1_SI_X14(3*8); \
VPINSRQ_1_SI_X15(11*8)
// load msg: X12 = (6, 14), X13 = (11, 0), X14 = (15, 9), X15 = (3, 8)
#define LOAD_MSG_AVX_6_14_11_0_15_9_3_8() \
MOVQ 6*8(SI), X12; \
MOVQ 11*8(SI), X13; \
MOVQ 15*8(SI), X14; \
MOVQ 3*8(SI), X15; \
VPINSRQ_1_SI_X12(14*8); \
VPINSRQ_1_SI_X13_0; \
VPINSRQ_1_SI_X14(9*8); \
VPINSRQ_1_SI_X15(8*8)
// load msg: X12 = (5, 15), X13 = (8, 2), X14 = (0, 4), X15 = (6, 10)
#define LOAD_MSG_AVX_5_15_8_2_0_4_6_10() \
MOVQ 5*8(SI), X12; \
MOVQ 8*8(SI), X13; \
MOVQ 0*8(SI), X14; \
MOVQ 6*8(SI), X15; \
VPINSRQ_1_SI_X12(15*8); \
VPINSRQ_1_SI_X13(2*8); \
VPINSRQ_1_SI_X14(4*8); \
VPINSRQ_1_SI_X15(10*8)
// load msg: X12 = (12, 13), X13 = (1, 10), X14 = (2, 7), X15 = (4, 5)
#define LOAD_MSG_AVX_12_13_1_10_2_7_4_5() \
VMOVDQU 12*8(SI), X12; \
MOVQ 1*8(SI), X13; \
MOVQ 2*8(SI), X14; \
VPINSRQ_1_SI_X13(10*8); \
VPINSRQ_1_SI_X14(7*8); \
VMOVDQU 4*8(SI), X15
// load msg: X12 = (15, 9), X13 = (3, 13), X14 = (11, 14), X15 = (12, 0)
#define LOAD_MSG_AVX_15_9_3_13_11_14_12_0() \
MOVQ 15*8(SI), X12; \
MOVQ 3*8(SI), X13; \
MOVQ 11*8(SI), X14; \
MOVQ 12*8(SI), X15; \
VPINSRQ_1_SI_X12(9*8); \
VPINSRQ_1_SI_X13(13*8); \
VPINSRQ_1_SI_X14(14*8); \
VPINSRQ_1_SI_X15_0
// func hashBlocksAVX(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte)
TEXT ·hashBlocksAVX(SB), 4, $288-48 // frame size = 272 + 16 byte alignment
MOVQ h+0(FP), AX
MOVQ c+8(FP), BX
MOVQ flag+16(FP), CX
MOVQ blocks_base+24(FP), SI
MOVQ blocks_len+32(FP), DI
MOVQ SP, BP
MOVQ SP, R9
ADDQ $15, R9
ANDQ $~15, R9
MOVQ R9, SP
VMOVDQU ·AVX_c40<>(SB), X0
VMOVDQU ·AVX_c48<>(SB), X1
VMOVDQA X0, X8
VMOVDQA X1, X9
VMOVDQU ·AVX_iv3<>(SB), X0
VMOVDQA X0, 0(SP)
XORQ CX, 0(SP) // 0(SP) = ·AVX_iv3 ^ (CX || 0)
VMOVDQU 0(AX), X10
VMOVDQU 16(AX), X11
VMOVDQU 32(AX), X2
VMOVDQU 48(AX), X3
MOVQ 0(BX), R8
MOVQ 8(BX), R9
loop:
ADDQ $128, R8
CMPQ R8, $128
JGE noinc
INCQ R9
noinc:
VMOVQ_R8_X15
VPINSRQ_1_R9_X15
VMOVDQA X10, X0
VMOVDQA X11, X1
VMOVDQU ·AVX_iv0<>(SB), X4
VMOVDQU ·AVX_iv1<>(SB), X5
VMOVDQU ·AVX_iv2<>(SB), X6
VPXOR X15, X6, X6
VMOVDQA 0(SP), X7
LOAD_MSG_AVX_0_2_4_6_1_3_5_7()
VMOVDQA X12, 16(SP)
VMOVDQA X13, 32(SP)
VMOVDQA X14, 48(SP)
VMOVDQA X15, 64(SP)
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX()
LOAD_MSG_AVX(8, 10, 12, 14, 9, 11, 13, 15)
VMOVDQA X12, 80(SP)
VMOVDQA X13, 96(SP)
VMOVDQA X14, 112(SP)
VMOVDQA X15, 128(SP)
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX_INV()
LOAD_MSG_AVX(14, 4, 9, 13, 10, 8, 15, 6)
VMOVDQA X12, 144(SP)
VMOVDQA X13, 160(SP)
VMOVDQA X14, 176(SP)
VMOVDQA X15, 192(SP)
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX()
LOAD_MSG_AVX_1_0_11_5_12_2_7_3()
VMOVDQA X12, 208(SP)
VMOVDQA X13, 224(SP)
VMOVDQA X14, 240(SP)
VMOVDQA X15, 256(SP)
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX_INV()
LOAD_MSG_AVX_11_12_5_15_8_0_2_13()
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX()
LOAD_MSG_AVX(10, 3, 7, 9, 14, 6, 1, 4)
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX_INV()
LOAD_MSG_AVX(7, 3, 13, 11, 9, 1, 12, 14)
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX()
LOAD_MSG_AVX_2_5_4_15_6_10_0_8()
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX_INV()
LOAD_MSG_AVX_9_5_2_10_0_7_4_15()
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX()
LOAD_MSG_AVX(14, 11, 6, 3, 1, 12, 8, 13)
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX_INV()
LOAD_MSG_AVX_2_6_0_8_12_10_11_3()
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX()
LOAD_MSG_AVX(4, 7, 15, 1, 13, 5, 14, 9)
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX_INV()
LOAD_MSG_AVX(12, 1, 14, 4, 5, 15, 13, 10)
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX()
LOAD_MSG_AVX_0_6_9_8_7_3_2_11()
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX_INV()
LOAD_MSG_AVX(13, 7, 12, 3, 11, 14, 1, 9)
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX()
LOAD_MSG_AVX_5_15_8_2_0_4_6_10()
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX_INV()
LOAD_MSG_AVX_6_14_11_0_15_9_3_8()
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX()
LOAD_MSG_AVX_12_13_1_10_2_7_4_5()
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX_INV()
LOAD_MSG_AVX(10, 8, 7, 1, 2, 4, 6, 5)
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX()
LOAD_MSG_AVX_15_9_3_13_11_14_12_0()
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, X12, X13, X14, X15, X15, X8, X9)
SHUFFLE_AVX_INV()
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, 16(SP), 32(SP), 48(SP), 64(SP), X15, X8, X9)
SHUFFLE_AVX()
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, 80(SP), 96(SP), 112(SP), 128(SP), X15, X8, X9)
SHUFFLE_AVX_INV()
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, 144(SP), 160(SP), 176(SP), 192(SP), X15, X8, X9)
SHUFFLE_AVX()
HALF_ROUND_AVX(X0, X1, X2, X3, X4, X5, X6, X7, 208(SP), 224(SP), 240(SP), 256(SP), X15, X8, X9)
SHUFFLE_AVX_INV()
VMOVDQU 32(AX), X14
VMOVDQU 48(AX), X15
VPXOR X0, X10, X10
VPXOR X1, X11, X11
VPXOR X2, X14, X14
VPXOR X3, X15, X15
VPXOR X4, X10, X10
VPXOR X5, X11, X11
VPXOR X6, X14, X2
VPXOR X7, X15, X3
VMOVDQU X2, 32(AX)
VMOVDQU X3, 48(AX)
LEAQ 128(SI), SI
SUBQ $128, DI
JNE loop
VMOVDQU X10, 0(AX)
VMOVDQU X11, 16(AX)
MOVQ R8, 0(BX)
MOVQ R9, 8(BX)
VZEROUPPER
MOVQ BP, SP
RET
// func supportsAVX2() bool
TEXT ·supportsAVX2(SB), 4, $0-1
MOVQ runtime·support_avx2(SB), AX MOVQ runtime·support_avx2(SB), AX
MOVB AX, ret+0(FP) MOVB AX, ret+0(FP)
RET RET
// func supportsAVX() bool
TEXT ·supportsAVX(SB), 4, $0-1
MOVQ runtime·support_avx(SB), AX
MOVB AX, ret+0(FP)
RET

View File

@ -6,11 +6,12 @@
package blake2b package blake2b
var useAVX2 = false func init() {
var useSSE4 = supportSSE4() useSSE4 = supportsSSE4()
}
//go:noescape //go:noescape
func supportSSE4() bool func supportsSSE4() bool
//go:noescape //go:noescape
func hashBlocksSSE4(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte) func hashBlocksSSE4(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte)

View File

@ -30,16 +30,16 @@ DATA ·c48<>+0x00(SB)/8, $0x0100070605040302
DATA ·c48<>+0x08(SB)/8, $0x09080f0e0d0c0b0a DATA ·c48<>+0x08(SB)/8, $0x09080f0e0d0c0b0a
GLOBL ·c48<>(SB), (NOPTR+RODATA), $16 GLOBL ·c48<>(SB), (NOPTR+RODATA), $16
#define SHUFFLE(v2, v3, v4, v5, v6, v7, t0, t1, t2) \ #define SHUFFLE(v2, v3, v4, v5, v6, v7, t1, t2) \
MOVO v4, t0; \ MOVO v4, t1; \
MOVO v5, v4; \ MOVO v5, v4; \
MOVO t0, v5; \ MOVO t1, v5; \
MOVO v6, t0; \ MOVO v6, t1; \
PUNPCKLQDQ v6, t2; \ PUNPCKLQDQ v6, t2; \
PUNPCKHQDQ v7, v6; \ PUNPCKHQDQ v7, v6; \
PUNPCKHQDQ t2, v6; \ PUNPCKHQDQ t2, v6; \
PUNPCKLQDQ v7, t2; \ PUNPCKLQDQ v7, t2; \
MOVO t0, v7; \ MOVO t1, v7; \
MOVO v2, t1; \ MOVO v2, t1; \
PUNPCKHQDQ t2, v7; \ PUNPCKHQDQ t2, v7; \
PUNPCKLQDQ v3, t2; \ PUNPCKLQDQ v3, t2; \
@ -47,16 +47,16 @@ GLOBL ·c48<>(SB), (NOPTR+RODATA), $16
PUNPCKLQDQ t1, t2; \ PUNPCKLQDQ t1, t2; \
PUNPCKHQDQ t2, v3 PUNPCKHQDQ t2, v3
#define SHUFFLE_INV(v2, v3, v4, v5, v6, v7, t0, t1, t2) \ #define SHUFFLE_INV(v2, v3, v4, v5, v6, v7, t1, t2) \
MOVO v4, t0; \ MOVO v4, t1; \
MOVO v5, v4; \ MOVO v5, v4; \
MOVO t0, v5; \ MOVO t1, v5; \
MOVO v2, t0; \ MOVO v2, t1; \
PUNPCKLQDQ v2, t2; \ PUNPCKLQDQ v2, t2; \
PUNPCKHQDQ v3, v2; \ PUNPCKHQDQ v3, v2; \
PUNPCKHQDQ t2, v2; \ PUNPCKHQDQ t2, v2; \
PUNPCKLQDQ v3, t2; \ PUNPCKLQDQ v3, t2; \
MOVO t0, v3; \ MOVO t1, v3; \
MOVO v6, t1; \ MOVO v6, t1; \
PUNPCKHQDQ t2, v3; \ PUNPCKHQDQ t2, v3; \
PUNPCKLQDQ v7, t2; \ PUNPCKLQDQ v7, t2; \
@ -64,7 +64,7 @@ GLOBL ·c48<>(SB), (NOPTR+RODATA), $16
PUNPCKLQDQ t1, t2; \ PUNPCKLQDQ t1, t2; \
PUNPCKHQDQ t2, v7 PUNPCKHQDQ t2, v7
#define HALF_ROUND(v0, v1, v2, v3, v4, v5, v6, v7, m0, m1, m2, m3, t0, t1, t2, c40, c48) \ #define HALF_ROUND(v0, v1, v2, v3, v4, v5, v6, v7, m0, m1, m2, m3, t0, c40, c48) \
PADDQ m0, v0; \ PADDQ m0, v0; \
PADDQ m1, v1; \ PADDQ m1, v1; \
PADDQ v2, v0; \ PADDQ v2, v0; \
@ -91,14 +91,14 @@ GLOBL ·c48<>(SB), (NOPTR+RODATA), $16
PADDQ v7, v5; \ PADDQ v7, v5; \
PXOR v4, v2; \ PXOR v4, v2; \
PXOR v5, v3; \ PXOR v5, v3; \
MOVOU v2, t2; \ MOVOU v2, t0; \
PADDQ v2, t2; \ PADDQ v2, t0; \
PSRLQ $63, v2; \ PSRLQ $63, v2; \
PXOR t2, v2; \ PXOR t0, v2; \
MOVOU v3, t2; \ MOVOU v3, t0; \
PADDQ v3, t2; \ PADDQ v3, t0; \
PSRLQ $63, v3; \ PSRLQ $63, v3; \
PXOR t2, v3 PXOR t0, v3
#define LOAD_MSG(m0, m1, m2, m3, src, i0, i1, i2, i3, i4, i5, i6, i7) \ #define LOAD_MSG(m0, m1, m2, m3, src, i0, i1, i2, i3, i4, i5, i6, i7) \
MOVQ i0*8(src), m0; \ MOVQ i0*8(src), m0; \
@ -111,7 +111,7 @@ GLOBL ·c48<>(SB), (NOPTR+RODATA), $16
PINSRQ $1, i7*8(src), m3 PINSRQ $1, i7*8(src), m3
// func hashBlocksSSE4(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte) // func hashBlocksSSE4(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte)
TEXT ·hashBlocksSSE4(SB), 4, $32-48 // frame size = 16 + 16 byte alignment TEXT ·hashBlocksSSE4(SB), 4, $288-48 // frame size = 272 + 16 byte alignment
MOVQ h+0(FP), AX MOVQ h+0(FP), AX
MOVQ c+8(FP), BX MOVQ c+8(FP), BX
MOVQ flag+16(FP), CX MOVQ flag+16(FP), CX
@ -131,6 +131,9 @@ TEXT ·hashBlocksSSE4(SB), 4, $32-48 // frame size = 16 + 16 byte alignment
MOVOU ·c40<>(SB), X13 MOVOU ·c40<>(SB), X13
MOVOU ·c48<>(SB), X14 MOVOU ·c48<>(SB), X14
MOVOU 0(AX), X12
MOVOU 16(AX), X15
MOVQ 0(BX), R8 MOVQ 0(BX), R8
MOVQ 8(BX), R9 MOVQ 8(BX), R9
@ -141,118 +144,126 @@ loop:
INCQ R9 INCQ R9
noinc: noinc:
MOVQ R8, X15 MOVQ R8, X8
PINSRQ $1, R9, X15 PINSRQ $1, R9, X8
MOVOU 0(AX), X0 MOVO X12, X0
MOVOU 16(AX), X1 MOVO X15, X1
MOVOU 32(AX), X2 MOVOU 32(AX), X2
MOVOU 48(AX), X3 MOVOU 48(AX), X3
MOVOU ·iv0<>(SB), X4 MOVOU ·iv0<>(SB), X4
MOVOU ·iv1<>(SB), X5 MOVOU ·iv1<>(SB), X5
MOVOU ·iv2<>(SB), X6 MOVOU ·iv2<>(SB), X6
PXOR X15, X6 PXOR X8, X6
MOVO 0(SP), X7 MOVO 0(SP), X7
LOAD_MSG(X8, X9, X10, X11, SI, 0, 2, 4, 6, 1, 3, 5, 7) LOAD_MSG(X8, X9, X10, X11, SI, 0, 2, 4, 6, 1, 3, 5, 7)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) MOVO X8, 16(SP)
SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9, X10) MOVO X9, 32(SP)
MOVO X10, 48(SP)
MOVO X11, 64(SP)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 8, 10, 12, 14, 9, 11, 13, 15) LOAD_MSG(X8, X9, X10, X11, SI, 8, 10, 12, 14, 9, 11, 13, 15)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) MOVO X8, 80(SP)
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9, X10) MOVO X9, 96(SP)
MOVO X10, 112(SP)
MOVO X11, 128(SP)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 14, 4, 9, 13, 10, 8, 15, 6) LOAD_MSG(X8, X9, X10, X11, SI, 14, 4, 9, 13, 10, 8, 15, 6)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) MOVO X8, 144(SP)
SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9, X10) MOVO X9, 160(SP)
MOVO X10, 176(SP)
MOVO X11, 192(SP)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 1, 0, 11, 5, 12, 2, 7, 3) LOAD_MSG(X8, X9, X10, X11, SI, 1, 0, 11, 5, 12, 2, 7, 3)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) MOVO X8, 208(SP)
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9, X10) MOVO X9, 224(SP)
MOVO X10, 240(SP)
MOVO X11, 256(SP)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 11, 12, 5, 15, 8, 0, 2, 13) LOAD_MSG(X8, X9, X10, X11, SI, 11, 12, 5, 15, 8, 0, 2, 13)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9, X10) SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 10, 3, 7, 9, 14, 6, 1, 4) LOAD_MSG(X8, X9, X10, X11, SI, 10, 3, 7, 9, 14, 6, 1, 4)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9, X10) SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 7, 3, 13, 11, 9, 1, 12, 14) LOAD_MSG(X8, X9, X10, X11, SI, 7, 3, 13, 11, 9, 1, 12, 14)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9, X10) SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 2, 5, 4, 15, 6, 10, 0, 8) LOAD_MSG(X8, X9, X10, X11, SI, 2, 5, 4, 15, 6, 10, 0, 8)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9, X10) SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 9, 5, 2, 10, 0, 7, 4, 15) LOAD_MSG(X8, X9, X10, X11, SI, 9, 5, 2, 10, 0, 7, 4, 15)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9, X10) SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 14, 11, 6, 3, 1, 12, 8, 13) LOAD_MSG(X8, X9, X10, X11, SI, 14, 11, 6, 3, 1, 12, 8, 13)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9, X10) SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 2, 6, 0, 8, 12, 10, 11, 3) LOAD_MSG(X8, X9, X10, X11, SI, 2, 6, 0, 8, 12, 10, 11, 3)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9, X10) SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 4, 7, 15, 1, 13, 5, 14, 9) LOAD_MSG(X8, X9, X10, X11, SI, 4, 7, 15, 1, 13, 5, 14, 9)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9, X10) SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 12, 1, 14, 4, 5, 15, 13, 10) LOAD_MSG(X8, X9, X10, X11, SI, 12, 1, 14, 4, 5, 15, 13, 10)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9, X10) SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 0, 6, 9, 8, 7, 3, 2, 11) LOAD_MSG(X8, X9, X10, X11, SI, 0, 6, 9, 8, 7, 3, 2, 11)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9, X10) SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 13, 7, 12, 3, 11, 14, 1, 9) LOAD_MSG(X8, X9, X10, X11, SI, 13, 7, 12, 3, 11, 14, 1, 9)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9, X10) SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 5, 15, 8, 2, 0, 4, 6, 10) LOAD_MSG(X8, X9, X10, X11, SI, 5, 15, 8, 2, 0, 4, 6, 10)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9, X10) SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 6, 14, 11, 0, 15, 9, 3, 8) LOAD_MSG(X8, X9, X10, X11, SI, 6, 14, 11, 0, 15, 9, 3, 8)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9, X10) SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 12, 13, 1, 10, 2, 7, 4, 5) LOAD_MSG(X8, X9, X10, X11, SI, 12, 13, 1, 10, 2, 7, 4, 5)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9, X10) SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 10, 8, 7, 1, 2, 4, 6, 5) LOAD_MSG(X8, X9, X10, X11, SI, 10, 8, 7, 1, 2, 4, 6, 5)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9, X10) SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 15, 9, 3, 13, 11, 14, 12, 0) LOAD_MSG(X8, X9, X10, X11, SI, 15, 9, 3, 13, 11, 14, 12, 0)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X11, X13, X14)
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9, X10) SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9)
LOAD_MSG(X8, X9, X10, X11, SI, 0, 2, 4, 6, 1, 3, 5, 7) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, 16(SP), 32(SP), 48(SP), 64(SP), X11, X13, X14)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9)
SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9, X10) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, 80(SP), 96(SP), 112(SP), 128(SP), X11, X13, X14)
LOAD_MSG(X8, X9, X10, X11, SI, 8, 10, 12, 14, 9, 11, 13, 15) SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14)
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9, X10)
LOAD_MSG(X8, X9, X10, X11, SI, 14, 4, 9, 13, 10, 8, 15, 6) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, 144(SP), 160(SP), 176(SP), 192(SP), X11, X13, X14)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14) SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9)
SHUFFLE(X2, X3, X4, X5, X6, X7, X8, X9, X10) HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, 208(SP), 224(SP), 240(SP), 256(SP), X11, X13, X14)
LOAD_MSG(X8, X9, X10, X11, SI, 1, 0, 11, 5, 12, 2, 7, 3) SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9)
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X8, X9, X12, X13, X14)
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, X8, X9, X10)
MOVOU 0(AX), X8
MOVOU 16(AX), X9
MOVOU 32(AX), X10 MOVOU 32(AX), X10
MOVOU 48(AX), X11 MOVOU 48(AX), X11
PXOR X0, X8 PXOR X0, X12
PXOR X1, X9 PXOR X1, X15
PXOR X2, X10 PXOR X2, X10
PXOR X3, X11 PXOR X3, X11
PXOR X4, X8 PXOR X4, X12
PXOR X5, X9 PXOR X5, X15
PXOR X6, X10 PXOR X6, X10
PXOR X7, X11 PXOR X7, X11
MOVOU X8, 0(AX)
MOVOU X9, 16(AX)
MOVOU X10, 32(AX) MOVOU X10, 32(AX)
MOVOU X11, 48(AX) MOVOU X11, 48(AX)
@ -260,13 +271,17 @@ noinc:
SUBQ $128, DI SUBQ $128, DI
JNE loop JNE loop
MOVOU X15, 0(BX) MOVOU X12, 0(AX)
MOVOU X15, 16(AX)
MOVQ R8, 0(BX)
MOVQ R9, 8(BX)
MOVQ BP, SP MOVQ BP, SP
RET RET
// func supportSSE4() bool // func supportsSSE4() bool
TEXT ·supportSSE4(SB), 4, $0-1 TEXT ·supportsSSE4(SB), 4, $0-1
MOVL $1, AX MOVL $1, AX
CPUID CPUID
SHRL $19, CX // Bit 19 indicates SSE4 support SHRL $19, CX // Bit 19 indicates SSE4 support

View File

@ -6,9 +6,6 @@
package blake2b package blake2b
var useAVX2 = false
var useSSE4 = false
func hashBlocks(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte) { func hashBlocks(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte) {
hashBlocksGeneric(h, c, flag, blocks) hashBlocksGeneric(h, c, flag, blocks)
} }

View File

@ -21,15 +21,20 @@ func fromHex(s string) []byte {
} }
func TestHashes(t *testing.T) { func TestHashes(t *testing.T) {
defer func(sse4, avx2 bool) { defer func(sse4, avx, avx2 bool) {
useSSE4, useAVX2 = sse4, avx2 useSSE4, useAVX, useAVX2 = sse4, avx, avx2
}(useSSE4, useAVX2) }(useSSE4, useAVX, useAVX2)
if useAVX2 { if useAVX2 {
t.Log("AVX2 version") t.Log("AVX2 version")
testHashes(t) testHashes(t)
useAVX2 = false useAVX2 = false
} }
if useAVX {
t.Log("AVX version")
testHashes(t)
useAVX = false
}
if useSSE4 { if useSSE4 {
t.Log("SSE4 version") t.Log("SSE4 version")
testHashes(t) testHashes(t)

View File

@ -0,0 +1,32 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.9
package blake2b
import (
"crypto"
"hash"
)
func init() {
newHash256 := func() hash.Hash {
h, _ := New256(nil)
return h
}
newHash384 := func() hash.Hash {
h, _ := New384(nil)
return h
}
newHash512 := func() hash.Hash {
h, _ := New512(nil)
return h
}
crypto.RegisterHash(crypto.BLAKE2b_256, newHash256)
crypto.RegisterHash(crypto.BLAKE2b_384, newHash384)
crypto.RegisterHash(crypto.BLAKE2b_512, newHash512)
}

View File

@ -4,7 +4,7 @@
// Package blake2s implements the BLAKE2s hash algorithm as // Package blake2s implements the BLAKE2s hash algorithm as
// defined in RFC 7693. // defined in RFC 7693.
package blake2s package blake2s // import "golang.org/x/crypto/blake2s"
import ( import (
"encoding/binary" "encoding/binary"

View File

@ -0,0 +1,21 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.9
package blake2s
import (
"crypto"
"hash"
)
func init() {
newHash256 := func() hash.Hash {
h, _ := New256(nil)
return h
}
crypto.RegisterHash(crypto.BLAKE2s_256, newHash256)
}

View File

@ -3,7 +3,7 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package chacha20poly1305 implements the ChaCha20-Poly1305 AEAD as specified in RFC 7539. // Package chacha20poly1305 implements the ChaCha20-Poly1305 AEAD as specified in RFC 7539.
package chacha20poly1305 package chacha20poly1305 // import "golang.org/x/crypto/chacha20poly1305"
import ( import (
"crypto/cipher" "crypto/cipher"

View File

@ -14,13 +14,60 @@ func chacha20Poly1305Open(dst []byte, key []uint32, src, ad []byte) bool
//go:noescape //go:noescape
func chacha20Poly1305Seal(dst []byte, key []uint32, src, ad []byte) func chacha20Poly1305Seal(dst []byte, key []uint32, src, ad []byte)
//go:noescape // cpuid is implemented in chacha20poly1305_amd64.s.
func haveSSSE3() bool func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32)
var canUseASM bool // xgetbv with ecx = 0 is implemented in chacha20poly1305_amd64.s.
func xgetbv() (eax, edx uint32)
var (
useASM bool
useAVX2 bool
)
func init() { func init() {
canUseASM = haveSSSE3() detectCPUFeatures()
}
// detectCPUFeatures is used to detect if cpu instructions
// used by the functions implemented in assembler in
// chacha20poly1305_amd64.s are supported.
func detectCPUFeatures() {
maxID, _, _, _ := cpuid(0, 0)
if maxID < 1 {
return
}
_, _, ecx1, _ := cpuid(1, 0)
haveSSSE3 := isSet(9, ecx1)
useASM = haveSSSE3
haveOSXSAVE := isSet(27, ecx1)
osSupportsAVX := false
// For XGETBV, OSXSAVE bit is required and sufficient.
if haveOSXSAVE {
eax, _ := xgetbv()
// Check if XMM and YMM registers have OS support.
osSupportsAVX = isSet(1, eax) && isSet(2, eax)
}
haveAVX := isSet(28, ecx1) && osSupportsAVX
if maxID < 7 {
return
}
_, ebx7, _, _ := cpuid(7, 0)
haveAVX2 := isSet(5, ebx7) && haveAVX
haveBMI2 := isSet(8, ebx7)
useAVX2 = haveAVX2 && haveBMI2
}
// isSet checks if bit at bitpos is set in value.
func isSet(bitpos uint, value uint32) bool {
return value&(1<<bitpos) != 0
} }
// setupState writes a ChaCha20 input matrix to state. See // setupState writes a ChaCha20 input matrix to state. See
@ -47,7 +94,7 @@ func setupState(state *[16]uint32, key *[32]byte, nonce []byte) {
} }
func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []byte { func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []byte {
if !canUseASM { if !useASM {
return c.sealGeneric(dst, nonce, plaintext, additionalData) return c.sealGeneric(dst, nonce, plaintext, additionalData)
} }
@ -60,7 +107,7 @@ func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []
} }
func (c *chacha20poly1305) open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { func (c *chacha20poly1305) open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
if !canUseASM { if !useASM {
return c.openGeneric(dst, nonce, ciphertext, additionalData) return c.openGeneric(dst, nonce, ciphertext, additionalData)
} }

View File

@ -278,7 +278,7 @@ TEXT ·chacha20Poly1305Open(SB), 0, $288-97
MOVQ ad+72(FP), adp MOVQ ad+72(FP), adp
// Check for AVX2 support // Check for AVX2 support
CMPB runtime·support_avx2(SB), $1 CMPB ·useAVX2(SB), $1
JE chacha20Poly1305Open_AVX2 JE chacha20Poly1305Open_AVX2
// Special optimization, for very short buffers // Special optimization, for very short buffers
@ -1484,8 +1484,7 @@ TEXT ·chacha20Poly1305Seal(SB), 0, $288-96
MOVQ src_len+56(FP), inl MOVQ src_len+56(FP), inl
MOVQ ad+72(FP), adp MOVQ ad+72(FP), adp
// Check for AVX2 support CMPB ·useAVX2(SB), $1
CMPB runtime·support_avx2(SB), $1
JE chacha20Poly1305Seal_AVX2 JE chacha20Poly1305Seal_AVX2
// Special optimization, for very short buffers // Special optimization, for very short buffers
@ -1691,7 +1690,7 @@ sealSSETail64:
MOVO D1, ctr0Store MOVO D1, ctr0Store
sealSSETail64LoopA: sealSSETail64LoopA:
// Perform ChaCha rounds, while hashing the prevsiosly encrpyted ciphertext // Perform ChaCha rounds, while hashing the previously encrypted ciphertext
polyAdd(0(oup)) polyAdd(0(oup))
polyMul polyMul
LEAQ 16(oup), oup LEAQ 16(oup), oup
@ -1725,7 +1724,7 @@ sealSSETail128:
MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1; MOVO D1, ctr1Store MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1; MOVO D1, ctr1Store
sealSSETail128LoopA: sealSSETail128LoopA:
// Perform ChaCha rounds, while hashing the prevsiosly encrpyted ciphertext // Perform ChaCha rounds, while hashing the previously encrypted ciphertext
polyAdd(0(oup)) polyAdd(0(oup))
polyMul polyMul
LEAQ 16(oup), oup LEAQ 16(oup), oup
@ -1771,7 +1770,7 @@ sealSSETail192:
MOVO A1, A2; MOVO B1, B2; MOVO C1, C2; MOVO D1, D2; PADDL ·sseIncMask<>(SB), D2; MOVO D2, ctr2Store MOVO A1, A2; MOVO B1, B2; MOVO C1, C2; MOVO D1, D2; PADDL ·sseIncMask<>(SB), D2; MOVO D2, ctr2Store
sealSSETail192LoopA: sealSSETail192LoopA:
// Perform ChaCha rounds, while hashing the prevsiosly encrpyted ciphertext // Perform ChaCha rounds, while hashing the previously encrypted ciphertext
polyAdd(0(oup)) polyAdd(0(oup))
polyMul polyMul
LEAQ 16(oup), oup LEAQ 16(oup), oup
@ -2695,13 +2694,21 @@ sealAVX2Tail512LoopB:
JMP sealAVX2SealHash JMP sealAVX2SealHash
// func haveSSSE3() bool // func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32)
TEXT ·haveSSSE3(SB), NOSPLIT, $0 TEXT ·cpuid(SB), NOSPLIT, $0-24
XORQ AX, AX MOVL eaxArg+0(FP), AX
INCL AX MOVL ecxArg+4(FP), CX
CPUID CPUID
SHRQ $9, CX MOVL AX, eax+8(FP)
ANDQ $1, CX MOVL BX, ebx+12(FP)
MOVB CX, ret+0(FP) MOVL CX, ecx+16(FP)
MOVL DX, edx+20(FP)
RET RET
// func xgetbv() (eax, edx uint32)
TEXT ·xgetbv(SB),NOSPLIT,$0-8
MOVL $0, CX
XGETBV
MOVL AX, eax+0(FP)
MOVL DX, edx+4(FP)
RET

File diff suppressed because one or more lines are too long

View File

@ -1,3 +1,7 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package chacha20 package chacha20
import ( import (

604
x/crypto/cryptobyte/asn1.go Normal file
View File

@ -0,0 +1,604 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cryptobyte
import (
"encoding/asn1"
"fmt"
"math/big"
"reflect"
"time"
)
// This file contains ASN.1-related methods for String and Builder.
// Tag represents an ASN.1 tag number and class (together also referred to as
// identifier octets). Methods in this package only support the low-tag-number
// form, i.e. a single identifier octet with bits 7-8 encoding the class and
// bits 1-6 encoding the tag number.
type Tag uint8
// Contructed returns t with the context-specific class bit set.
func (t Tag) ContextSpecific() Tag { return t | 0x80 }
// Contructed returns t with the constructed class bit set.
func (t Tag) Constructed() Tag { return t | 0x20 }
// Builder
// AddASN1Int64 appends a DER-encoded ASN.1 INTEGER.
func (b *Builder) AddASN1Int64(v int64) {
b.addASN1Signed(asn1.TagInteger, v)
}
// AddASN1Enum appends a DER-encoded ASN.1 ENUMERATION.
func (b *Builder) AddASN1Enum(v int64) {
b.addASN1Signed(asn1.TagEnum, v)
}
func (b *Builder) addASN1Signed(tag Tag, v int64) {
b.AddASN1(tag, func(c *Builder) {
length := 1
for i := v; i >= 0x80 || i < -0x80; i >>= 8 {
length++
}
for ; length > 0; length-- {
i := v >> uint((length-1)*8) & 0xff
c.AddUint8(uint8(i))
}
})
}
// AddASN1Uint64 appends a DER-encoded ASN.1 INTEGER.
func (b *Builder) AddASN1Uint64(v uint64) {
b.AddASN1(asn1.TagInteger, func(c *Builder) {
length := 1
for i := v; i >= 0x80; i >>= 8 {
length++
}
for ; length > 0; length-- {
i := v >> uint((length-1)*8) & 0xff
c.AddUint8(uint8(i))
}
})
}
// AddASN1BigInt appends a DER-encoded ASN.1 INTEGER.
func (b *Builder) AddASN1BigInt(n *big.Int) {
if b.err != nil {
return
}
b.AddASN1(asn1.TagInteger, func(c *Builder) {
if n.Sign() < 0 {
// A negative number has to be converted to two's-complement form. So we
// invert and subtract 1. If the most-significant-bit isn't set then
// we'll need to pad the beginning with 0xff in order to keep the number
// negative.
nMinus1 := new(big.Int).Neg(n)
nMinus1.Sub(nMinus1, bigOne)
bytes := nMinus1.Bytes()
for i := range bytes {
bytes[i] ^= 0xff
}
if bytes[0]&0x80 == 0 {
c.add(0xff)
}
c.add(bytes...)
} else if n.Sign() == 0 {
c.add(0)
} else {
bytes := n.Bytes()
if bytes[0]&0x80 != 0 {
c.add(0)
}
c.add(bytes...)
}
})
}
// AddASN1OctetString appends a DER-encoded ASN.1 OCTET STRING.
func (b *Builder) AddASN1OctetString(bytes []byte) {
b.AddASN1(asn1.TagOctetString, func(c *Builder) {
c.AddBytes(bytes)
})
}
const generalizedTimeFormatStr = "20060102150405Z0700"
// AddASN1GeneralizedTime appends a DER-encoded ASN.1 GENERALIZEDTIME.
func (b *Builder) AddASN1GeneralizedTime(t time.Time) {
if t.Year() < 0 || t.Year() > 9999 {
b.err = fmt.Errorf("cryptobyte: cannot represent %v as a GeneralizedTime", t)
return
}
b.AddASN1(asn1.TagGeneralizedTime, func(c *Builder) {
c.AddBytes([]byte(t.Format(generalizedTimeFormatStr)))
})
}
// AddASN1BitString appends a DER-encoded ASN.1 BIT STRING.
func (b *Builder) AddASN1BitString(s asn1.BitString) {
// TODO(martinkr): Implement.
b.MarshalASN1(s)
}
// MarshalASN1 calls asn1.Marshal on its input and appends the result if
// successful or records an error if one occurred.
func (b *Builder) MarshalASN1(v interface{}) {
// NOTE(martinkr): This is somewhat of a hack to allow propagation of
// asn1.Marshal errors into Builder.err. N.B. if you call MarshalASN1 with a
// value embedded into a struct, its tag information is lost.
if b.err != nil {
return
}
bytes, err := asn1.Marshal(v)
if err != nil {
b.err = err
return
}
b.AddBytes(bytes)
}
// AddASN1 appends an ASN.1 object. The object is prefixed with the given tag.
// Tags greater than 30 are not supported and result in an error (i.e.
// low-tag-number form only). The child builder passed to the
// BuilderContinuation can be used to build the content of the ASN.1 object.
func (b *Builder) AddASN1(tag Tag, f BuilderContinuation) {
if b.err != nil {
return
}
// Identifiers with the low five bits set indicate high-tag-number format
// (two or more octets), which we don't support.
if tag&0x1f == 0x1f {
b.err = fmt.Errorf("cryptobyte: high-tag number identifier octects not supported: 0x%x", tag)
return
}
b.AddUint8(uint8(tag))
b.addLengthPrefixed(1, true, f)
}
// String
var bigIntType = reflect.TypeOf((*big.Int)(nil)).Elem()
// ReadASN1Integer decodes an ASN.1 INTEGER into out and advances. If out does
// not point to an integer or to a big.Int, it panics. It returns true on
// success and false on error.
func (s *String) ReadASN1Integer(out interface{}) bool {
if reflect.TypeOf(out).Kind() != reflect.Ptr {
panic("out is not a pointer")
}
switch reflect.ValueOf(out).Elem().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var i int64
if !s.readASN1Int64(&i) || reflect.ValueOf(out).Elem().OverflowInt(i) {
return false
}
reflect.ValueOf(out).Elem().SetInt(i)
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
var u uint64
if !s.readASN1Uint64(&u) || reflect.ValueOf(out).Elem().OverflowUint(u) {
return false
}
reflect.ValueOf(out).Elem().SetUint(u)
return true
case reflect.Struct:
if reflect.TypeOf(out).Elem() == bigIntType {
return s.readASN1BigInt(out.(*big.Int))
}
}
panic("out does not point to an integer type")
}
func checkASN1Integer(bytes []byte) bool {
if len(bytes) == 0 {
// An INTEGER is encoded with at least one octet.
return false
}
if len(bytes) == 1 {
return true
}
if bytes[0] == 0 && bytes[1]&0x80 == 0 || bytes[0] == 0xff && bytes[1]&0x80 == 0x80 {
// Value is not minimally encoded.
return false
}
return true
}
var bigOne = big.NewInt(1)
func (s *String) readASN1BigInt(out *big.Int) bool {
var bytes String
if !s.ReadASN1(&bytes, asn1.TagInteger) || !checkASN1Integer(bytes) {
return false
}
if bytes[0]&0x80 == 0x80 {
// Negative number.
neg := make([]byte, len(bytes))
for i, b := range bytes {
neg[i] = ^b
}
out.SetBytes(neg)
out.Add(out, bigOne)
out.Neg(out)
} else {
out.SetBytes(bytes)
}
return true
}
func (s *String) readASN1Int64(out *int64) bool {
var bytes String
if !s.ReadASN1(&bytes, asn1.TagInteger) || !checkASN1Integer(bytes) || !asn1Signed(out, bytes) {
return false
}
return true
}
func asn1Signed(out *int64, n []byte) bool {
length := len(n)
if length > 8 {
return false
}
for i := 0; i < length; i++ {
*out <<= 8
*out |= int64(n[i])
}
// Shift up and down in order to sign extend the result.
*out <<= 64 - uint8(length)*8
*out >>= 64 - uint8(length)*8
return true
}
func (s *String) readASN1Uint64(out *uint64) bool {
var bytes String
if !s.ReadASN1(&bytes, asn1.TagInteger) || !checkASN1Integer(bytes) || !asn1Unsigned(out, bytes) {
return false
}
return true
}
func asn1Unsigned(out *uint64, n []byte) bool {
length := len(n)
if length > 9 || length == 9 && n[0] != 0 {
// Too large for uint64.
return false
}
if n[0]&0x80 != 0 {
// Negative number.
return false
}
for i := 0; i < length; i++ {
*out <<= 8
*out |= uint64(n[i])
}
return true
}
// ReadASN1Enum decodes an ASN.1 ENUMERATION into out and advances. It returns
// true on success and false on error.
func (s *String) ReadASN1Enum(out *int) bool {
var bytes String
var i int64
if !s.ReadASN1(&bytes, asn1.TagEnum) || !checkASN1Integer(bytes) || !asn1Signed(&i, bytes) {
return false
}
if int64(int(i)) != i {
return false
}
*out = int(i)
return true
}
func (s *String) readBase128Int(out *int) bool {
ret := 0
for i := 0; len(*s) > 0; i++ {
if i == 4 {
return false
}
ret <<= 7
b := s.read(1)[0]
ret |= int(b & 0x7f)
if b&0x80 == 0 {
*out = ret
return true
}
}
return false // truncated
}
// ReadASN1ObjectIdentifier decodes an ASN.1 OBJECT IDENTIFIER into out and
// advances. It returns true on success and false on error.
func (s *String) ReadASN1ObjectIdentifier(out *asn1.ObjectIdentifier) bool {
var bytes String
if !s.ReadASN1(&bytes, asn1.TagOID) || len(bytes) == 0 {
return false
}
// In the worst case, we get two elements from the first byte (which is
// encoded differently) and then every varint is a single byte long.
components := make([]int, len(bytes)+1)
// The first varint is 40*value1 + value2:
// According to this packing, value1 can take the values 0, 1 and 2 only.
// When value1 = 0 or value1 = 1, then value2 is <= 39. When value1 = 2,
// then there are no restrictions on value2.
var v int
if !bytes.readBase128Int(&v) {
return false
}
if v < 80 {
components[0] = v / 40
components[1] = v % 40
} else {
components[0] = 2
components[1] = v - 80
}
i := 2
for ; len(bytes) > 0; i++ {
if !bytes.readBase128Int(&v) {
return false
}
components[i] = v
}
*out = components[:i]
return true
}
// ReadASN1GeneralizedTime decodes an ASN.1 GENERALIZEDTIME into out and
// advances. It returns true on success and false on error.
func (s *String) ReadASN1GeneralizedTime(out *time.Time) bool {
var bytes String
if !s.ReadASN1(&bytes, asn1.TagGeneralizedTime) {
return false
}
t := string(bytes)
res, err := time.Parse(generalizedTimeFormatStr, t)
if err != nil {
return false
}
if serialized := res.Format(generalizedTimeFormatStr); serialized != t {
return false
}
*out = res
return true
}
// ReadASN1BitString decodes an ASN.1 BIT STRING into out and advances. It
// returns true on success and false on error.
func (s *String) ReadASN1BitString(out *asn1.BitString) bool {
var bytes String
if !s.ReadASN1(&bytes, asn1.TagBitString) || len(bytes) == 0 {
return false
}
paddingBits := uint8(bytes[0])
bytes = bytes[1:]
if paddingBits > 7 ||
len(bytes) == 0 && paddingBits != 0 ||
len(bytes) > 0 && bytes[len(bytes)-1]&(1<<paddingBits-1) != 0 {
return false
}
out.BitLength = len(bytes)*8 - int(paddingBits)
out.Bytes = bytes
return true
}
// ReadASN1Bytes reads the contents of a DER-encoded ASN.1 element (not including
// tag and length bytes) into out, and advances. The element must match the
// given tag. It returns true on success and false on error.
func (s *String) ReadASN1Bytes(out *[]byte, tag Tag) bool {
return s.ReadASN1((*String)(out), tag)
}
// ReadASN1 reads the contents of a DER-encoded ASN.1 element (not including
// tag and length bytes) into out, and advances. The element must match the
// given tag. It returns true on success and false on error.
//
// Tags greater than 30 are not supported (i.e. low-tag-number format only).
func (s *String) ReadASN1(out *String, tag Tag) bool {
var t Tag
if !s.ReadAnyASN1(out, &t) || t != tag {
return false
}
return true
}
// ReadASN1Element reads the contents of a DER-encoded ASN.1 element (including
// tag and length bytes) into out, and advances. The element must match the
// given tag. It returns true on success and false on error.
//
// Tags greater than 30 are not supported (i.e. low-tag-number format only).
func (s *String) ReadASN1Element(out *String, tag Tag) bool {
var t Tag
if !s.ReadAnyASN1Element(out, &t) || t != tag {
return false
}
return true
}
// ReadAnyASN1 reads the contents of a DER-encoded ASN.1 element (not including
// tag and length bytes) into out, sets outTag to its tag, and advances. It
// returns true on success and false on error.
//
// Tags greater than 30 are not supported (i.e. low-tag-number format only).
func (s *String) ReadAnyASN1(out *String, outTag *Tag) bool {
return s.readASN1(out, outTag, true /* skip header */)
}
// ReadAnyASN1Element reads the contents of a DER-encoded ASN.1 element
// (including tag and length bytes) into out, sets outTag to is tag, and
// advances. It returns true on success and false on error.
//
// Tags greater than 30 are not supported (i.e. low-tag-number format only).
func (s *String) ReadAnyASN1Element(out *String, outTag *Tag) bool {
return s.readASN1(out, outTag, false /* include header */)
}
// PeekASN1Tag returns true if the next ASN.1 value on the string starts with
// the given tag.
func (s String) PeekASN1Tag(tag Tag) bool {
if len(s) == 0 {
return false
}
return Tag(s[0]) == tag
}
// ReadOptionalASN1 attempts to read the contents of a DER-encoded ASN.Element
// (not including tag and length bytes) tagged with the given tag into out. It
// stores whether an element with the tag was found in outPresent, unless
// outPresent is nil. It returns true on success and false on error.
func (s *String) ReadOptionalASN1(out *String, outPresent *bool, tag Tag) bool {
present := s.PeekASN1Tag(tag)
if outPresent != nil {
*outPresent = present
}
if present && !s.ReadASN1(out, tag) {
return false
}
return true
}
// ReadOptionalASN1Integer attempts to read an optional ASN.1 INTEGER
// explicitly tagged with tag into out and advances. If no element with a
// matching tag is present, it writes defaultValue into out instead. If out
// does not point to an integer or to a big.Int, it panics. It returns true on
// success and false on error.
func (s *String) ReadOptionalASN1Integer(out interface{}, tag Tag, defaultValue interface{}) bool {
if reflect.TypeOf(out).Kind() != reflect.Ptr {
panic("out is not a pointer")
}
var present bool
var i String
if !s.ReadOptionalASN1(&i, &present, tag) {
return false
}
if !present {
switch reflect.ValueOf(out).Elem().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
reflect.ValueOf(out).Elem().Set(reflect.ValueOf(defaultValue))
case reflect.Struct:
if reflect.TypeOf(out).Elem() != bigIntType {
panic("invalid integer type")
}
if reflect.TypeOf(defaultValue).Kind() != reflect.Ptr ||
reflect.TypeOf(defaultValue).Elem() != bigIntType {
panic("out points to big.Int, but defaultValue does not")
}
out.(*big.Int).Set(defaultValue.(*big.Int))
default:
panic("invalid integer type")
}
return true
}
if !i.ReadASN1Integer(out) || !i.Empty() {
return false
}
return true
}
// ReadOptionalASN1OctetString attempts to read an optional ASN.1 OCTET STRING
// explicitly tagged with tag into out and advances. If no element with a
// matching tag is present, it writes defaultValue into out instead. It returns
// true on success and false on error.
func (s *String) ReadOptionalASN1OctetString(out *[]byte, outPresent *bool, tag Tag) bool {
var present bool
var child String
if !s.ReadOptionalASN1(&child, &present, tag) {
return false
}
if outPresent != nil {
*outPresent = present
}
if present {
var oct String
if !child.ReadASN1(&oct, asn1.TagOctetString) || !child.Empty() {
return false
}
*out = oct
} else {
*out = nil
}
return true
}
func (s *String) readASN1(out *String, outTag *Tag, skipHeader bool) bool {
if len(*s) < 2 {
return false
}
tag, lenByte := (*s)[0], (*s)[1]
if tag&0x1f == 0x1f {
// ITU-T X.690 section 8.1.2
//
// An identifier octet with a tag part of 0x1f indicates a high-tag-number
// form identifier with two or more octets. We only support tags less than
// 31 (i.e. low-tag-number form, single octet identifier).
return false
}
if outTag != nil {
*outTag = Tag(tag)
}
// ITU-T X.690 section 8.1.3
//
// Bit 8 of the first length byte indicates whether the length is short- or
// long-form.
var length, headerLen uint32 // length includes headerLen
if lenByte&0x80 == 0 {
// Short-form length (section 8.1.3.4), encoded in bits 1-7.
length = uint32(lenByte) + 2
headerLen = 2
} else {
// Long-form length (section 8.1.3.5). Bits 1-7 encode the number of octets
// used to encode the length.
lenLen := lenByte & 0x7f
var len32 uint32
if lenLen == 0 || lenLen > 4 || len(*s) < int(2+lenLen) {
return false
}
lenBytes := String((*s)[2 : 2+lenLen])
if !lenBytes.readUnsigned(&len32, int(lenLen)) {
return false
}
// ITU-T X.690 section 10.1 (DER length forms) requires encoding the length
// with the minimum number of octets.
if len32 < 128 {
// Length should have used short-form encoding.
return false
}
if len32>>((lenLen-1)*8) == 0 {
// Leading octet is 0. Length should have been at least one byte shorter.
return false
}
headerLen = 2 + uint32(lenLen)
if headerLen+len32 < len32 {
// Overflow.
return false
}
length = headerLen + len32
}
if uint32(int(length)) != length || !s.ReadBytes((*[]byte)(out), int(length)) {
return false
}
if skipHeader && !out.Skip(int(headerLen)) {
panic("cryptobyte: internal error")
}
return true
}

View File

@ -0,0 +1,285 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cryptobyte
import (
"bytes"
"encoding/asn1"
"math/big"
"reflect"
"testing"
"time"
)
type readASN1Test struct {
name string
in []byte
tag Tag
ok bool
out interface{}
}
var readASN1TestData = []readASN1Test{
{"valid", []byte{0x30, 2, 1, 2}, 0x30, true, []byte{1, 2}},
{"truncated", []byte{0x30, 3, 1, 2}, 0x30, false, nil},
{"zero length of length", []byte{0x30, 0x80}, 0x30, false, nil},
{"invalid long form length", []byte{0x30, 0x81, 1, 1}, 0x30, false, nil},
{"non-minimal length", append([]byte{0x30, 0x82, 0, 0x80}, make([]byte, 0x80)...), 0x30, false, nil},
{"invalid tag", []byte{0xa1, 3, 0x4, 1, 1}, 31, false, nil},
{"high tag", []byte{0x1f, 0x81, 0x80, 0x01, 2, 1, 2}, 0xff /* actually 0x4001, but tag is uint8 */, false, nil},
}
func TestReadASN1(t *testing.T) {
for _, test := range readASN1TestData {
t.Run(test.name, func(t *testing.T) {
var in, out String = test.in, nil
ok := in.ReadASN1(&out, test.tag)
if ok != test.ok || ok && !bytes.Equal(out, test.out.([]byte)) {
t.Errorf("in.ReadASN1() = %v, want %v; out = %v, want %v", ok, test.ok, out, test.out)
}
})
}
}
func TestReadASN1Optional(t *testing.T) {
var empty String
var present bool
ok := empty.ReadOptionalASN1(nil, &present, 0xa0)
if !ok || present {
t.Errorf("empty.ReadOptionalASN1() = %v, want true; present = %v want false", ok, present)
}
var in, out String = []byte{0xa1, 3, 0x4, 1, 1}, nil
ok = in.ReadOptionalASN1(&out, &present, 0xa0)
if !ok || present {
t.Errorf("in.ReadOptionalASN1() = %v, want true, present = %v, want false", ok, present)
}
ok = in.ReadOptionalASN1(&out, &present, 0xa1)
wantBytes := []byte{4, 1, 1}
if !ok || !present || !bytes.Equal(out, wantBytes) {
t.Errorf("in.ReadOptionalASN1() = %v, want true; present = %v, want true; out = %v, want = %v", ok, present, out, wantBytes)
}
}
var optionalOctetStringTestData = []struct {
readASN1Test
present bool
}{
{readASN1Test{"empty", []byte{}, 0xa0, true, []byte{}}, false},
{readASN1Test{"invalid", []byte{0xa1, 3, 0x4, 2, 1}, 0xa1, false, []byte{}}, true},
{readASN1Test{"missing", []byte{0xa1, 3, 0x4, 1, 1}, 0xa0, true, []byte{}}, false},
{readASN1Test{"present", []byte{0xa1, 3, 0x4, 1, 1}, 0xa1, true, []byte{1}}, true},
}
func TestReadASN1OptionalOctetString(t *testing.T) {
for _, test := range optionalOctetStringTestData {
t.Run(test.name, func(t *testing.T) {
in := String(test.in)
var out []byte
var present bool
ok := in.ReadOptionalASN1OctetString(&out, &present, test.tag)
if ok != test.ok || present != test.present || !bytes.Equal(out, test.out.([]byte)) {
t.Errorf("in.ReadOptionalASN1OctetString() = %v, want %v; present = %v want %v; out = %v, want %v", ok, test.ok, present, test.present, out, test.out)
}
})
}
}
const defaultInt = -1
var optionalIntTestData = []readASN1Test{
{"empty", []byte{}, 0xa0, true, defaultInt},
{"invalid", []byte{0xa1, 3, 0x2, 2, 127}, 0xa1, false, 0},
{"missing", []byte{0xa1, 3, 0x2, 1, 127}, 0xa0, true, defaultInt},
{"present", []byte{0xa1, 3, 0x2, 1, 42}, 0xa1, true, 42},
}
func TestReadASN1OptionalInteger(t *testing.T) {
for _, test := range optionalIntTestData {
t.Run(test.name, func(t *testing.T) {
in := String(test.in)
var out int
ok := in.ReadOptionalASN1Integer(&out, test.tag, defaultInt)
if ok != test.ok || ok && out != test.out.(int) {
t.Errorf("in.ReadOptionalASN1Integer() = %v, want %v; out = %v, want %v", ok, test.ok, out, test.out)
}
})
}
}
func TestReadASN1IntegerSigned(t *testing.T) {
testData64 := []struct {
in []byte
out int64
}{
{[]byte{2, 3, 128, 0, 0}, -0x800000},
{[]byte{2, 2, 255, 0}, -256},
{[]byte{2, 2, 255, 127}, -129},
{[]byte{2, 1, 128}, -128},
{[]byte{2, 1, 255}, -1},
{[]byte{2, 1, 0}, 0},
{[]byte{2, 1, 1}, 1},
{[]byte{2, 1, 2}, 2},
{[]byte{2, 1, 127}, 127},
{[]byte{2, 2, 0, 128}, 128},
{[]byte{2, 2, 1, 0}, 256},
{[]byte{2, 4, 0, 128, 0, 0}, 0x800000},
}
for i, test := range testData64 {
in := String(test.in)
var out int64
ok := in.ReadASN1Integer(&out)
if !ok || out != test.out {
t.Errorf("#%d: in.ReadASN1Integer() = %v, want true; out = %d, want %d", i, ok, out, test.out)
}
}
// Repeat the same cases, reading into a big.Int.
t.Run("big.Int", func(t *testing.T) {
for i, test := range testData64 {
in := String(test.in)
var out big.Int
ok := in.ReadASN1Integer(&out)
if !ok || out.Int64() != test.out {
t.Errorf("#%d: in.ReadASN1Integer() = %v, want true; out = %d, want %d", i, ok, out.Int64(), test.out)
}
}
})
}
func TestReadASN1IntegerUnsigned(t *testing.T) {
testData := []struct {
in []byte
out uint64
}{
{[]byte{2, 1, 0}, 0},
{[]byte{2, 1, 1}, 1},
{[]byte{2, 1, 2}, 2},
{[]byte{2, 1, 127}, 127},
{[]byte{2, 2, 0, 128}, 128},
{[]byte{2, 2, 1, 0}, 256},
{[]byte{2, 4, 0, 128, 0, 0}, 0x800000},
{[]byte{2, 8, 127, 255, 255, 255, 255, 255, 255, 255}, 0x7fffffffffffffff},
{[]byte{2, 9, 0, 128, 0, 0, 0, 0, 0, 0, 0}, 0x8000000000000000},
{[]byte{2, 9, 0, 255, 255, 255, 255, 255, 255, 255, 255}, 0xffffffffffffffff},
}
for i, test := range testData {
in := String(test.in)
var out uint64
ok := in.ReadASN1Integer(&out)
if !ok || out != test.out {
t.Errorf("#%d: in.ReadASN1Integer() = %v, want true; out = %d, want %d", i, ok, out, test.out)
}
}
}
func TestReadASN1IntegerInvalid(t *testing.T) {
testData := []String{
[]byte{3, 1, 0}, // invalid tag
// truncated
[]byte{2, 1},
[]byte{2, 2, 0},
// not minimally encoded
[]byte{2, 2, 0, 1},
[]byte{2, 2, 0xff, 0xff},
}
for i, test := range testData {
var out int64
if test.ReadASN1Integer(&out) {
t.Errorf("#%d: in.ReadASN1Integer() = true, want false (out = %d)", i, out)
}
}
}
func TestReadASN1ObjectIdentifier(t *testing.T) {
testData := []struct {
in []byte
ok bool
out []int
}{
{[]byte{}, false, []int{}},
{[]byte{6, 0}, false, []int{}},
{[]byte{5, 1, 85}, false, []int{2, 5}},
{[]byte{6, 1, 85}, true, []int{2, 5}},
{[]byte{6, 2, 85, 0x02}, true, []int{2, 5, 2}},
{[]byte{6, 4, 85, 0x02, 0xc0, 0x00}, true, []int{2, 5, 2, 0x2000}},
{[]byte{6, 3, 0x81, 0x34, 0x03}, true, []int{2, 100, 3}},
{[]byte{6, 7, 85, 0x02, 0xc0, 0x80, 0x80, 0x80, 0x80}, false, []int{}},
}
for i, test := range testData {
in := String(test.in)
var out asn1.ObjectIdentifier
ok := in.ReadASN1ObjectIdentifier(&out)
if ok != test.ok || ok && !out.Equal(test.out) {
t.Errorf("#%d: in.ReadASN1ObjectIdentifier() = %v, want %v; out = %v, want %v", i, ok, test.ok, out, test.out)
}
}
}
func TestReadASN1GeneralizedTime(t *testing.T) {
testData := []struct {
in string
ok bool
out time.Time
}{
{"20100102030405Z", true, time.Date(2010, 01, 02, 03, 04, 05, 0, time.UTC)},
{"20100102030405", false, time.Time{}},
{"20100102030405+0607", true, time.Date(2010, 01, 02, 03, 04, 05, 0, time.FixedZone("", 6*60*60+7*60))},
{"20100102030405-0607", true, time.Date(2010, 01, 02, 03, 04, 05, 0, time.FixedZone("", -6*60*60-7*60))},
/* These are invalid times. However, the time package normalises times
* and they were accepted in some versions. See #11134. */
{"00000100000000Z", false, time.Time{}},
{"20101302030405Z", false, time.Time{}},
{"20100002030405Z", false, time.Time{}},
{"20100100030405Z", false, time.Time{}},
{"20100132030405Z", false, time.Time{}},
{"20100231030405Z", false, time.Time{}},
{"20100102240405Z", false, time.Time{}},
{"20100102036005Z", false, time.Time{}},
{"20100102030460Z", false, time.Time{}},
{"-20100102030410Z", false, time.Time{}},
{"2010-0102030410Z", false, time.Time{}},
{"2010-0002030410Z", false, time.Time{}},
{"201001-02030410Z", false, time.Time{}},
{"20100102-030410Z", false, time.Time{}},
{"2010010203-0410Z", false, time.Time{}},
{"201001020304-10Z", false, time.Time{}},
}
for i, test := range testData {
in := String(append([]byte{asn1.TagGeneralizedTime, byte(len(test.in))}, test.in...))
var out time.Time
ok := in.ReadASN1GeneralizedTime(&out)
if ok != test.ok || ok && !reflect.DeepEqual(out, test.out) {
t.Errorf("#%d: in.ReadASN1GeneralizedTime() = %v, want %v; out = %q, want %q", i, ok, test.ok, out, test.out)
}
}
}
func TestReadASN1BitString(t *testing.T) {
testData := []struct {
in []byte
ok bool
out asn1.BitString
}{
{[]byte{}, false, asn1.BitString{}},
{[]byte{0x00}, true, asn1.BitString{}},
{[]byte{0x07, 0x00}, true, asn1.BitString{Bytes: []byte{0}, BitLength: 1}},
{[]byte{0x07, 0x01}, false, asn1.BitString{}},
{[]byte{0x07, 0x40}, false, asn1.BitString{}},
{[]byte{0x08, 0x00}, false, asn1.BitString{}},
{[]byte{0xff}, false, asn1.BitString{}},
{[]byte{0xfe, 0x00}, false, asn1.BitString{}},
}
for i, test := range testData {
in := String(append([]byte{3, byte(len(test.in))}, test.in...))
var out asn1.BitString
ok := in.ReadASN1BitString(&out)
if ok != test.ok || ok && (!bytes.Equal(out.Bytes, test.out.Bytes) || out.BitLength != test.out.BitLength) {
t.Errorf("#%d: in.ReadASN1BitString() = %v, want %v; out = %v, want %v", i, ok, test.ok, out, test.out)
}
}
}

View File

@ -0,0 +1,255 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cryptobyte
import (
"errors"
"fmt"
)
// A Builder builds byte strings from fixed-length and length-prefixed values.
// The zero value is a usable Builder that allocates space as needed.
type Builder struct {
err error
result []byte
fixedSize bool
child *Builder
offset int
pendingLenLen int
pendingIsASN1 bool
}
// NewBuilder creates a Builder that appends its output to the given buffer.
// Like append(), the slice will be reallocated if its capacity is exceeded.
// Use Bytes to get the final buffer.
func NewBuilder(buffer []byte) *Builder {
return &Builder{
result: buffer,
}
}
// NewFixedBuilder creates a Builder that appends its output into the given
// buffer. This builder does not reallocate the output buffer. Writes that
// would exceed the buffer's capacity are treated as an error.
func NewFixedBuilder(buffer []byte) *Builder {
return &Builder{
result: buffer,
fixedSize: true,
}
}
// Bytes returns the bytes written by the builder or an error if one has
// occurred during during building.
func (b *Builder) Bytes() ([]byte, error) {
if b.err != nil {
return nil, b.err
}
return b.result[b.offset:], nil
}
// BytesOrPanic returns the bytes written by the builder or panics if an error
// has occurred during building.
func (b *Builder) BytesOrPanic() []byte {
if b.err != nil {
panic(b.err)
}
return b.result[b.offset:]
}
// AddUint8 appends an 8-bit value to the byte string.
func (b *Builder) AddUint8(v uint8) {
b.add(byte(v))
}
// AddUint16 appends a big-endian, 16-bit value to the byte string.
func (b *Builder) AddUint16(v uint16) {
b.add(byte(v>>8), byte(v))
}
// AddUint24 appends a big-endian, 24-bit value to the byte string. The highest
// byte of the 32-bit input value is silently truncated.
func (b *Builder) AddUint24(v uint32) {
b.add(byte(v>>16), byte(v>>8), byte(v))
}
// AddUint32 appends a big-endian, 32-bit value to the byte string.
func (b *Builder) AddUint32(v uint32) {
b.add(byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
}
// AddBytes appends a sequence of bytes to the byte string.
func (b *Builder) AddBytes(v []byte) {
b.add(v...)
}
// BuilderContinuation is continuation-passing interface for building
// length-prefixed byte sequences. Builder methods for length-prefixed
// sequences (AddUint8LengthPrefixed etc.) will invoke the BuilderContinuation
// supplied to them. The child builder passed to the continuation can be used
// to build the content of the length-prefixed sequence. Example:
//
// parent := cryptobyte.NewBuilder()
// parent.AddUint8LengthPrefixed(func (child *Builder) {
// child.AddUint8(42)
// child.AddUint8LengthPrefixed(func (grandchild *Builder) {
// grandchild.AddUint8(5)
// })
// })
//
// It is an error to write more bytes to the child than allowed by the reserved
// length prefix. After the continuation returns, the child must be considered
// invalid, i.e. users must not store any copies or references of the child
// that outlive the continuation.
type BuilderContinuation func(child *Builder)
// AddUint8LengthPrefixed adds a 8-bit length-prefixed byte sequence.
func (b *Builder) AddUint8LengthPrefixed(f BuilderContinuation) {
b.addLengthPrefixed(1, false, f)
}
// AddUint16LengthPrefixed adds a big-endian, 16-bit length-prefixed byte sequence.
func (b *Builder) AddUint16LengthPrefixed(f BuilderContinuation) {
b.addLengthPrefixed(2, false, f)
}
// AddUint24LengthPrefixed adds a big-endian, 24-bit length-prefixed byte sequence.
func (b *Builder) AddUint24LengthPrefixed(f BuilderContinuation) {
b.addLengthPrefixed(3, false, f)
}
func (b *Builder) addLengthPrefixed(lenLen int, isASN1 bool, f BuilderContinuation) {
// Subsequent writes can be ignored if the builder has encountered an error.
if b.err != nil {
return
}
offset := len(b.result)
b.add(make([]byte, lenLen)...)
b.child = &Builder{
result: b.result,
fixedSize: b.fixedSize,
offset: offset,
pendingLenLen: lenLen,
pendingIsASN1: isASN1,
}
f(b.child)
b.flushChild()
if b.child != nil {
panic("cryptobyte: internal error")
}
}
func (b *Builder) flushChild() {
if b.child == nil {
return
}
b.child.flushChild()
child := b.child
b.child = nil
if child.err != nil {
b.err = child.err
return
}
length := len(child.result) - child.pendingLenLen - child.offset
if length < 0 {
panic("cryptobyte: internal error") // result unexpectedly shrunk
}
if child.pendingIsASN1 {
// For ASN.1, we reserved a single byte for the length. If that turned out
// to be incorrect, we have to move the contents along in order to make
// space.
if child.pendingLenLen != 1 {
panic("cryptobyte: internal error")
}
var lenLen, lenByte uint8
if int64(length) > 0xfffffffe {
b.err = errors.New("pending ASN.1 child too long")
return
} else if length > 0xffffff {
lenLen = 5
lenByte = 0x80 | 4
} else if length > 0xffff {
lenLen = 4
lenByte = 0x80 | 3
} else if length > 0xff {
lenLen = 3
lenByte = 0x80 | 2
} else if length > 0x7f {
lenLen = 2
lenByte = 0x80 | 1
} else {
lenLen = 1
lenByte = uint8(length)
length = 0
}
// Insert the initial length byte, make space for successive length bytes,
// and adjust the offset.
child.result[child.offset] = lenByte
extraBytes := int(lenLen - 1)
if extraBytes != 0 {
child.add(make([]byte, extraBytes)...)
childStart := child.offset + child.pendingLenLen
copy(child.result[childStart+extraBytes:], child.result[childStart:])
}
child.offset++
child.pendingLenLen = extraBytes
}
l := length
for i := child.pendingLenLen - 1; i >= 0; i-- {
child.result[child.offset+i] = uint8(l)
l >>= 8
}
if l != 0 {
b.err = fmt.Errorf("cryptobyte: pending child length %d exceeds %d-byte length prefix", length, child.pendingLenLen)
return
}
if !b.fixedSize {
b.result = child.result // In case child reallocated result.
}
}
func (b *Builder) add(bytes ...byte) {
if b.err != nil {
return
}
if b.child != nil {
panic("attempted write while child is pending")
}
if len(b.result)+len(bytes) < len(bytes) {
b.err = errors.New("cryptobyte: length overflow")
}
if b.fixedSize && len(b.result)+len(bytes) > cap(b.result) {
b.err = errors.New("cryptobyte: Builder is exceeding its fixed-size buffer")
return
}
b.result = append(b.result, bytes...)
}
// A MarshalingValue marshals itself into a Builder.
type MarshalingValue interface {
// Marshal is called by Builder.AddValue. It receives a pointer to a builder
// to marshal itself into. It may return an error that occurred during
// marshaling, such as unset or invalid values.
Marshal(b *Builder) error
}
// AddValue calls Marshal on v, passing a pointer to the builder to append to.
// If Marshal returns an error, it is set on the Builder so that subsequent
// appends don't have an effect.
func (b *Builder) AddValue(v MarshalingValue) {
err := v.Marshal(b)
if err != nil {
b.err = err
}
}

View File

@ -0,0 +1,379 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cryptobyte
import (
"bytes"
"fmt"
"testing"
)
func builderBytesEq(b *Builder, want ...byte) error {
got := b.BytesOrPanic()
if !bytes.Equal(got, want) {
return fmt.Errorf("Bytes() = %v, want %v", got, want)
}
return nil
}
func TestBytes(t *testing.T) {
var b Builder
v := []byte("foobarbaz")
b.AddBytes(v[0:3])
b.AddBytes(v[3:4])
b.AddBytes(v[4:9])
if err := builderBytesEq(&b, v...); err != nil {
t.Error(err)
}
s := String(b.BytesOrPanic())
for _, w := range []string{"foo", "bar", "baz"} {
var got []byte
if !s.ReadBytes(&got, 3) {
t.Errorf("ReadBytes() = false, want true (w = %v)", w)
}
want := []byte(w)
if !bytes.Equal(got, want) {
t.Errorf("ReadBytes(): got = %v, want %v", got, want)
}
}
if len(s) != 0 {
t.Errorf("len(s) = %d, want 0", len(s))
}
}
func TestUint8(t *testing.T) {
var b Builder
b.AddUint8(42)
if err := builderBytesEq(&b, 42); err != nil {
t.Error(err)
}
var s String = b.BytesOrPanic()
var v uint8
if !s.ReadUint8(&v) {
t.Error("ReadUint8() = false, want true")
}
if v != 42 {
t.Errorf("v = %d, want 42", v)
}
if len(s) != 0 {
t.Errorf("len(s) = %d, want 0", len(s))
}
}
func TestUint16(t *testing.T) {
var b Builder
b.AddUint16(65534)
if err := builderBytesEq(&b, 255, 254); err != nil {
t.Error(err)
}
var s String = b.BytesOrPanic()
var v uint16
if !s.ReadUint16(&v) {
t.Error("ReadUint16() == false, want true")
}
if v != 65534 {
t.Errorf("v = %d, want 65534", v)
}
if len(s) != 0 {
t.Errorf("len(s) = %d, want 0", len(s))
}
}
func TestUint24(t *testing.T) {
var b Builder
b.AddUint24(0xfffefd)
if err := builderBytesEq(&b, 255, 254, 253); err != nil {
t.Error(err)
}
var s String = b.BytesOrPanic()
var v uint32
if !s.ReadUint24(&v) {
t.Error("ReadUint8() = false, want true")
}
if v != 0xfffefd {
t.Errorf("v = %d, want fffefd", v)
}
if len(s) != 0 {
t.Errorf("len(s) = %d, want 0", len(s))
}
}
func TestUint24Truncation(t *testing.T) {
var b Builder
b.AddUint24(0x10111213)
if err := builderBytesEq(&b, 0x11, 0x12, 0x13); err != nil {
t.Error(err)
}
}
func TestUint32(t *testing.T) {
var b Builder
b.AddUint32(0xfffefdfc)
if err := builderBytesEq(&b, 255, 254, 253, 252); err != nil {
t.Error(err)
}
var s String = b.BytesOrPanic()
var v uint32
if !s.ReadUint32(&v) {
t.Error("ReadUint8() = false, want true")
}
if v != 0xfffefdfc {
t.Errorf("v = %x, want fffefdfc", v)
}
if len(s) != 0 {
t.Errorf("len(s) = %d, want 0", len(s))
}
}
func TestUMultiple(t *testing.T) {
var b Builder
b.AddUint8(23)
b.AddUint32(0xfffefdfc)
b.AddUint16(42)
if err := builderBytesEq(&b, 23, 255, 254, 253, 252, 0, 42); err != nil {
t.Error(err)
}
var s String = b.BytesOrPanic()
var (
x uint8
y uint32
z uint16
)
if !s.ReadUint8(&x) || !s.ReadUint32(&y) || !s.ReadUint16(&z) {
t.Error("ReadUint8() = false, want true")
}
if x != 23 || y != 0xfffefdfc || z != 42 {
t.Errorf("x, y, z = %d, %d, %d; want 23, 4294901244, 5", x, y, z)
}
if len(s) != 0 {
t.Errorf("len(s) = %d, want 0", len(s))
}
}
func TestUint8LengthPrefixedSimple(t *testing.T) {
var b Builder
b.AddUint8LengthPrefixed(func(c *Builder) {
c.AddUint8(23)
c.AddUint8(42)
})
if err := builderBytesEq(&b, 2, 23, 42); err != nil {
t.Error(err)
}
var base, child String = b.BytesOrPanic(), nil
var x, y uint8
if !base.ReadUint8LengthPrefixed(&child) || !child.ReadUint8(&x) ||
!child.ReadUint8(&y) {
t.Error("parsing failed")
}
if x != 23 || y != 42 {
t.Errorf("want x, y == 23, 42; got %d, %d", x, y)
}
if len(base) != 0 {
t.Errorf("len(base) = %d, want 0", len(base))
}
if len(child) != 0 {
t.Errorf("len(child) = %d, want 0", len(child))
}
}
func TestUint8LengthPrefixedMulti(t *testing.T) {
var b Builder
b.AddUint8LengthPrefixed(func(c *Builder) {
c.AddUint8(23)
c.AddUint8(42)
})
b.AddUint8(5)
b.AddUint8LengthPrefixed(func(c *Builder) {
c.AddUint8(123)
c.AddUint8(234)
})
if err := builderBytesEq(&b, 2, 23, 42, 5, 2, 123, 234); err != nil {
t.Error(err)
}
var s, child String = b.BytesOrPanic(), nil
var u, v, w, x, y uint8
if !s.ReadUint8LengthPrefixed(&child) || !child.ReadUint8(&u) || !child.ReadUint8(&v) ||
!s.ReadUint8(&w) || !s.ReadUint8LengthPrefixed(&child) || !child.ReadUint8(&x) || !child.ReadUint8(&y) {
t.Error("parsing failed")
}
if u != 23 || v != 42 || w != 5 || x != 123 || y != 234 {
t.Errorf("u, v, w, x, y = %d, %d, %d, %d, %d; want 23, 42, 5, 123, 234",
u, v, w, x, y)
}
if len(s) != 0 {
t.Errorf("len(s) = %d, want 0", len(s))
}
if len(child) != 0 {
t.Errorf("len(child) = %d, want 0", len(child))
}
}
func TestUint8LengthPrefixedNested(t *testing.T) {
var b Builder
b.AddUint8LengthPrefixed(func(c *Builder) {
c.AddUint8(5)
c.AddUint8LengthPrefixed(func(d *Builder) {
d.AddUint8(23)
d.AddUint8(42)
})
c.AddUint8(123)
})
if err := builderBytesEq(&b, 5, 5, 2, 23, 42, 123); err != nil {
t.Error(err)
}
var base, child1, child2 String = b.BytesOrPanic(), nil, nil
var u, v, w, x uint8
if !base.ReadUint8LengthPrefixed(&child1) {
t.Error("parsing base failed")
}
if !child1.ReadUint8(&u) || !child1.ReadUint8LengthPrefixed(&child2) || !child1.ReadUint8(&x) {
t.Error("parsing child1 failed")
}
if !child2.ReadUint8(&v) || !child2.ReadUint8(&w) {
t.Error("parsing child2 failed")
}
if u != 5 || v != 23 || w != 42 || x != 123 {
t.Errorf("u, v, w, x = %d, %d, %d, %d, want 5, 23, 42, 123",
u, v, w, x)
}
if len(base) != 0 {
t.Errorf("len(base) = %d, want 0", len(base))
}
if len(child1) != 0 {
t.Errorf("len(child1) = %d, want 0", len(child1))
}
if len(base) != 0 {
t.Errorf("len(child2) = %d, want 0", len(child2))
}
}
func TestPreallocatedBuffer(t *testing.T) {
var buf [5]byte
b := NewBuilder(buf[0:0])
b.AddUint8(1)
b.AddUint8LengthPrefixed(func(c *Builder) {
c.AddUint8(3)
c.AddUint8(4)
})
b.AddUint16(1286) // Outgrow buf by one byte.
want := []byte{1, 2, 3, 4, 0}
if !bytes.Equal(buf[:], want) {
t.Errorf("buf = %v want %v", buf, want)
}
if err := builderBytesEq(b, 1, 2, 3, 4, 5, 6); err != nil {
t.Error(err)
}
}
func TestWriteWithPendingChild(t *testing.T) {
var b Builder
b.AddUint8LengthPrefixed(func(c *Builder) {
c.AddUint8LengthPrefixed(func(d *Builder) {
defer func() {
if recover() == nil {
t.Errorf("recover() = nil, want error; c.AddUint8() did not panic")
}
}()
c.AddUint8(2) // panics
defer func() {
if recover() == nil {
t.Errorf("recover() = nil, want error; b.AddUint8() did not panic")
}
}()
b.AddUint8(2) // panics
})
defer func() {
if recover() == nil {
t.Errorf("recover() = nil, want error; b.AddUint8() did not panic")
}
}()
b.AddUint8(2) // panics
})
}
// ASN.1
func TestASN1Int64(t *testing.T) {
tests := []struct {
in int64
want []byte
}{
{-0x800000, []byte{2, 3, 128, 0, 0}},
{-256, []byte{2, 2, 255, 0}},
{-129, []byte{2, 2, 255, 127}},
{-128, []byte{2, 1, 128}},
{-1, []byte{2, 1, 255}},
{0, []byte{2, 1, 0}},
{1, []byte{2, 1, 1}},
{2, []byte{2, 1, 2}},
{127, []byte{2, 1, 127}},
{128, []byte{2, 2, 0, 128}},
{256, []byte{2, 2, 1, 0}},
{0x800000, []byte{2, 4, 0, 128, 0, 0}},
}
for i, tt := range tests {
var b Builder
b.AddASN1Int64(tt.in)
if err := builderBytesEq(&b, tt.want...); err != nil {
t.Errorf("%v, (i = %d; in = %v)", err, i, tt.in)
}
var n int64
s := String(b.BytesOrPanic())
ok := s.ReadASN1Integer(&n)
if !ok || n != tt.in {
t.Errorf("s.ReadASN1Integer(&n) = %v, n = %d; want true, n = %d (i = %d)",
ok, n, tt.in, i)
}
if len(s) != 0 {
t.Errorf("len(s) = %d, want 0", len(s))
}
}
}
func TestASN1Uint64(t *testing.T) {
tests := []struct {
in uint64
want []byte
}{
{0, []byte{2, 1, 0}},
{1, []byte{2, 1, 1}},
{2, []byte{2, 1, 2}},
{127, []byte{2, 1, 127}},
{128, []byte{2, 2, 0, 128}},
{256, []byte{2, 2, 1, 0}},
{0x800000, []byte{2, 4, 0, 128, 0, 0}},
{0x7fffffffffffffff, []byte{2, 8, 127, 255, 255, 255, 255, 255, 255, 255}},
{0x8000000000000000, []byte{2, 9, 0, 128, 0, 0, 0, 0, 0, 0, 0}},
{0xffffffffffffffff, []byte{2, 9, 0, 255, 255, 255, 255, 255, 255, 255, 255}},
}
for i, tt := range tests {
var b Builder
b.AddASN1Uint64(tt.in)
if err := builderBytesEq(&b, tt.want...); err != nil {
t.Errorf("%v, (i = %d; in = %v)", err, i, tt.in)
}
var n uint64
s := String(b.BytesOrPanic())
ok := s.ReadASN1Integer(&n)
if !ok || n != tt.in {
t.Errorf("s.ReadASN1Integer(&n) = %v, n = %d; want true, n = %d (i = %d)",
ok, n, tt.in, i)
}
if len(s) != 0 {
t.Errorf("len(s) = %d, want 0", len(s))
}
}
}

View File

@ -0,0 +1,120 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cryptobyte_test
import (
"encoding/asn1"
"fmt"
"golang.org/x/crypto/cryptobyte"
)
func ExampleString_lengthPrefixed() {
// This is an example of parsing length-prefixed data (as found in, for
// example, TLS). Imagine a 16-bit prefixed series of 8-bit prefixed
// strings.
input := cryptobyte.String([]byte{0, 12, 5, 'h', 'e', 'l', 'l', 'o', 5, 'w', 'o', 'r', 'l', 'd'})
var result []string
var values cryptobyte.String
if !input.ReadUint16LengthPrefixed(&values) ||
!input.Empty() {
panic("bad format")
}
for !values.Empty() {
var value cryptobyte.String
if !values.ReadUint8LengthPrefixed(&value) {
panic("bad format")
}
result = append(result, string(value))
}
// Output: []string{"hello", "world"}
fmt.Printf("%#v\n", result)
}
func ExampleString_asn1() {
// This is an example of parsing ASN.1 data that looks like:
// Foo ::= SEQUENCE {
// version [6] INTEGER DEFAULT 0
// data OCTET STRING
// }
input := cryptobyte.String([]byte{0x30, 12, 0xa6, 3, 2, 1, 2, 4, 5, 'h', 'e', 'l', 'l', 'o'})
var (
version int64
data, inner, versionBytes cryptobyte.String
haveVersion bool
)
if !input.ReadASN1(&inner, cryptobyte.Tag(asn1.TagSequence).Constructed()) ||
!input.Empty() ||
!inner.ReadOptionalASN1(&versionBytes, &haveVersion, cryptobyte.Tag(6).Constructed().ContextSpecific()) ||
(haveVersion && !versionBytes.ReadASN1Integer(&version)) ||
(haveVersion && !versionBytes.Empty()) ||
!inner.ReadASN1(&data, asn1.TagOctetString) ||
!inner.Empty() {
panic("bad format")
}
// Output: haveVersion: true, version: 2, data: hello
fmt.Printf("haveVersion: %t, version: %d, data: %s\n", haveVersion, version, string(data))
}
func ExampleBuilder_asn1() {
// This is an example of building ASN.1 data that looks like:
// Foo ::= SEQUENCE {
// version [6] INTEGER DEFAULT 0
// data OCTET STRING
// }
version := int64(2)
data := []byte("hello")
const defaultVersion = 0
var b cryptobyte.Builder
b.AddASN1(cryptobyte.Tag(asn1.TagSequence).Constructed(), func(b *cryptobyte.Builder) {
if version != defaultVersion {
b.AddASN1(cryptobyte.Tag(6).Constructed().ContextSpecific(), func(b *cryptobyte.Builder) {
b.AddASN1Int64(version)
})
}
b.AddASN1OctetString(data)
})
result, err := b.Bytes()
if err != nil {
panic(err)
}
// Output: 300ca603020102040568656c6c6f
fmt.Printf("%x\n", result)
}
func ExampleBuilder_lengthPrefixed() {
// This is an example of building length-prefixed data (as found in,
// for example, TLS). Imagine a 16-bit prefixed series of 8-bit
// prefixed strings.
input := []string{"hello", "world"}
var b cryptobyte.Builder
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, value := range input {
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte(value))
})
}
})
result, err := b.Bytes()
if err != nil {
panic(err)
}
// Output: 000c0568656c6c6f05776f726c64
fmt.Printf("%x\n", result)
}

View File

@ -0,0 +1,157 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package cryptobyte implements building and parsing of byte strings for
// DER-encoded ASN.1 and TLS messages. See the examples for the Builder and
// String types to get started.
package cryptobyte // import "golang.org/x/crypto/cryptobyte"
// String represents a string of bytes. It provides methods for parsing
// fixed-length and length-prefixed values from it.
type String []byte
// read advances a String by n bytes and returns them. If less than n bytes
// remain, it returns nil.
func (s *String) read(n int) []byte {
if len(*s) < n {
return nil
}
v := (*s)[:n]
*s = (*s)[n:]
return v
}
// Skip advances the String by n byte and reports whether it was successful.
func (s *String) Skip(n int) bool {
return s.read(n) != nil
}
// ReadUint8 decodes an 8-bit value into out and advances over it. It
// returns true on success and false on error.
func (s *String) ReadUint8(out *uint8) bool {
v := s.read(1)
if v == nil {
return false
}
*out = uint8(v[0])
return true
}
// ReadUint16 decodes a big-endian, 16-bit value into out and advances over it.
// It returns true on success and false on error.
func (s *String) ReadUint16(out *uint16) bool {
v := s.read(2)
if v == nil {
return false
}
*out = uint16(v[0])<<8 | uint16(v[1])
return true
}
// ReadUint24 decodes a big-endian, 24-bit value into out and advances over it.
// It returns true on success and false on error.
func (s *String) ReadUint24(out *uint32) bool {
v := s.read(3)
if v == nil {
return false
}
*out = uint32(v[0])<<16 | uint32(v[1])<<8 | uint32(v[2])
return true
}
// ReadUint32 decodes a big-endian, 32-bit value into out and advances over it.
// It returns true on success and false on error.
func (s *String) ReadUint32(out *uint32) bool {
v := s.read(4)
if v == nil {
return false
}
*out = uint32(v[0])<<24 | uint32(v[1])<<16 | uint32(v[2])<<8 | uint32(v[3])
return true
}
func (s *String) readUnsigned(out *uint32, length int) bool {
v := s.read(length)
if v == nil {
return false
}
var result uint32
for i := 0; i < length; i++ {
result <<= 8
result |= uint32(v[i])
}
*out = result
return true
}
func (s *String) readLengthPrefixed(lenLen int, outChild *String) bool {
lenBytes := s.read(lenLen)
if lenBytes == nil {
return false
}
var length uint32
for _, b := range lenBytes {
length = length << 8
length = length | uint32(b)
}
if int(length) < 0 {
// This currently cannot overflow because we read uint24 at most, but check
// anyway in case that changes in the future.
return false
}
v := s.read(int(length))
if v == nil {
return false
}
*outChild = v
return true
}
// ReadUint8LengthPrefixed reads the content of an 8-bit length-prefixed value
// into out and advances over it. It returns true on success and false on
// error.
func (s *String) ReadUint8LengthPrefixed(out *String) bool {
return s.readLengthPrefixed(1, out)
}
// ReadUint16LengthPrefixed reads the content of a big-endian, 16-bit
// length-prefixed value into out and advances over it. It returns true on
// success and false on error.
func (s *String) ReadUint16LengthPrefixed(out *String) bool {
return s.readLengthPrefixed(2, out)
}
// ReadUint24LengthPrefixed reads the content of a big-endian, 24-bit
// length-prefixed value into out and advances over it. It returns true on
// success and false on error.
func (s *String) ReadUint24LengthPrefixed(out *String) bool {
return s.readLengthPrefixed(3, out)
}
// ReadBytes reads n bytes into out and advances over them. It returns true on
// success and false and error.
func (s *String) ReadBytes(out *[]byte, n int) bool {
v := s.read(n)
if v == nil {
return false
}
*out = v
return true
}
// CopyBytes copies len(out) bytes into out and advances over them. It returns
// true on success and false on error.
func (s *String) CopyBytes(out []byte) bool {
n := len(out)
v := s.read(n)
if v == nil {
return false
}
return copy(out, v) == n
}
// Empty reports whether the string does not contain any bytes.
func (s String) Empty() bool {
return len(s) == 0
}

View File

@ -0,0 +1,8 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This code was translated into a form compatible with 6a from the public
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html
#define REDMASK51 0x0007FFFFFFFFFFFF

View File

@ -7,8 +7,8 @@
// +build amd64,!gccgo,!appengine // +build amd64,!gccgo,!appengine
DATA ·REDMASK51(SB)/8, $0x0007FFFFFFFFFFFF // These constants cannot be encoded in non-MOVQ immediates.
GLOBL ·REDMASK51(SB), 8, $8 // We access them directly from memory instead.
DATA ·_121666_213(SB)/8, $996687872 DATA ·_121666_213(SB)/8, $996687872
GLOBL ·_121666_213(SB), 8, $8 GLOBL ·_121666_213(SB), 8, $8

View File

@ -2,87 +2,64 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This code was translated into a form compatible with 6a from the public
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html
// +build amd64,!gccgo,!appengine // +build amd64,!gccgo,!appengine
// func cswap(inout *[5]uint64, v uint64) // func cswap(inout *[4][5]uint64, v uint64)
TEXT ·cswap(SB),7,$0 TEXT ·cswap(SB),7,$0
MOVQ inout+0(FP),DI MOVQ inout+0(FP),DI
MOVQ v+8(FP),SI MOVQ v+8(FP),SI
CMPQ SI,$1 SUBQ $1, SI
MOVQ 0(DI),SI NOTQ SI
MOVQ 80(DI),DX MOVQ SI, X15
MOVQ 8(DI),CX PSHUFD $0x44, X15, X15
MOVQ 88(DI),R8
MOVQ SI,R9 MOVOU 0(DI), X0
CMOVQEQ DX,SI MOVOU 16(DI), X2
CMOVQEQ R9,DX MOVOU 32(DI), X4
MOVQ CX,R9 MOVOU 48(DI), X6
CMOVQEQ R8,CX MOVOU 64(DI), X8
CMOVQEQ R9,R8 MOVOU 80(DI), X1
MOVQ SI,0(DI) MOVOU 96(DI), X3
MOVQ DX,80(DI) MOVOU 112(DI), X5
MOVQ CX,8(DI) MOVOU 128(DI), X7
MOVQ R8,88(DI) MOVOU 144(DI), X9
MOVQ 16(DI),SI
MOVQ 96(DI),DX MOVO X1, X10
MOVQ 24(DI),CX MOVO X3, X11
MOVQ 104(DI),R8 MOVO X5, X12
MOVQ SI,R9 MOVO X7, X13
CMOVQEQ DX,SI MOVO X9, X14
CMOVQEQ R9,DX
MOVQ CX,R9 PXOR X0, X10
CMOVQEQ R8,CX PXOR X2, X11
CMOVQEQ R9,R8 PXOR X4, X12
MOVQ SI,16(DI) PXOR X6, X13
MOVQ DX,96(DI) PXOR X8, X14
MOVQ CX,24(DI) PAND X15, X10
MOVQ R8,104(DI) PAND X15, X11
MOVQ 32(DI),SI PAND X15, X12
MOVQ 112(DI),DX PAND X15, X13
MOVQ 40(DI),CX PAND X15, X14
MOVQ 120(DI),R8 PXOR X10, X0
MOVQ SI,R9 PXOR X10, X1
CMOVQEQ DX,SI PXOR X11, X2
CMOVQEQ R9,DX PXOR X11, X3
MOVQ CX,R9 PXOR X12, X4
CMOVQEQ R8,CX PXOR X12, X5
CMOVQEQ R9,R8 PXOR X13, X6
MOVQ SI,32(DI) PXOR X13, X7
MOVQ DX,112(DI) PXOR X14, X8
MOVQ CX,40(DI) PXOR X14, X9
MOVQ R8,120(DI)
MOVQ 48(DI),SI MOVOU X0, 0(DI)
MOVQ 128(DI),DX MOVOU X2, 16(DI)
MOVQ 56(DI),CX MOVOU X4, 32(DI)
MOVQ 136(DI),R8 MOVOU X6, 48(DI)
MOVQ SI,R9 MOVOU X8, 64(DI)
CMOVQEQ DX,SI MOVOU X1, 80(DI)
CMOVQEQ R9,DX MOVOU X3, 96(DI)
MOVQ CX,R9 MOVOU X5, 112(DI)
CMOVQEQ R8,CX MOVOU X7, 128(DI)
CMOVQEQ R9,R8 MOVOU X9, 144(DI)
MOVQ SI,48(DI)
MOVQ DX,128(DI)
MOVQ CX,56(DI)
MOVQ R8,136(DI)
MOVQ 64(DI),SI
MOVQ 144(DI),DX
MOVQ 72(DI),CX
MOVQ 152(DI),R8
MOVQ SI,R9
CMOVQEQ DX,SI
CMOVQEQ R9,DX
MOVQ CX,R9
CMOVQEQ R8,CX
CMOVQEQ R9,R8
MOVQ SI,64(DI)
MOVQ DX,144(DI)
MOVQ CX,72(DI)
MOVQ R8,152(DI)
MOVQ DI,AX
MOVQ SI,DX
RET RET

View File

@ -8,6 +8,10 @@
package curve25519 package curve25519
import (
"encoding/binary"
)
// This code is a port of the public domain, "ref10" implementation of // This code is a port of the public domain, "ref10" implementation of
// curve25519 from SUPERCOP 20130419 by D. J. Bernstein. // curve25519 from SUPERCOP 20130419 by D. J. Bernstein.
@ -50,17 +54,11 @@ func feCopy(dst, src *fieldElement) {
// //
// Preconditions: b in {0,1}. // Preconditions: b in {0,1}.
func feCSwap(f, g *fieldElement, b int32) { func feCSwap(f, g *fieldElement, b int32) {
var x fieldElement
b = -b b = -b
for i := range x {
x[i] = b & (f[i] ^ g[i])
}
for i := range f { for i := range f {
f[i] ^= x[i] t := b & (f[i] ^ g[i])
} f[i] ^= t
for i := range g { g[i] ^= t
g[i] ^= x[i]
} }
} }
@ -75,12 +73,7 @@ func load3(in []byte) int64 {
// load4 reads a 32-bit, little-endian value from in. // load4 reads a 32-bit, little-endian value from in.
func load4(in []byte) int64 { func load4(in []byte) int64 {
var r int64 return int64(binary.LittleEndian.Uint32(in))
r = int64(in[0])
r |= int64(in[1]) << 8
r |= int64(in[2]) << 16
r |= int64(in[3]) << 24
return r
} }
func feFromBytes(dst *fieldElement, src *[32]byte) { func feFromBytes(dst *fieldElement, src *[32]byte) {

View File

@ -27,3 +27,13 @@ func TestBaseScalarMult(t *testing.T) {
t.Errorf("incorrect result: got %s, want %s", result, expectedHex) t.Errorf("incorrect result: got %s, want %s", result, expectedHex)
} }
} }
func BenchmarkScalarBaseMult(b *testing.B) {
var in, out [32]byte
in[0] = 1
b.SetBytes(32)
for i := 0; i < b.N; i++ {
ScalarBaseMult(&out, &in)
}
}

View File

@ -7,6 +7,8 @@
// +build amd64,!gccgo,!appengine // +build amd64,!gccgo,!appengine
#include "const_amd64.h"
// func freeze(inout *[5]uint64) // func freeze(inout *[5]uint64)
TEXT ·freeze(SB),7,$0-8 TEXT ·freeze(SB),7,$0-8
MOVQ inout+0(FP), DI MOVQ inout+0(FP), DI
@ -16,7 +18,7 @@ TEXT ·freeze(SB),7,$0-8
MOVQ 16(DI),CX MOVQ 16(DI),CX
MOVQ 24(DI),R8 MOVQ 24(DI),R8
MOVQ 32(DI),R9 MOVQ 32(DI),R9
MOVQ ·REDMASK51(SB),AX MOVQ $REDMASK51,AX
MOVQ AX,R10 MOVQ AX,R10
SUBQ $18,R10 SUBQ $18,R10
MOVQ $3,R11 MOVQ $3,R11

View File

@ -7,6 +7,8 @@
// +build amd64,!gccgo,!appengine // +build amd64,!gccgo,!appengine
#include "const_amd64.h"
// func ladderstep(inout *[5][5]uint64) // func ladderstep(inout *[5][5]uint64)
TEXT ·ladderstep(SB),0,$296-8 TEXT ·ladderstep(SB),0,$296-8
MOVQ inout+0(FP),DI MOVQ inout+0(FP),DI
@ -118,7 +120,7 @@ TEXT ·ladderstep(SB),0,$296-8
MULQ 72(SP) MULQ 72(SP)
ADDQ AX,R12 ADDQ AX,R12
ADCQ DX,R13 ADCQ DX,R13
MOVQ ·REDMASK51(SB),DX MOVQ $REDMASK51,DX
SHLQ $13,CX:SI SHLQ $13,CX:SI
ANDQ DX,SI ANDQ DX,SI
SHLQ $13,R9:R8 SHLQ $13,R9:R8
@ -233,7 +235,7 @@ TEXT ·ladderstep(SB),0,$296-8
MULQ 32(SP) MULQ 32(SP)
ADDQ AX,R12 ADDQ AX,R12
ADCQ DX,R13 ADCQ DX,R13
MOVQ ·REDMASK51(SB),DX MOVQ $REDMASK51,DX
SHLQ $13,CX:SI SHLQ $13,CX:SI
ANDQ DX,SI ANDQ DX,SI
SHLQ $13,R9:R8 SHLQ $13,R9:R8
@ -438,7 +440,7 @@ TEXT ·ladderstep(SB),0,$296-8
MULQ 72(SP) MULQ 72(SP)
ADDQ AX,R12 ADDQ AX,R12
ADCQ DX,R13 ADCQ DX,R13
MOVQ ·REDMASK51(SB),DX MOVQ $REDMASK51,DX
SHLQ $13,CX:SI SHLQ $13,CX:SI
ANDQ DX,SI ANDQ DX,SI
SHLQ $13,R9:R8 SHLQ $13,R9:R8
@ -588,7 +590,7 @@ TEXT ·ladderstep(SB),0,$296-8
MULQ 32(SP) MULQ 32(SP)
ADDQ AX,R12 ADDQ AX,R12
ADCQ DX,R13 ADCQ DX,R13
MOVQ ·REDMASK51(SB),DX MOVQ $REDMASK51,DX
SHLQ $13,CX:SI SHLQ $13,CX:SI
ANDQ DX,SI ANDQ DX,SI
SHLQ $13,R9:R8 SHLQ $13,R9:R8
@ -728,7 +730,7 @@ TEXT ·ladderstep(SB),0,$296-8
MULQ 152(DI) MULQ 152(DI)
ADDQ AX,R12 ADDQ AX,R12
ADCQ DX,R13 ADCQ DX,R13
MOVQ ·REDMASK51(SB),DX MOVQ $REDMASK51,DX
SHLQ $13,CX:SI SHLQ $13,CX:SI
ANDQ DX,SI ANDQ DX,SI
SHLQ $13,R9:R8 SHLQ $13,R9:R8
@ -843,7 +845,7 @@ TEXT ·ladderstep(SB),0,$296-8
MULQ 192(DI) MULQ 192(DI)
ADDQ AX,R12 ADDQ AX,R12
ADCQ DX,R13 ADCQ DX,R13
MOVQ ·REDMASK51(SB),DX MOVQ $REDMASK51,DX
SHLQ $13,CX:SI SHLQ $13,CX:SI
ANDQ DX,SI ANDQ DX,SI
SHLQ $13,R9:R8 SHLQ $13,R9:R8
@ -993,7 +995,7 @@ TEXT ·ladderstep(SB),0,$296-8
MULQ 32(DI) MULQ 32(DI)
ADDQ AX,R12 ADDQ AX,R12
ADCQ DX,R13 ADCQ DX,R13
MOVQ ·REDMASK51(SB),DX MOVQ $REDMASK51,DX
SHLQ $13,CX:SI SHLQ $13,CX:SI
ANDQ DX,SI ANDQ DX,SI
SHLQ $13,R9:R8 SHLQ $13,R9:R8
@ -1143,7 +1145,7 @@ TEXT ·ladderstep(SB),0,$296-8
MULQ 112(SP) MULQ 112(SP)
ADDQ AX,R12 ADDQ AX,R12
ADCQ DX,R13 ADCQ DX,R13
MOVQ ·REDMASK51(SB),DX MOVQ $REDMASK51,DX
SHLQ $13,CX:SI SHLQ $13,CX:SI
ANDQ DX,SI ANDQ DX,SI
SHLQ $13,R9:R8 SHLQ $13,R9:R8
@ -1329,7 +1331,7 @@ TEXT ·ladderstep(SB),0,$296-8
MULQ 192(SP) MULQ 192(SP)
ADDQ AX,R12 ADDQ AX,R12
ADCQ DX,R13 ADCQ DX,R13
MOVQ ·REDMASK51(SB),DX MOVQ $REDMASK51,DX
SHLQ $13,CX:SI SHLQ $13,CX:SI
ANDQ DX,SI ANDQ DX,SI
SHLQ $13,R9:R8 SHLQ $13,R9:R8

View File

@ -7,6 +7,8 @@
// +build amd64,!gccgo,!appengine // +build amd64,!gccgo,!appengine
#include "const_amd64.h"
// func mul(dest, a, b *[5]uint64) // func mul(dest, a, b *[5]uint64)
TEXT ·mul(SB),0,$16-24 TEXT ·mul(SB),0,$16-24
MOVQ dest+0(FP), DI MOVQ dest+0(FP), DI
@ -121,7 +123,7 @@ TEXT ·mul(SB),0,$16-24
MULQ 32(CX) MULQ 32(CX)
ADDQ AX,R14 ADDQ AX,R14
ADCQ DX,R15 ADCQ DX,R15
MOVQ ·REDMASK51(SB),SI MOVQ $REDMASK51,SI
SHLQ $13,R9:R8 SHLQ $13,R9:R8
ANDQ SI,R8 ANDQ SI,R8
SHLQ $13,R11:R10 SHLQ $13,R11:R10

View File

@ -7,6 +7,8 @@
// +build amd64,!gccgo,!appengine // +build amd64,!gccgo,!appengine
#include "const_amd64.h"
// func square(out, in *[5]uint64) // func square(out, in *[5]uint64)
TEXT ·square(SB),7,$0-16 TEXT ·square(SB),7,$0-16
MOVQ out+0(FP), DI MOVQ out+0(FP), DI
@ -84,7 +86,7 @@ TEXT ·square(SB),7,$0-16
MULQ 32(SI) MULQ 32(SI)
ADDQ AX,R13 ADDQ AX,R13
ADCQ DX,R14 ADCQ DX,R14
MOVQ ·REDMASK51(SB),SI MOVQ $REDMASK51,SI
SHLQ $13,R8:CX SHLQ $13,R8:CX
ANDQ SI,CX ANDQ SI,CX
SHLQ $13,R10:R9 SHLQ $13,R10:R9

View File

@ -116,12 +116,11 @@ type basicResponse struct {
} }
type responseData struct { type responseData struct {
Raw asn1.RawContent Raw asn1.RawContent
Version int `asn1:"optional,default:0,explicit,tag:0"` Version int `asn1:"optional,default:0,explicit,tag:0"`
RawResponderName asn1.RawValue `asn1:"optional,explicit,tag:1"` RawResponderID asn1.RawValue
KeyHash []byte `asn1:"optional,explicit,tag:2"` ProducedAt time.Time `asn1:"generalized"`
ProducedAt time.Time `asn1:"generalized"` Responses []singleResponse
Responses []singleResponse
} }
type singleResponse struct { type singleResponse struct {
@ -363,6 +362,15 @@ type Response struct {
// If zero, the default is crypto.SHA1. // If zero, the default is crypto.SHA1.
IssuerHash crypto.Hash IssuerHash crypto.Hash
// RawResponderName optionally contains the DER-encoded subject of the
// responder certificate. Exactly one of RawResponderName and
// ResponderKeyHash is set.
RawResponderName []byte
// ResponderKeyHash optionally contains the SHA-1 hash of the
// responder's public key. Exactly one of RawResponderName and
// ResponderKeyHash is set.
ResponderKeyHash []byte
// Extensions contains raw X.509 extensions from the singleExtensions field // Extensions contains raw X.509 extensions from the singleExtensions field
// of the OCSP response. When parsing certificates, this can be used to // of the OCSP response. When parsing certificates, this can be used to
// extract non-critical extensions that are not parsed by this package. When // extract non-critical extensions that are not parsed by this package. When
@ -494,6 +502,25 @@ func ParseResponseForCert(bytes []byte, cert, issuer *x509.Certificate) (*Respon
SignatureAlgorithm: getSignatureAlgorithmFromOID(basicResp.SignatureAlgorithm.Algorithm), SignatureAlgorithm: getSignatureAlgorithmFromOID(basicResp.SignatureAlgorithm.Algorithm),
} }
// Handle the ResponderID CHOICE tag. ResponderID can be flattened into
// TBSResponseData once https://go-review.googlesource.com/34503 has been
// released.
rawResponderID := basicResp.TBSResponseData.RawResponderID
switch rawResponderID.Tag {
case 1: // Name
var rdn pkix.RDNSequence
if rest, err := asn1.Unmarshal(rawResponderID.Bytes, &rdn); err != nil || len(rest) != 0 {
return nil, ParseError("invalid responder name")
}
ret.RawResponderName = rawResponderID.Bytes
case 2: // KeyHash
if rest, err := asn1.Unmarshal(rawResponderID.Bytes, &ret.ResponderKeyHash); err != nil || len(rest) != 0 {
return nil, ParseError("invalid responder key hash")
}
default:
return nil, ParseError("invalid responder id tag")
}
if len(basicResp.Certificates) > 0 { if len(basicResp.Certificates) > 0 {
ret.Certificate, err = x509.ParseCertificate(basicResp.Certificates[0].FullBytes) ret.Certificate, err = x509.ParseCertificate(basicResp.Certificates[0].FullBytes)
if err != nil { if err != nil {
@ -501,17 +528,17 @@ func ParseResponseForCert(bytes []byte, cert, issuer *x509.Certificate) (*Respon
} }
if err := ret.CheckSignatureFrom(ret.Certificate); err != nil { if err := ret.CheckSignatureFrom(ret.Certificate); err != nil {
return nil, ParseError("bad OCSP signature") return nil, ParseError("bad signature on embedded certificate: " + err.Error())
} }
if issuer != nil { if issuer != nil {
if err := issuer.CheckSignature(ret.Certificate.SignatureAlgorithm, ret.Certificate.RawTBSCertificate, ret.Certificate.Signature); err != nil { if err := issuer.CheckSignature(ret.Certificate.SignatureAlgorithm, ret.Certificate.RawTBSCertificate, ret.Certificate.Signature); err != nil {
return nil, ParseError("bad signature on embedded certificate") return nil, ParseError("bad OCSP signature: " + err.Error())
} }
} }
} else if issuer != nil { } else if issuer != nil {
if err := ret.CheckSignatureFrom(issuer); err != nil { if err := ret.CheckSignatureFrom(issuer); err != nil {
return nil, ParseError("bad OCSP signature") return nil, ParseError("bad OCSP signature: " + err.Error())
} }
} }
@ -620,8 +647,8 @@ func CreateRequest(cert, issuer *x509.Certificate, opts *RequestOptions) ([]byte
// CreateResponse returns a DER-encoded OCSP response with the specified contents. // CreateResponse returns a DER-encoded OCSP response with the specified contents.
// The fields in the response are populated as follows: // The fields in the response are populated as follows:
// //
// The responder cert is used to populate the ResponderName field, and the certificate // The responder cert is used to populate the responder's name field, and the
// itself is provided alongside the OCSP response signature. // certificate itself is provided alongside the OCSP response signature.
// //
// The issuer cert is used to puplate the IssuerNameHash and IssuerKeyHash fields. // The issuer cert is used to puplate the IssuerNameHash and IssuerKeyHash fields.
// //
@ -649,7 +676,7 @@ func CreateResponse(issuer, responderCert *x509.Certificate, template Response,
} }
if !template.IssuerHash.Available() { if !template.IssuerHash.Available() {
return nil, fmt.Errorf("issuer hash algorithm %v not linked into binarya", template.IssuerHash) return nil, fmt.Errorf("issuer hash algorithm %v not linked into binary", template.IssuerHash)
} }
h := template.IssuerHash.New() h := template.IssuerHash.New()
h.Write(publicKeyInfo.PublicKey.RightAlign()) h.Write(publicKeyInfo.PublicKey.RightAlign())
@ -686,17 +713,17 @@ func CreateResponse(issuer, responderCert *x509.Certificate, template Response,
} }
} }
responderName := asn1.RawValue{ rawResponderID := asn1.RawValue{
Class: 2, // context-specific Class: 2, // context-specific
Tag: 1, // explicit tag Tag: 1, // Name (explicit tag)
IsCompound: true, IsCompound: true,
Bytes: responderCert.RawSubject, Bytes: responderCert.RawSubject,
} }
tbsResponseData := responseData{ tbsResponseData := responseData{
Version: 0, Version: 0,
RawResponderName: responderName, RawResponderID: rawResponderID,
ProducedAt: time.Now().Truncate(time.Minute).UTC(), ProducedAt: time.Now().Truncate(time.Minute).UTC(),
Responses: []singleResponse{innerResponse}, Responses: []singleResponse{innerResponse},
} }
tbsResponseDataDER, err := asn1.Marshal(tbsResponseData) tbsResponseDataDER, err := asn1.Marshal(tbsResponseData)

View File

@ -24,7 +24,13 @@ func TestOCSPDecode(t *testing.T) {
responseBytes, _ := hex.DecodeString(ocspResponseHex) responseBytes, _ := hex.DecodeString(ocspResponseHex)
resp, err := ParseResponse(responseBytes, nil) resp, err := ParseResponse(responseBytes, nil)
if err != nil { if err != nil {
t.Error(err) t.Fatal(err)
}
responderCert, _ := hex.DecodeString(startComResponderCertHex)
responder, err := x509.ParseCertificate(responderCert)
if err != nil {
t.Fatal(err)
} }
expected := Response{ expected := Response{
@ -33,6 +39,7 @@ func TestOCSPDecode(t *testing.T) {
RevocationReason: Unspecified, RevocationReason: Unspecified,
ThisUpdate: time.Date(2010, 7, 7, 15, 1, 5, 0, time.UTC), ThisUpdate: time.Date(2010, 7, 7, 15, 1, 5, 0, time.UTC),
NextUpdate: time.Date(2010, 7, 7, 18, 35, 17, 0, time.UTC), NextUpdate: time.Date(2010, 7, 7, 18, 35, 17, 0, time.UTC),
RawResponderName: responder.RawSubject,
} }
if !reflect.DeepEqual(resp.ThisUpdate, expected.ThisUpdate) { if !reflect.DeepEqual(resp.ThisUpdate, expected.ThisUpdate) {
@ -54,6 +61,14 @@ func TestOCSPDecode(t *testing.T) {
if resp.RevocationReason != expected.RevocationReason { if resp.RevocationReason != expected.RevocationReason {
t.Errorf("resp.RevocationReason: got %d, want %d", resp.RevocationReason, expected.RevocationReason) t.Errorf("resp.RevocationReason: got %d, want %d", resp.RevocationReason, expected.RevocationReason)
} }
if !bytes.Equal(resp.RawResponderName, expected.RawResponderName) {
t.Errorf("resp.RawResponderName: got %x, want %x", resp.RawResponderName, expected.RawResponderName)
}
if !bytes.Equal(resp.ResponderKeyHash, expected.ResponderKeyHash) {
t.Errorf("resp.ResponderKeyHash: got %x, want %x", resp.ResponderKeyHash, expected.ResponderKeyHash)
}
} }
func TestOCSPDecodeWithoutCert(t *testing.T) { func TestOCSPDecodeWithoutCert(t *testing.T) {
@ -210,7 +225,6 @@ func TestOCSPResponse(t *testing.T) {
}, },
} }
producedAt := time.Now().Truncate(time.Minute)
thisUpdate := time.Date(2010, 7, 7, 15, 1, 5, 0, time.UTC) thisUpdate := time.Date(2010, 7, 7, 15, 1, 5, 0, time.UTC)
nextUpdate := time.Date(2010, 7, 7, 18, 35, 17, 0, time.UTC) nextUpdate := time.Date(2010, 7, 7, 18, 35, 17, 0, time.UTC)
template := Response{ template := Response{
@ -269,8 +283,9 @@ func TestOCSPResponse(t *testing.T) {
t.Errorf("resp.Extensions: got %v, want %v", resp.Extensions, template.ExtraExtensions) t.Errorf("resp.Extensions: got %v, want %v", resp.Extensions, template.ExtraExtensions)
} }
if !resp.ProducedAt.Equal(producedAt) { delay := time.Since(resp.ProducedAt)
t.Errorf("resp.ProducedAt: got %d, want %d", resp.ProducedAt, producedAt) if delay < -time.Hour || delay > time.Hour {
t.Errorf("resp.ProducedAt: got %s, want close to current time (%s)", resp.ProducedAt, time.Now())
} }
if resp.Status != template.Status { if resp.Status != template.Status {
@ -386,6 +401,41 @@ const ocspResponseHex = "308206bc0a0100a08206b5308206b106092b0601050507300101048
"a1d24ce16e41a9941568fec5b42771e118f16c106a54ccc339a4b02166445a167902e75e" + "a1d24ce16e41a9941568fec5b42771e118f16c106a54ccc339a4b02166445a167902e75e" +
"6d8620b0825dcd18a069b90fd851d10fa8effd409deec02860d26d8d833f304b10669b42" "6d8620b0825dcd18a069b90fd851d10fa8effd409deec02860d26d8d833f304b10669b42"
const startComResponderCertHex = "308204b23082039aa003020102020101300d06092a864886f70d010105050030818c310b" +
"300906035504061302494c31163014060355040a130d5374617274436f6d204c74642e31" +
"2b3029060355040b1322536563757265204469676974616c204365727469666963617465" +
"205369676e696e67313830360603550403132f5374617274436f6d20436c617373203120" +
"5072696d61727920496e7465726d65646961746520536572766572204341301e170d3037" +
"313032353030323330365a170d3132313032333030323330365a304c310b300906035504" +
"061302494c31163014060355040a130d5374617274436f6d204c74642e31253023060355" +
"0403131c5374617274436f6d20436c6173732031204f435350205369676e657230820122" +
"300d06092a864886f70d01010105000382010f003082010a0282010100b9561b4c453187" +
"17178084e96e178df2255e18ed8d8ecc7c2b7b51a6c1c2e6bf0aa3603066f132fe10ae97" +
"b50e99fa24b83fc53dd2777496387d14e1c3a9b6a4933e2ac12413d085570a95b8147414" +
"a0bc007c7bcf222446ef7f1a156d7ea1c577fc5f0facdfd42eb0f5974990cb2f5cefebce" +
"ef4d1bdc7ae5c1075c5a99a93171f2b0845b4ff0864e973fcfe32f9d7511ff87a3e94341" +
"0c90a4493a306b6944359340a9ca96f02b66ce67f028df2980a6aaee8d5d5d452b8b0eb9" +
"3f923cc1e23fcccbdbe7ffcb114d08fa7a6a3c404f825d1a0e715935cf623a8c7b596700" +
"14ed0622f6089a9447a7a19010f7fe58f84129a2765ea367824d1c3bb2fda30853020301" +
"0001a382015c30820158300c0603551d130101ff04023000300b0603551d0f0404030203" +
"a8301e0603551d250417301506082b0601050507030906092b0601050507300105301d06" +
"03551d0e0416041445e0a36695414c5dd449bc00e33cdcdbd2343e173081a80603551d23" +
"0481a030819d8014eb4234d098b0ab9ff41b6b08f7cc642eef0e2c45a18181a47f307d31" +
"0b300906035504061302494c31163014060355040a130d5374617274436f6d204c74642e" +
"312b3029060355040b1322536563757265204469676974616c2043657274696669636174" +
"65205369676e696e6731293027060355040313205374617274436f6d2043657274696669" +
"636174696f6e20417574686f7269747982010a30230603551d12041c301a861868747470" +
"3a2f2f7777772e737461727473736c2e636f6d2f302c06096086480186f842010d041f16" +
"1d5374617274436f6d205265766f636174696f6e20417574686f72697479300d06092a86" +
"4886f70d01010505000382010100182d22158f0fc0291324fa8574c49bb8ff2835085adc" +
"bf7b7fc4191c397ab6951328253fffe1e5ec2a7da0d50fca1a404e6968481366939e666c" +
"0a6209073eca57973e2fefa9ed1718e8176f1d85527ff522c08db702e3b2b180f1cbff05" +
"d98128252cf0f450f7dd2772f4188047f19dc85317366f94bc52d60f453a550af58e308a" +
"aab00ced33040b62bf37f5b1ab2a4f7f0f80f763bf4d707bc8841d7ad9385ee2a4244469" +
"260b6f2bf085977af9074796048ecc2f9d48a1d24ce16e41a9941568fec5b42771e118f1" +
"6c106a54ccc339a4b02166445a167902e75e6d8620b0825dcd18a069b90fd851d10fa8ef" +
"fd409deec02860d26d8d833f304b10669b42"
const startComHex = "308206343082041ca003020102020118300d06092a864886f70d0101050500307d310b30" + const startComHex = "308206343082041ca003020102020118300d06092a864886f70d0101050500307d310b30" +
"0906035504061302494c31163014060355040a130d5374617274436f6d204c74642e312b" + "0906035504061302494c31163014060355040a130d5374617274436f6d204c74642e312b" +
"3029060355040b1322536563757265204469676974616c20436572746966696361746520" + "3029060355040b1322536563757265204469676974616c20436572746966696361746520" +

View File

@ -307,8 +307,6 @@ func readToNextPublicKey(packets *packet.Reader) (err error) {
return return
} }
} }
panic("unreachable")
} }
// ReadEntity reads an entity (public key, identities, subkeys etc) from the // ReadEntity reads an entity (public key, identities, subkeys etc) from the

View File

@ -300,7 +300,7 @@ func TestNewEntityWithoutPreferredHash(t *testing.T) {
for _, identity := range entity.Identities { for _, identity := range entity.Identities {
if len(identity.SelfSignature.PreferredHash) != 0 { if len(identity.SelfSignature.PreferredHash) != 0 {
t.Fatal("Expected preferred hash to be empty but got length %d", len(identity.SelfSignature.PreferredHash)) t.Fatalf("Expected preferred hash to be empty but got length %d", len(identity.SelfSignature.PreferredHash))
} }
} }
} }

View File

@ -273,8 +273,6 @@ func consumeAll(r io.Reader) (n int64, err error) {
return return
} }
} }
panic("unreachable")
} }
// packetType represents the numeric ids of the different OpenPGP packet types. See // packetType represents the numeric ids of the different OpenPGP packet types. See

View File

@ -540,7 +540,6 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro
default: default:
return errors.SignatureError("Unsupported public key algorithm used in signature") return errors.SignatureError("Unsupported public key algorithm used in signature")
} }
panic("unreachable")
} }
// VerifySignatureV3 returns nil iff sig is a valid signature, made by this // VerifySignatureV3 returns nil iff sig is a valid signature, made by this
@ -585,7 +584,6 @@ func (pk *PublicKey) VerifySignatureV3(signed hash.Hash, sig *SignatureV3) (err
default: default:
panic("shouldn't happen") panic("shouldn't happen")
} }
panic("unreachable")
} }
// keySignatureHash returns a Hash of the message that needs to be signed for // keySignatureHash returns a Hash of the message that needs to be signed for

View File

@ -216,7 +216,6 @@ func (pk *PublicKeyV3) VerifySignatureV3(signed hash.Hash, sig *SignatureV3) (er
// V3 public keys only support RSA. // V3 public keys only support RSA.
panic("shouldn't happen") panic("shouldn't happen")
} }
panic("unreachable")
} }
// VerifyUserIdSignatureV3 returns nil iff sig is a valid signature, made by this // VerifyUserIdSignatureV3 returns nil iff sig is a valid signature, made by this

View File

@ -80,7 +80,7 @@ func TestNewEntity(t *testing.T) {
t.Errorf("failed to find bit length: %s", err) t.Errorf("failed to find bit length: %s", err)
} }
if int(bl) != defaultRSAKeyBits { if int(bl) != defaultRSAKeyBits {
t.Errorf("BitLength %v, expected %v", defaultRSAKeyBits) t.Errorf("BitLength %v, expected %v", int(bl), defaultRSAKeyBits)
} }
// Check bit-length with a config. // Check bit-length with a config.
@ -238,7 +238,7 @@ func TestEncryption(t *testing.T) {
signKey, _ := kring[0].signingKey(testTime) signKey, _ := kring[0].signingKey(testTime)
expectedKeyId := signKey.PublicKey.KeyId expectedKeyId := signKey.PublicKey.KeyId
if md.SignedByKeyId != expectedKeyId { if md.SignedByKeyId != expectedKeyId {
t.Errorf("#%d: message signed by wrong key id, got: %d, want: %d", i, *md.SignedBy, expectedKeyId) t.Errorf("#%d: message signed by wrong key id, got: %v, want: %v", i, *md.SignedBy, expectedKeyId)
} }
if md.SignedBy == nil { if md.SignedBy == nil {
t.Errorf("#%d: failed to find the signing Entity", i) t.Errorf("#%d: failed to find the signing Entity", i)

View File

@ -943,6 +943,7 @@ func (c *Conversation) processData(in []byte) (out []byte, tlvs []tlv, err error
t.data, tlvData, ok3 = getNBytes(tlvData, int(t.length)) t.data, tlvData, ok3 = getNBytes(tlvData, int(t.length))
if !ok1 || !ok2 || !ok3 { if !ok1 || !ok2 || !ok3 {
err = errors.New("otr: corrupt tlv data") err = errors.New("otr: corrupt tlv data")
return
} }
tlvs = append(tlvs, t) tlvs = append(tlvs, t)
} }
@ -1313,6 +1314,12 @@ func (priv *PrivateKey) Import(in []byte) bool {
mpis[i] = new(big.Int).SetBytes(mpiBytes) mpis[i] = new(big.Int).SetBytes(mpiBytes)
} }
for _, mpi := range mpis {
if mpi.Sign() <= 0 {
return false
}
}
priv.PrivateKey.P = mpis[0] priv.PrivateKey.P = mpis[0]
priv.PrivateKey.Q = mpis[1] priv.PrivateKey.Q = mpis[1]
priv.PrivateKey.G = mpis[2] priv.PrivateKey.G = mpis[2]

View File

@ -109,6 +109,10 @@ func ToPEM(pfxData []byte, password string) ([]*pem.Block, error) {
bags, encodedPassword, err := getSafeContents(pfxData, encodedPassword) bags, encodedPassword, err := getSafeContents(pfxData, encodedPassword)
if err != nil {
return nil, err
}
blocks := make([]*pem.Block, 0, len(bags)) blocks := make([]*pem.Block, 0, len(bags))
for _, bag := range bags { for _, bag := range bags {
block, err := convertBag(&bag, encodedPassword) block, err := convertBag(&bag, encodedPassword)

View File

@ -6,10 +6,14 @@ package poly1305
import ( import (
"bytes" "bytes"
"encoding/hex"
"flag"
"testing" "testing"
"unsafe" "unsafe"
) )
var stressFlag = flag.Bool("stress", false, "run slow stress tests")
var testData = []struct { var testData = []struct {
in, k, correct []byte in, k, correct []byte
}{ }{
@ -39,6 +43,36 @@ var testData = []struct {
[]byte{0x3b, 0x3a, 0x29, 0xe9, 0x3b, 0x21, 0x3a, 0x5c, 0x5c, 0x3b, 0x3b, 0x05, 0x3a, 0x3a, 0x8c, 0x0d}, []byte{0x3b, 0x3a, 0x29, 0xe9, 0x3b, 0x21, 0x3a, 0x5c, 0x5c, 0x3b, 0x3b, 0x05, 0x3a, 0x3a, 0x8c, 0x0d},
[]byte{0x6d, 0xc1, 0x8b, 0x8c, 0x34, 0x4c, 0xd7, 0x99, 0x27, 0x11, 0x8b, 0xbe, 0x84, 0xb7, 0xf3, 0x14}, []byte{0x6d, 0xc1, 0x8b, 0x8c, 0x34, 0x4c, 0xd7, 0x99, 0x27, 0x11, 0x8b, 0xbe, 0x84, 0xb7, 0xf3, 0x14},
}, },
{
// This test generates a result of (2^130-1) % (2^130-5).
[]byte{
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
[]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
[]byte{4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
{
// This test generates a result of (2^130-6) % (2^130-5).
[]byte{
0xfa, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
[]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
[]byte{0xfa, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
},
{
// This test generates a result of (2^130-5) % (2^130-5).
[]byte{
0xfb, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
[]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
[]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
} }
func testSum(t *testing.T, unaligned bool) { func testSum(t *testing.T, unaligned bool) {
@ -58,6 +92,39 @@ func testSum(t *testing.T, unaligned bool) {
} }
} }
func TestBurnin(t *testing.T) {
// This test can be used to sanity-check significant changes. It can
// take about many minutes to run, even on fast machines. It's disabled
// by default.
if !*stressFlag {
t.Skip("skipping without -stress")
}
var key [32]byte
var input [25]byte
var output [16]byte
for i := range key {
key[i] = 1
}
for i := range input {
input[i] = 2
}
for i := uint64(0); i < 1e10; i++ {
Sum(&output, input[:], &key)
copy(key[0:], output[:])
copy(key[16:], output[:])
copy(input[:], output[:])
copy(input[16:], output[:])
}
const expected = "5e3b866aea0b636d240c83c428f84bfa"
if got := hex.EncodeToString(output[:]); got != expected {
t.Errorf("expected %s, got %s", expected, got)
}
}
func TestSum(t *testing.T) { testSum(t, false) } func TestSum(t *testing.T) { testSum(t, false) }
func TestSumUnaligned(t *testing.T) { testSum(t, true) } func TestSumUnaligned(t *testing.T) { testSum(t, true) }

View File

@ -54,9 +54,9 @@
ADCQ t3, h1; \ ADCQ t3, h1; \
ADCQ $0, h2 ADCQ $0, h2
DATA poly1305Mask<>+0x00(SB)/8, $0x0FFFFFFC0FFFFFFF DATA ·poly1305Mask<>+0x00(SB)/8, $0x0FFFFFFC0FFFFFFF
DATA poly1305Mask<>+0x08(SB)/8, $0x0FFFFFFC0FFFFFFC DATA ·poly1305Mask<>+0x08(SB)/8, $0x0FFFFFFC0FFFFFFC
GLOBL poly1305Mask<>(SB), RODATA, $16 GLOBL ·poly1305Mask<>(SB), RODATA, $16
// func poly1305(out *[16]byte, m *byte, mlen uint64, key *[32]key) // func poly1305(out *[16]byte, m *byte, mlen uint64, key *[32]key)
TEXT ·poly1305(SB), $0-32 TEXT ·poly1305(SB), $0-32
@ -67,8 +67,8 @@ TEXT ·poly1305(SB), $0-32
MOVQ 0(AX), R11 MOVQ 0(AX), R11
MOVQ 8(AX), R12 MOVQ 8(AX), R12
ANDQ poly1305Mask<>(SB), R11 // r0 ANDQ ·poly1305Mask<>(SB), R11 // r0
ANDQ poly1305Mask<>+8(SB), R12 // r1 ANDQ ·poly1305Mask<>+8(SB), R12 // r1
XORQ R8, R8 // h0 XORQ R8, R8 // h0
XORQ R9, R9 // h1 XORQ R9, R9 // h1
XORQ R10, R10 // h2 XORQ R10, R10 // h2

View File

@ -9,12 +9,12 @@
// This code was translated into a form compatible with 5a from the public // This code was translated into a form compatible with 5a from the public
// domain source by Andrew Moon: github.com/floodyberry/poly1305-opt/blob/master/app/extensions/poly1305. // domain source by Andrew Moon: github.com/floodyberry/poly1305-opt/blob/master/app/extensions/poly1305.
DATA poly1305_init_constants_armv6<>+0x00(SB)/4, $0x3ffffff DATA ·poly1305_init_constants_armv6<>+0x00(SB)/4, $0x3ffffff
DATA poly1305_init_constants_armv6<>+0x04(SB)/4, $0x3ffff03 DATA ·poly1305_init_constants_armv6<>+0x04(SB)/4, $0x3ffff03
DATA poly1305_init_constants_armv6<>+0x08(SB)/4, $0x3ffc0ff DATA ·poly1305_init_constants_armv6<>+0x08(SB)/4, $0x3ffc0ff
DATA poly1305_init_constants_armv6<>+0x0c(SB)/4, $0x3f03fff DATA ·poly1305_init_constants_armv6<>+0x0c(SB)/4, $0x3f03fff
DATA poly1305_init_constants_armv6<>+0x10(SB)/4, $0x00fffff DATA ·poly1305_init_constants_armv6<>+0x10(SB)/4, $0x00fffff
GLOBL poly1305_init_constants_armv6<>(SB), 8, $20 GLOBL ·poly1305_init_constants_armv6<>(SB), 8, $20
// Warning: the linker may use R11 to synthesize certain instructions. Please // Warning: the linker may use R11 to synthesize certain instructions. Please
// take care and verify that no synthetic instructions use it. // take care and verify that no synthetic instructions use it.
@ -27,7 +27,7 @@ TEXT poly1305_init_ext_armv6<>(SB), NOSPLIT, $0
ADD $4, R13, R8 ADD $4, R13, R8
MOVM.IB [R4-R7], (R8) MOVM.IB [R4-R7], (R8)
MOVM.IA.W (R1), [R2-R5] MOVM.IA.W (R1), [R2-R5]
MOVW $poly1305_init_constants_armv6<>(SB), R7 MOVW $·poly1305_init_constants_armv6<>(SB), R7
MOVW R2, R8 MOVW R2, R8
MOVW R2>>26, R9 MOVW R2>>26, R9
MOVW R3>>20, g MOVW R3>>20, g

File diff suppressed because it is too large Load Diff

View File

@ -182,7 +182,10 @@ func TestCert(t *testing.T) {
func netPipe() (net.Conn, net.Conn, error) { func netPipe() (net.Conn, net.Conn, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
return nil, nil, err listener, err = net.Listen("tcp", "[::1]:0")
if err != nil {
return nil, nil, err
}
} }
defer listener.Close() defer listener.Close()
c1, err := net.Dial("tcp", listener.Addr().String()) c1, err := net.Dial("tcp", listener.Addr().String())
@ -200,6 +203,9 @@ func netPipe() (net.Conn, net.Conn, error) {
} }
func TestAuth(t *testing.T) { func TestAuth(t *testing.T) {
agent, _, cleanup := startAgent(t)
defer cleanup()
a, b, err := netPipe() a, b, err := netPipe()
if err != nil { if err != nil {
t.Fatalf("netPipe: %v", err) t.Fatalf("netPipe: %v", err)
@ -208,9 +214,6 @@ func TestAuth(t *testing.T) {
defer a.Close() defer a.Close()
defer b.Close() defer b.Close()
agent, _, cleanup := startAgent(t)
defer cleanup()
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil { if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil {
t.Errorf("Add: %v", err) t.Errorf("Add: %v", err)
} }
@ -233,7 +236,9 @@ func TestAuth(t *testing.T) {
conn.Close() conn.Close()
}() }()
conf := ssh.ClientConfig{} conf := ssh.ClientConfig{
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers)) conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers))
conn, _, _, err := ssh.NewClientConn(b, "", &conf) conn, _, _, err := ssh.NewClientConn(b, "", &conf)
if err != nil { if err != nil {

View File

@ -6,20 +6,20 @@ package agent_test
import ( import (
"log" "log"
"os"
"net" "net"
"os"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent" "golang.org/x/crypto/ssh/agent"
) )
func ExampleClientAgent() { func ExampleClientAgent() {
// ssh-agent has a UNIX socket under $SSH_AUTH_SOCK // ssh-agent has a UNIX socket under $SSH_AUTH_SOCK
socket := os.Getenv("SSH_AUTH_SOCK") socket := os.Getenv("SSH_AUTH_SOCK")
conn, err := net.Dial("unix", socket) conn, err := net.Dial("unix", socket)
if err != nil { if err != nil {
log.Fatalf("net.Dial: %v", err) log.Fatalf("net.Dial: %v", err)
} }
agentClient := agent.NewClient(conn) agentClient := agent.NewClient(conn)
config := &ssh.ClientConfig{ config := &ssh.ClientConfig{
User: "username", User: "username",
@ -29,6 +29,7 @@ func ExampleClientAgent() {
// wants it. // wants it.
ssh.PublicKeysCallback(agentClient.Signers), ssh.PublicKeysCallback(agentClient.Signers),
}, },
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
} }
sshc, err := ssh.Dial("tcp", "localhost:22", config) sshc, err := ssh.Dial("tcp", "localhost:22", config)

View File

@ -56,7 +56,9 @@ func TestSetupForwardAgent(t *testing.T) {
incoming <- conn incoming <- conn
}() }()
conf := ssh.ClientConfig{} conf := ssh.ClientConfig{
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf) conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf)
if err != nil { if err != nil {
t.Fatalf("NewClientConn: %v", err) t.Fatalf("NewClientConn: %v", err)

View File

@ -251,10 +251,18 @@ type CertChecker struct {
// for user certificates. // for user certificates.
SupportedCriticalOptions []string SupportedCriticalOptions []string
// IsAuthority should return true if the key is recognized as // IsUserAuthority should return true if the key is recognized as an
// an authority. This allows for certificates to be signed by other // authority for the given user certificate. This allows for
// certificates. // certificates to be signed by other certificates. This must be set
IsAuthority func(auth PublicKey) bool // if this CertChecker will be checking user certificates.
IsUserAuthority func(auth PublicKey) bool
// IsHostAuthority should report whether the key is recognized as
// an authority for this host. This allows for certificates to be
// signed by other keys, and for those other keys to only be valid
// signers for particular hostnames. This must be set if this
// CertChecker will be checking host certificates.
IsHostAuthority func(auth PublicKey, address string) bool
// Clock is used for verifying time stamps. If nil, time.Now // Clock is used for verifying time stamps. If nil, time.Now
// is used. // is used.
@ -268,7 +276,7 @@ type CertChecker struct {
// HostKeyFallback is called when CertChecker.CheckHostKey encounters a // HostKeyFallback is called when CertChecker.CheckHostKey encounters a
// public key that is not a certificate. It must implement host key // public key that is not a certificate. It must implement host key
// validation or else, if nil, all such keys are rejected. // validation or else, if nil, all such keys are rejected.
HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error HostKeyFallback HostKeyCallback
// IsRevoked is called for each certificate so that revocation checking // IsRevoked is called for each certificate so that revocation checking
// can be implemented. It should return true if the given certificate // can be implemented. It should return true if the given certificate
@ -356,7 +364,13 @@ func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
} }
} }
if !c.IsAuthority(cert.SignatureKey) { // if this is a host cert, principal is the remote hostname as passed
// to CheckHostCert.
if cert.CertType == HostCert && !c.IsHostAuthority(cert.SignatureKey, principal) {
return fmt.Errorf("ssh: no authorities for hostname: %v", principal)
}
if cert.CertType == UserCert && !c.IsUserAuthority(cert.SignatureKey) {
return fmt.Errorf("ssh: certificate signed by unrecognized authority") return fmt.Errorf("ssh: certificate signed by unrecognized authority")
} }

View File

@ -104,7 +104,7 @@ func TestValidateCert(t *testing.T) {
t.Fatalf("got %v (%T), want *Certificate", key, key) t.Fatalf("got %v (%T), want *Certificate", key, key)
} }
checker := CertChecker{} checker := CertChecker{}
checker.IsAuthority = func(k PublicKey) bool { checker.IsUserAuthority = func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal()) return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal())
} }
@ -142,7 +142,7 @@ func TestValidateCertTime(t *testing.T) {
checker := CertChecker{ checker := CertChecker{
Clock: func() time.Time { return time.Unix(ts, 0) }, Clock: func() time.Time { return time.Unix(ts, 0) },
} }
checker.IsAuthority = func(k PublicKey) bool { checker.IsUserAuthority = func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), return bytes.Equal(k.Marshal(),
testPublicKeys["ecdsa"].Marshal()) testPublicKeys["ecdsa"].Marshal())
} }
@ -160,7 +160,7 @@ func TestValidateCertTime(t *testing.T) {
func TestHostKeyCert(t *testing.T) { func TestHostKeyCert(t *testing.T) {
cert := &Certificate{ cert := &Certificate{
ValidPrincipals: []string{"hostname", "hostname.domain"}, ValidPrincipals: []string{"hostname", "hostname.domain", "otherhost"},
Key: testPublicKeys["rsa"], Key: testPublicKeys["rsa"],
ValidBefore: CertTimeInfinity, ValidBefore: CertTimeInfinity,
CertType: HostCert, CertType: HostCert,
@ -168,8 +168,8 @@ func TestHostKeyCert(t *testing.T) {
cert.SignCert(rand.Reader, testSigners["ecdsa"]) cert.SignCert(rand.Reader, testSigners["ecdsa"])
checker := &CertChecker{ checker := &CertChecker{
IsAuthority: func(p PublicKey) bool { IsHostAuthority: func(p PublicKey, h string) bool {
return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal()) return h == "hostname" && bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal())
}, },
} }
@ -178,7 +178,7 @@ func TestHostKeyCert(t *testing.T) {
t.Errorf("NewCertSigner: %v", err) t.Errorf("NewCertSigner: %v", err)
} }
for _, name := range []string{"hostname", "otherhost"} { for _, name := range []string{"hostname", "otherhost", "lasthost"} {
c1, c2, err := netPipe() c1, c2, err := netPipe()
if err != nil { if err != nil {
t.Fatalf("netPipe: %v", err) t.Fatalf("netPipe: %v", err)

View File

@ -461,8 +461,8 @@ func (m *mux) newChannel(chanType string, direction channelDirection, extraData
pending: newBuffer(), pending: newBuffer(),
extPending: newBuffer(), extPending: newBuffer(),
direction: direction, direction: direction,
incomingRequests: make(chan *Request, 16), incomingRequests: make(chan *Request, chanSize),
msg: make(chan interface{}, 16), msg: make(chan interface{}, chanSize),
chanType: chanType, chanType: chanType,
extraData: extraData, extraData: extraData,
mux: m, mux: m,

View File

@ -135,6 +135,7 @@ const prefixLen = 5
type streamPacketCipher struct { type streamPacketCipher struct {
mac hash.Hash mac hash.Hash
cipher cipher.Stream cipher cipher.Stream
etm bool
// The following members are to avoid per-packet allocations. // The following members are to avoid per-packet allocations.
prefix [prefixLen]byte prefix [prefixLen]byte
@ -150,7 +151,14 @@ func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, err
return nil, err return nil, err
} }
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) var encryptedPaddingLength [1]byte
if s.mac != nil && s.etm {
copy(encryptedPaddingLength[:], s.prefix[4:5])
s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5])
} else {
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
}
length := binary.BigEndian.Uint32(s.prefix[0:4]) length := binary.BigEndian.Uint32(s.prefix[0:4])
paddingLength := uint32(s.prefix[4]) paddingLength := uint32(s.prefix[4])
@ -159,7 +167,12 @@ func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, err
s.mac.Reset() s.mac.Reset()
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
s.mac.Write(s.seqNumBytes[:]) s.mac.Write(s.seqNumBytes[:])
s.mac.Write(s.prefix[:]) if s.etm {
s.mac.Write(s.prefix[:4])
s.mac.Write(encryptedPaddingLength[:])
} else {
s.mac.Write(s.prefix[:])
}
macSize = uint32(s.mac.Size()) macSize = uint32(s.mac.Size())
} }
@ -184,10 +197,17 @@ func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, err
} }
mac := s.packetData[length-1:] mac := s.packetData[length-1:]
data := s.packetData[:length-1] data := s.packetData[:length-1]
if s.mac != nil && s.etm {
s.mac.Write(data)
}
s.cipher.XORKeyStream(data, data) s.cipher.XORKeyStream(data, data)
if s.mac != nil { if s.mac != nil {
s.mac.Write(data) if !s.etm {
s.mac.Write(data)
}
s.macResult = s.mac.Sum(s.macResult[:0]) s.macResult = s.mac.Sum(s.macResult[:0])
if subtle.ConstantTimeCompare(s.macResult, mac) != 1 { if subtle.ConstantTimeCompare(s.macResult, mac) != 1 {
return nil, errors.New("ssh: MAC failure") return nil, errors.New("ssh: MAC failure")
@ -203,7 +223,13 @@ func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Rea
return errors.New("ssh: packet too large") return errors.New("ssh: packet too large")
} }
paddingLength := packetSizeMultiple - (prefixLen+len(packet))%packetSizeMultiple aadlen := 0
if s.mac != nil && s.etm {
// packet length is not encrypted for EtM modes
aadlen = 4
}
paddingLength := packetSizeMultiple - (prefixLen+len(packet)-aadlen)%packetSizeMultiple
if paddingLength < 4 { if paddingLength < 4 {
paddingLength += packetSizeMultiple paddingLength += packetSizeMultiple
} }
@ -220,15 +246,37 @@ func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Rea
s.mac.Reset() s.mac.Reset()
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
s.mac.Write(s.seqNumBytes[:]) s.mac.Write(s.seqNumBytes[:])
if s.etm {
// For EtM algorithms, the packet length must stay unencrypted,
// but the following data (padding length) must be encrypted
s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5])
}
s.mac.Write(s.prefix[:]) s.mac.Write(s.prefix[:])
if !s.etm {
// For non-EtM algorithms, the algorithm is applied on unencrypted data
s.mac.Write(packet)
s.mac.Write(padding)
}
}
if !(s.mac != nil && s.etm) {
// For EtM algorithms, the padding length has already been encrypted
// and the packet length must remain unencrypted
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
}
s.cipher.XORKeyStream(packet, packet)
s.cipher.XORKeyStream(padding, padding)
if s.mac != nil && s.etm {
// For EtM algorithms, packet and padding must be encrypted
s.mac.Write(packet) s.mac.Write(packet)
s.mac.Write(padding) s.mac.Write(padding)
} }
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
s.cipher.XORKeyStream(packet, packet)
s.cipher.XORKeyStream(padding, padding)
if _, err := w.Write(s.prefix[:]); err != nil { if _, err := w.Write(s.prefix[:]); err != nil {
return err return err
} }

View File

@ -26,39 +26,41 @@ func TestPacketCiphers(t *testing.T) {
defer delete(cipherModes, aes128cbcID) defer delete(cipherModes, aes128cbcID)
for cipher := range cipherModes { for cipher := range cipherModes {
kr := &kexResult{Hash: crypto.SHA1} for mac := range macModes {
algs := directionAlgorithms{ kr := &kexResult{Hash: crypto.SHA1}
Cipher: cipher, algs := directionAlgorithms{
MAC: "hmac-sha1", Cipher: cipher,
Compression: "none", MAC: mac,
} Compression: "none",
client, err := newPacketCipher(clientKeys, algs, kr) }
if err != nil { client, err := newPacketCipher(clientKeys, algs, kr)
t.Errorf("newPacketCipher(client, %q): %v", cipher, err) if err != nil {
continue t.Errorf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
} continue
server, err := newPacketCipher(clientKeys, algs, kr) }
if err != nil { server, err := newPacketCipher(clientKeys, algs, kr)
t.Errorf("newPacketCipher(client, %q): %v", cipher, err) if err != nil {
continue t.Errorf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
} continue
}
want := "bla bla" want := "bla bla"
input := []byte(want) input := []byte(want)
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
if err := client.writePacket(0, buf, rand.Reader, input); err != nil { if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
t.Errorf("writePacket(%q): %v", cipher, err) t.Errorf("writePacket(%q, %q): %v", cipher, mac, err)
continue continue
} }
packet, err := server.readPacket(0, buf) packet, err := server.readPacket(0, buf)
if err != nil { if err != nil {
t.Errorf("readPacket(%q): %v", cipher, err) t.Errorf("readPacket(%q, %q): %v", cipher, mac, err)
continue continue
} }
if string(packet) != want { if string(packet) != want {
t.Errorf("roundtrip(%q): got %q, want %q", cipher, packet, want) t.Errorf("roundtrip(%q, %q): got %q, want %q", cipher, mac, packet, want)
}
} }
} }
} }

View File

@ -5,6 +5,7 @@
package ssh package ssh
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -13,7 +14,7 @@ import (
) )
// Client implements a traditional SSH client that supports shells, // Client implements a traditional SSH client that supports shells,
// subprocesses, port forwarding and tunneled dialing. // subprocesses, TCP port/streamlocal forwarding and tunneled dialing.
type Client struct { type Client struct {
Conn Conn
@ -40,7 +41,7 @@ func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel {
return nil return nil
} }
ch = make(chan NewChannel, 16) ch = make(chan NewChannel, chanSize)
c.channelHandlers[channelType] = ch c.channelHandlers[channelType] = ch
return ch return ch
} }
@ -59,6 +60,7 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
conn.forwards.closeAll() conn.forwards.closeAll()
}() }()
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip")) go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip"))
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
return conn return conn
} }
@ -68,6 +70,11 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) {
fullConf := *config fullConf := *config
fullConf.SetDefaults() fullConf.SetDefaults()
if fullConf.HostKeyCallback == nil {
c.Close()
return nil, nil, nil, errors.New("ssh: must specify HostKeyCallback")
}
conn := &connection{ conn := &connection{
sshConn: sshConn{conn: c}, sshConn: sshConn{conn: c},
} }
@ -97,13 +104,11 @@ func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) e
c.transport = newClientTransport( c.transport = newClientTransport(
newTransport(c.sshConn.conn, config.Rand, true /* is client */), newTransport(c.sshConn.conn, config.Rand, true /* is client */),
c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr()) c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr())
if err := c.transport.requestInitialKeyChange(); err != nil { if err := c.transport.waitSession(); err != nil {
return err return err
} }
// We just did the key change, so the session ID is established.
c.sessionID = c.transport.getSessionID() c.sessionID = c.transport.getSessionID()
return c.clientAuthenticate(config) return c.clientAuthenticate(config)
} }
@ -175,6 +180,13 @@ func Dial(network, addr string, config *ClientConfig) (*Client, error) {
return NewClient(c, chans, reqs), nil return NewClient(c, chans, reqs), nil
} }
// HostKeyCallback is the function type used for verifying server
// keys. A HostKeyCallback must return nil if the host key is OK, or
// an error to reject it. It receives the hostname as passed to Dial
// or NewClientConn. The remote address is the RemoteAddr of the
// net.Conn underlying the the SSH connection.
type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
// A ClientConfig structure is used to configure a Client. It must not be // A ClientConfig structure is used to configure a Client. It must not be
// modified after having been passed to an SSH function. // modified after having been passed to an SSH function.
type ClientConfig struct { type ClientConfig struct {
@ -190,10 +202,12 @@ type ClientConfig struct {
// be used during authentication. // be used during authentication.
Auth []AuthMethod Auth []AuthMethod
// HostKeyCallback, if not nil, is called during the cryptographic // HostKeyCallback is called during the cryptographic
// handshake to validate the server's host key. A nil HostKeyCallback // handshake to validate the server's host key. The client
// implies that all host keys are accepted. // configuration must supply this callback for the connection
HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error // to succeed. The functions InsecureIgnoreHostKey or
// FixedHostKey can be used for simplistic host key checks.
HostKeyCallback HostKeyCallback
// ClientVersion contains the version identification string that will // ClientVersion contains the version identification string that will
// be used for the connection. If empty, a reasonable default is used. // be used for the connection. If empty, a reasonable default is used.
@ -211,3 +225,33 @@ type ClientConfig struct {
// A Timeout of zero means no timeout. // A Timeout of zero means no timeout.
Timeout time.Duration Timeout time.Duration
} }
// InsecureIgnoreHostKey returns a function that can be used for
// ClientConfig.HostKeyCallback to accept any host key. It should
// not be used for production code.
func InsecureIgnoreHostKey() HostKeyCallback {
return func(hostname string, remote net.Addr, key PublicKey) error {
return nil
}
}
type fixedHostKey struct {
key PublicKey
}
func (f *fixedHostKey) check(hostname string, remote net.Addr, key PublicKey) error {
if f.key == nil {
return fmt.Errorf("ssh: required host key was nil")
}
if !bytes.Equal(key.Marshal(), f.key.Marshal()) {
return fmt.Errorf("ssh: host key mismatch")
}
return nil
}
// FixedHostKey returns a function for use in
// ClientConfig.HostKeyCallback to accept only a specific host key.
func FixedHostKey(key PublicKey) HostKeyCallback {
hk := &fixedHostKey{key}
return hk.check
}

View File

@ -30,8 +30,10 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
// then any untried methods suggested by the server. // then any untried methods suggested by the server.
tried := make(map[string]bool) tried := make(map[string]bool)
var lastMethods []string var lastMethods []string
sessionID := c.transport.getSessionID()
for auth := AuthMethod(new(noneAuth)); auth != nil; { for auth := AuthMethod(new(noneAuth)); auth != nil; {
ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand) ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand)
if err != nil { if err != nil {
return err return err
} }
@ -177,31 +179,26 @@ func (cb publicKeyCallback) method() string {
} }
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
// Authentication is performed in two stages. The first stage sends an // Authentication is performed by sending an enquiry to test if a key is
// enquiry to test if each key is acceptable to the remote. The second // acceptable to the remote. If the key is acceptable, the client will
// stage attempts to authenticate with the valid keys obtained in the // attempt to authenticate with the valid key. If not the client will repeat
// first stage. // the process with the remaining keys.
signers, err := cb() signers, err := cb()
if err != nil { if err != nil {
return false, nil, err return false, nil, err
} }
var validKeys []Signer
for _, signer := range signers {
if ok, err := validateKey(signer.PublicKey(), user, c); ok {
validKeys = append(validKeys, signer)
} else {
if err != nil {
return false, nil, err
}
}
}
// methods that may continue if this auth is not successful.
var methods []string var methods []string
for _, signer := range validKeys { for _, signer := range signers {
pub := signer.PublicKey() ok, err := validateKey(signer.PublicKey(), user, c)
if err != nil {
return false, nil, err
}
if !ok {
continue
}
pub := signer.PublicKey()
pubKey := pub.Marshal() pubKey := pub.Marshal()
sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{ sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{
User: user, User: user,
@ -234,13 +231,29 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
if err != nil { if err != nil {
return false, nil, err return false, nil, err
} }
if success {
// If authentication succeeds or the list of available methods does not
// contain the "publickey" method, do not attempt to authenticate with any
// other keys. According to RFC 4252 Section 7, the latter can occur when
// additional authentication methods are required.
if success || !containsMethod(methods, cb.method()) {
return success, methods, err return success, methods, err
} }
} }
return false, methods, nil return false, methods, nil
} }
func containsMethod(methods []string, method string) bool {
for _, m := range methods {
if m == method {
return true
}
}
return false
}
// validateKey validates the key provided is acceptable to the server. // validateKey validates the key provided is acceptable to the server.
func validateKey(key PublicKey, user string, c packetConn) (bool, error) { func validateKey(key PublicKey, user string, c packetConn) (bool, error) {
pubKey := key.Marshal() pubKey := key.Marshal()

View File

@ -38,7 +38,7 @@ func tryAuth(t *testing.T, config *ClientConfig) error {
defer c2.Close() defer c2.Close()
certChecker := CertChecker{ certChecker := CertChecker{
IsAuthority: func(k PublicKey) bool { IsUserAuthority: func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal()) return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal())
}, },
UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
@ -76,9 +76,6 @@ func tryAuth(t *testing.T, config *ClientConfig) error {
} }
return nil, errors.New("keyboard-interactive failed") return nil, errors.New("keyboard-interactive failed")
}, },
AuthLogCallback: func(conn ConnMetadata, method string, err error) {
t.Logf("user %q, method %q: %v", conn.User(), method, err)
},
} }
serverConfig.AddHostKey(testSigners["rsa"]) serverConfig.AddHostKey(testSigners["rsa"])
@ -93,6 +90,7 @@ func TestClientAuthPublicKey(t *testing.T) {
Auth: []AuthMethod{ Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]), PublicKeys(testSigners["rsa"]),
}, },
HostKeyCallback: InsecureIgnoreHostKey(),
} }
if err := tryAuth(t, config); err != nil { if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err) t.Fatalf("unable to dial remote side: %s", err)
@ -105,6 +103,7 @@ func TestAuthMethodPassword(t *testing.T) {
Auth: []AuthMethod{ Auth: []AuthMethod{
Password(clientPassword), Password(clientPassword),
}, },
HostKeyCallback: InsecureIgnoreHostKey(),
} }
if err := tryAuth(t, config); err != nil { if err := tryAuth(t, config); err != nil {
@ -124,6 +123,7 @@ func TestAuthMethodFallback(t *testing.T) {
return "WRONG", nil return "WRONG", nil
}), }),
}, },
HostKeyCallback: InsecureIgnoreHostKey(),
} }
if err := tryAuth(t, config); err != nil { if err := tryAuth(t, config); err != nil {
@ -142,6 +142,7 @@ func TestAuthMethodWrongPassword(t *testing.T) {
Password("wrong"), Password("wrong"),
PublicKeys(testSigners["rsa"]), PublicKeys(testSigners["rsa"]),
}, },
HostKeyCallback: InsecureIgnoreHostKey(),
} }
if err := tryAuth(t, config); err != nil { if err := tryAuth(t, config); err != nil {
@ -159,6 +160,7 @@ func TestAuthMethodKeyboardInteractive(t *testing.T) {
Auth: []AuthMethod{ Auth: []AuthMethod{
KeyboardInteractive(answers.Challenge), KeyboardInteractive(answers.Challenge),
}, },
HostKeyCallback: InsecureIgnoreHostKey(),
} }
if err := tryAuth(t, config); err != nil { if err := tryAuth(t, config); err != nil {
@ -204,6 +206,7 @@ func TestAuthMethodRSAandDSA(t *testing.T) {
Auth: []AuthMethod{ Auth: []AuthMethod{
PublicKeys(testSigners["dsa"], testSigners["rsa"]), PublicKeys(testSigners["dsa"], testSigners["rsa"]),
}, },
HostKeyCallback: InsecureIgnoreHostKey(),
} }
if err := tryAuth(t, config); err != nil { if err := tryAuth(t, config); err != nil {
t.Fatalf("client could not authenticate with rsa key: %v", err) t.Fatalf("client could not authenticate with rsa key: %v", err)
@ -220,6 +223,7 @@ func TestClientHMAC(t *testing.T) {
Config: Config{ Config: Config{
MACs: []string{mac}, MACs: []string{mac},
}, },
HostKeyCallback: InsecureIgnoreHostKey(),
} }
if err := tryAuth(t, config); err != nil { if err := tryAuth(t, config); err != nil {
t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err) t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err)
@ -255,6 +259,7 @@ func TestClientUnsupportedKex(t *testing.T) {
Config: Config{ Config: Config{
KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported
}, },
HostKeyCallback: InsecureIgnoreHostKey(),
} }
if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") { if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") {
t.Errorf("got %v, expected 'common algorithm'", err) t.Errorf("got %v, expected 'common algorithm'", err)
@ -274,22 +279,23 @@ func TestClientLoginCert(t *testing.T) {
} }
clientConfig := &ClientConfig{ clientConfig := &ClientConfig{
User: "user", User: "user",
HostKeyCallback: InsecureIgnoreHostKey(),
} }
clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner)) clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner))
t.Log("should succeed") // should succeed
if err := tryAuth(t, clientConfig); err != nil { if err := tryAuth(t, clientConfig); err != nil {
t.Errorf("cert login failed: %v", err) t.Errorf("cert login failed: %v", err)
} }
t.Log("corrupted signature") // corrupted signature
cert.Signature.Blob[0]++ cert.Signature.Blob[0]++
if err := tryAuth(t, clientConfig); err == nil { if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with corrupted sig") t.Errorf("cert login passed with corrupted sig")
} }
t.Log("revoked") // revoked
cert.Serial = 666 cert.Serial = 666
cert.SignCert(rand.Reader, testSigners["ecdsa"]) cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil { if err := tryAuth(t, clientConfig); err == nil {
@ -297,13 +303,13 @@ func TestClientLoginCert(t *testing.T) {
} }
cert.Serial = 1 cert.Serial = 1
t.Log("sign with wrong key") // sign with wrong key
cert.SignCert(rand.Reader, testSigners["dsa"]) cert.SignCert(rand.Reader, testSigners["dsa"])
if err := tryAuth(t, clientConfig); err == nil { if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with non-authoritative key") t.Errorf("cert login passed with non-authoritative key")
} }
t.Log("host cert") // host cert
cert.CertType = HostCert cert.CertType = HostCert
cert.SignCert(rand.Reader, testSigners["ecdsa"]) cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil { if err := tryAuth(t, clientConfig); err == nil {
@ -311,14 +317,14 @@ func TestClientLoginCert(t *testing.T) {
} }
cert.CertType = UserCert cert.CertType = UserCert
t.Log("principal specified") // principal specified
cert.ValidPrincipals = []string{"user"} cert.ValidPrincipals = []string{"user"}
cert.SignCert(rand.Reader, testSigners["ecdsa"]) cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err != nil { if err := tryAuth(t, clientConfig); err != nil {
t.Errorf("cert login failed: %v", err) t.Errorf("cert login failed: %v", err)
} }
t.Log("wrong principal specified") // wrong principal specified
cert.ValidPrincipals = []string{"fred"} cert.ValidPrincipals = []string{"fred"}
cert.SignCert(rand.Reader, testSigners["ecdsa"]) cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil { if err := tryAuth(t, clientConfig); err == nil {
@ -326,22 +332,22 @@ func TestClientLoginCert(t *testing.T) {
} }
cert.ValidPrincipals = nil cert.ValidPrincipals = nil
t.Log("added critical option") // added critical option
cert.CriticalOptions = map[string]string{"root-access": "yes"} cert.CriticalOptions = map[string]string{"root-access": "yes"}
cert.SignCert(rand.Reader, testSigners["ecdsa"]) cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil { if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with unrecognized critical option") t.Errorf("cert login passed with unrecognized critical option")
} }
t.Log("allowed source address") // allowed source address
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24"} cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24,::42/120"}
cert.SignCert(rand.Reader, testSigners["ecdsa"]) cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err != nil { if err := tryAuth(t, clientConfig); err != nil {
t.Errorf("cert login with source-address failed: %v", err) t.Errorf("cert login with source-address failed: %v", err)
} }
t.Log("disallowed source address") // disallowed source address
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42"} cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42,::42"}
cert.SignCert(rand.Reader, testSigners["ecdsa"]) cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil { if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login with source-address succeeded") t.Errorf("cert login with source-address succeeded")
@ -364,6 +370,7 @@ func testPermissionsPassing(withPermissions bool, t *testing.T) {
Auth: []AuthMethod{ Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]), PublicKeys(testSigners["rsa"]),
}, },
HostKeyCallback: InsecureIgnoreHostKey(),
} }
if withPermissions { if withPermissions {
clientConfig.User = "permissions" clientConfig.User = "permissions"
@ -410,6 +417,7 @@ func TestRetryableAuth(t *testing.T) {
}), 2), }), 2),
PublicKeys(testSigners["rsa"]), PublicKeys(testSigners["rsa"]),
}, },
HostKeyCallback: InsecureIgnoreHostKey(),
} }
if err := tryAuth(t, config); err != nil { if err := tryAuth(t, config); err != nil {
@ -431,7 +439,8 @@ func ExampleRetryableAuthMethod(t *testing.T) {
} }
config := &ClientConfig{ config := &ClientConfig{
User: user, HostKeyCallback: InsecureIgnoreHostKey(),
User: user,
Auth: []AuthMethod{ Auth: []AuthMethod{
RetryableAuthMethod(KeyboardInteractiveChallenge(Cb), NumberOfPrompts), RetryableAuthMethod(KeyboardInteractiveChallenge(Cb), NumberOfPrompts),
}, },
@ -451,7 +460,8 @@ func TestClientAuthNone(t *testing.T) {
serverConfig.AddHostKey(testSigners["rsa"]) serverConfig.AddHostKey(testSigners["rsa"])
clientConfig := &ClientConfig{ clientConfig := &ClientConfig{
User: user, User: user,
HostKeyCallback: InsecureIgnoreHostKey(),
} }
c1, c2, err := netPipe() c1, c2, err := netPipe()
@ -470,3 +480,100 @@ func TestClientAuthNone(t *testing.T) {
t.Fatalf("server: got %q, want %q", serverConn.User(), user) t.Fatalf("server: got %q, want %q", serverConn.User(), user)
} }
} }
// Test if authentication attempts are limited on server when MaxAuthTries is set
func TestClientAuthMaxAuthTries(t *testing.T) {
user := "testuser"
serverConfig := &ServerConfig{
MaxAuthTries: 2,
PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) {
if conn.User() == "testuser" && string(pass) == "right" {
return nil, nil
}
return nil, errors.New("password auth failed")
},
}
serverConfig.AddHostKey(testSigners["rsa"])
expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{
Reason: 2,
Message: "too many authentication failures",
})
for tries := 2; tries < 4; tries++ {
n := tries
clientConfig := &ClientConfig{
User: user,
Auth: []AuthMethod{
RetryableAuthMethod(PasswordCallback(func() (string, error) {
n--
if n == 0 {
return "right", nil
} else {
return "wrong", nil
}
}), tries),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go newServer(c1, serverConfig)
_, _, _, err = NewClientConn(c2, "", clientConfig)
if tries > 2 {
if err == nil {
t.Fatalf("client: got no error, want %s", expectedErr)
} else if err.Error() != expectedErr.Error() {
t.Fatalf("client: got %s, want %s", err, expectedErr)
}
} else {
if err != nil {
t.Fatalf("client: got %s, want no error", err)
}
}
}
}
// Test if authentication attempts are correctly limited on server
// when more public keys are provided then MaxAuthTries
func TestClientAuthMaxAuthTriesPublicKey(t *testing.T) {
signers := []Signer{}
for i := 0; i < 6; i++ {
signers = append(signers, testSigners["dsa"])
}
validConfig := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(append([]Signer{testSigners["rsa"]}, signers...)...),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, validConfig); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{
Reason: 2,
Message: "too many authentication failures",
})
invalidConfig := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(append(signers, testSigners["rsa"])...),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, invalidConfig); err == nil {
t.Fatalf("client: got no error, want %s", expectedErr)
} else if err.Error() != expectedErr.Error() {
t.Fatalf("client: got %s, want %s", err, expectedErr)
}
}

View File

@ -6,6 +6,7 @@ package ssh
import ( import (
"net" "net"
"strings"
"testing" "testing"
) )
@ -13,6 +14,7 @@ func testClientVersion(t *testing.T, config *ClientConfig, expected string) {
clientConn, serverConn := net.Pipe() clientConn, serverConn := net.Pipe()
defer clientConn.Close() defer clientConn.Close()
receivedVersion := make(chan string, 1) receivedVersion := make(chan string, 1)
config.HostKeyCallback = InsecureIgnoreHostKey()
go func() { go func() {
version, err := readVersion(serverConn) version, err := readVersion(serverConn)
if err != nil { if err != nil {
@ -37,3 +39,43 @@ func TestCustomClientVersion(t *testing.T) {
func TestDefaultClientVersion(t *testing.T) { func TestDefaultClientVersion(t *testing.T) {
testClientVersion(t, &ClientConfig{}, packageVersion) testClientVersion(t, &ClientConfig{}, packageVersion)
} }
func TestHostKeyCheck(t *testing.T) {
for _, tt := range []struct {
name string
wantError string
key PublicKey
}{
{"no callback", "must specify HostKeyCallback", nil},
{"correct key", "", testSigners["rsa"].PublicKey()},
{"mismatch", "mismatch", testSigners["ecdsa"].PublicKey()},
} {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
serverConf := &ServerConfig{
NoClientAuth: true,
}
serverConf.AddHostKey(testSigners["rsa"])
go NewServerConn(c1, serverConf)
clientConf := ClientConfig{
User: "user",
}
if tt.key != nil {
clientConf.HostKeyCallback = FixedHostKey(tt.key)
}
_, _, _, err = NewClientConn(c2, "", &clientConf)
if err != nil {
if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) {
t.Errorf("%s: got error %q, missing %q", tt.name, err.Error(), tt.wantError)
}
} else if tt.wantError != "" {
t.Errorf("%s: succeeded, but want error string %q", tt.name, tt.wantError)
}
}
}

View File

@ -9,6 +9,7 @@ import (
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"io" "io"
"math"
"sync" "sync"
_ "crypto/sha1" _ "crypto/sha1"
@ -40,7 +41,7 @@ var supportedKexAlgos = []string{
kexAlgoDH14SHA1, kexAlgoDH1SHA1, kexAlgoDH14SHA1, kexAlgoDH1SHA1,
} }
// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods // supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods
// of authenticating servers) in preference order. // of authenticating servers) in preference order.
var supportedHostKeyAlgos = []string{ var supportedHostKeyAlgos = []string{
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
@ -56,7 +57,7 @@ var supportedHostKeyAlgos = []string{
// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed // This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
// because they have reached the end of their useful life. // because they have reached the end of their useful life.
var supportedMACs = []string{ var supportedMACs = []string{
"hmac-sha2-256", "hmac-sha1", "hmac-sha1-96", "hmac-sha2-256-etm@openssh.com", "hmac-sha2-256", "hmac-sha1", "hmac-sha1-96",
} }
var supportedCompressions = []string{compressionNone} var supportedCompressions = []string{compressionNone}
@ -104,6 +105,21 @@ type directionAlgorithms struct {
Compression string Compression string
} }
// rekeyBytes returns a rekeying intervals in bytes.
func (a *directionAlgorithms) rekeyBytes() int64 {
// According to RFC4344 block ciphers should rekey after
// 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is
// 128.
switch a.Cipher {
case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcmCipherID, aes128cbcID:
return 16 * (1 << 32)
}
// For others, stick with RFC4253 recommendation to rekey after 1 Gb of data.
return 1 << 30
}
type algorithms struct { type algorithms struct {
kex string kex string
hostKey string hostKey string
@ -171,7 +187,7 @@ type Config struct {
// The maximum number of bytes sent or received after which a // The maximum number of bytes sent or received after which a
// new key is negotiated. It must be at least 256. If // new key is negotiated. It must be at least 256. If
// unspecified, 1 gigabyte is used. // unspecified, a size suitable for the chosen cipher is used.
RekeyThreshold uint64 RekeyThreshold uint64
// The allowed key exchanges algorithms. If unspecified then a // The allowed key exchanges algorithms. If unspecified then a
@ -215,11 +231,12 @@ func (c *Config) SetDefaults() {
} }
if c.RekeyThreshold == 0 { if c.RekeyThreshold == 0 {
// RFC 4253, section 9 suggests rekeying after 1G. // cipher specific default
c.RekeyThreshold = 1 << 30 } else if c.RekeyThreshold < minRekeyThreshold {
}
if c.RekeyThreshold < minRekeyThreshold {
c.RekeyThreshold = minRekeyThreshold c.RekeyThreshold = minRekeyThreshold
} else if c.RekeyThreshold >= math.MaxInt64 {
// Avoid weirdness if somebody uses -1 as a threshold.
c.RekeyThreshold = math.MaxInt64
} }
} }

View File

@ -14,5 +14,8 @@ others.
References: References:
[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD [PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1 [SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
This package does not fall under the stability promise of the Go language itself,
so its API may be changed when pressing needs arise.
*/ */
package ssh // import "golang.org/x/crypto/ssh" package ssh // import "golang.org/x/crypto/ssh"

View File

@ -5,12 +5,16 @@
package ssh_test package ssh_test
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
"os"
"path/filepath"
"strings"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal" "golang.org/x/crypto/ssh/terminal"
@ -90,8 +94,6 @@ func ExampleNewServerConn() {
// The incoming Request channel must be serviced. // The incoming Request channel must be serviced.
go ssh.DiscardRequests(reqs) go ssh.DiscardRequests(reqs)
// Service the incoming Channel channel.
// Service the incoming Channel channel. // Service the incoming Channel channel.
for newChannel := range chans { for newChannel := range chans {
// Channels have a type, depending on the application level // Channels have a type, depending on the application level
@ -131,16 +133,59 @@ func ExampleNewServerConn() {
} }
} }
func ExampleHostKeyCheck() {
// Every client must provide a host key check. Here is a
// simple-minded parse of OpenSSH's known_hosts file
host := "hostname"
file, err := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts"))
if err != nil {
log.Fatal(err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
var hostKey ssh.PublicKey
for scanner.Scan() {
fields := strings.Split(scanner.Text(), " ")
if len(fields) != 3 {
continue
}
if strings.Contains(fields[0], host) {
var err error
hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes())
if err != nil {
log.Fatalf("error parsing %q: %v", fields[2], err)
}
break
}
}
if hostKey == nil {
log.Fatalf("no hostkey for %s", host)
}
config := ssh.ClientConfig{
User: os.Getenv("USER"),
HostKeyCallback: ssh.FixedHostKey(hostKey),
}
_, err = ssh.Dial("tcp", host+":22", &config)
log.Println(err)
}
func ExampleDial() { func ExampleDial() {
var hostKey ssh.PublicKey
// An SSH client is represented with a ClientConn. // An SSH client is represented with a ClientConn.
// //
// To authenticate with the remote server you must pass at least one // To authenticate with the remote server you must pass at least one
// implementation of AuthMethod via the Auth field in ClientConfig. // implementation of AuthMethod via the Auth field in ClientConfig,
// and provide a HostKeyCallback.
config := &ssh.ClientConfig{ config := &ssh.ClientConfig{
User: "username", User: "username",
Auth: []ssh.AuthMethod{ Auth: []ssh.AuthMethod{
ssh.Password("yourpassword"), ssh.Password("yourpassword"),
}, },
HostKeyCallback: ssh.FixedHostKey(hostKey),
} }
client, err := ssh.Dial("tcp", "yourserver.com:22", config) client, err := ssh.Dial("tcp", "yourserver.com:22", config)
if err != nil { if err != nil {
@ -166,6 +211,7 @@ func ExampleDial() {
} }
func ExamplePublicKeys() { func ExamplePublicKeys() {
var hostKey ssh.PublicKey
// A public key may be used to authenticate against the remote // A public key may be used to authenticate against the remote
// server by using an unencrypted PEM-encoded private key file. // server by using an unencrypted PEM-encoded private key file.
// //
@ -188,6 +234,7 @@ func ExamplePublicKeys() {
// Use the PublicKeys method for remote authentication. // Use the PublicKeys method for remote authentication.
ssh.PublicKeys(signer), ssh.PublicKeys(signer),
}, },
HostKeyCallback: ssh.FixedHostKey(hostKey),
} }
// Connect to the remote server and perform the SSH handshake. // Connect to the remote server and perform the SSH handshake.
@ -199,11 +246,13 @@ func ExamplePublicKeys() {
} }
func ExampleClient_Listen() { func ExampleClient_Listen() {
var hostKey ssh.PublicKey
config := &ssh.ClientConfig{ config := &ssh.ClientConfig{
User: "username", User: "username",
Auth: []ssh.AuthMethod{ Auth: []ssh.AuthMethod{
ssh.Password("password"), ssh.Password("password"),
}, },
HostKeyCallback: ssh.FixedHostKey(hostKey),
} }
// Dial your ssh server. // Dial your ssh server.
conn, err := ssh.Dial("tcp", "localhost:22", config) conn, err := ssh.Dial("tcp", "localhost:22", config)
@ -226,12 +275,14 @@ func ExampleClient_Listen() {
} }
func ExampleSession_RequestPty() { func ExampleSession_RequestPty() {
var hostKey ssh.PublicKey
// Create client config // Create client config
config := &ssh.ClientConfig{ config := &ssh.ClientConfig{
User: "username", User: "username",
Auth: []ssh.AuthMethod{ Auth: []ssh.AuthMethod{
ssh.Password("password"), ssh.Password("password"),
}, },
HostKeyCallback: ssh.FixedHostKey(hostKey),
} }
// Connect to ssh server // Connect to ssh server
conn, err := ssh.Dial("tcp", "localhost:22", config) conn, err := ssh.Dial("tcp", "localhost:22", config)

View File

@ -19,6 +19,11 @@ import (
// messages are wrong when using ECDH. // messages are wrong when using ECDH.
const debugHandshake = false const debugHandshake = false
// chanSize sets the amount of buffering SSH connections. This is
// primarily for testing: setting chanSize=0 uncovers deadlocks more
// quickly.
const chanSize = 16
// keyingTransport is a packet based transport that supports key // keyingTransport is a packet based transport that supports key
// changes. It need not be thread-safe. It should pass through // changes. It need not be thread-safe. It should pass through
// msgNewKeys in both directions. // msgNewKeys in both directions.
@ -53,34 +58,60 @@ type handshakeTransport struct {
incoming chan []byte incoming chan []byte
readError error readError error
mu sync.Mutex
writeError error
sentInitPacket []byte
sentInitMsg *kexInitMsg
pendingPackets [][]byte // Used when a key exchange is in progress.
// If the read loop wants to schedule a kex, it pings this
// channel, and the write loop will send out a kex
// message.
requestKex chan struct{}
// If the other side requests or confirms a kex, its kexInit
// packet is sent here for the write loop to find it.
startKex chan *pendingKex
// data for host key checking // data for host key checking
hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error hostKeyCallback HostKeyCallback
dialAddress string dialAddress string
remoteAddr net.Addr remoteAddr net.Addr
readSinceKex uint64 // Algorithms agreed in the last key exchange.
algorithms *algorithms
// Protects the writing side of the connection readPacketsLeft uint32
mu sync.Mutex readBytesLeft int64
cond *sync.Cond
sentInitPacket []byte writePacketsLeft uint32
sentInitMsg *kexInitMsg writeBytesLeft int64
writtenSinceKex uint64
writeError error
// The session ID or nil if first kex did not complete yet. // The session ID or nil if first kex did not complete yet.
sessionID []byte sessionID []byte
} }
type pendingKex struct {
otherInit []byte
done chan error
}
func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
t := &handshakeTransport{ t := &handshakeTransport{
conn: conn, conn: conn,
serverVersion: serverVersion, serverVersion: serverVersion,
clientVersion: clientVersion, clientVersion: clientVersion,
incoming: make(chan []byte, 16), incoming: make(chan []byte, chanSize),
config: config, requestKex: make(chan struct{}, 1),
startKex: make(chan *pendingKex, 1),
config: config,
} }
t.cond = sync.NewCond(&t.mu) t.resetReadThresholds()
t.resetWriteThresholds()
// We always start with a mandatory key exchange.
t.requestKex <- struct{}{}
return t return t
} }
@ -95,6 +126,7 @@ func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byt
t.hostKeyAlgorithms = supportedHostKeyAlgos t.hostKeyAlgorithms = supportedHostKeyAlgos
} }
go t.readLoop() go t.readLoop()
go t.kexLoop()
return t return t
} }
@ -102,6 +134,7 @@ func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byt
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
t.hostKeys = config.hostKeys t.hostKeys = config.hostKeys
go t.readLoop() go t.readLoop()
go t.kexLoop()
return t return t
} }
@ -109,6 +142,20 @@ func (t *handshakeTransport) getSessionID() []byte {
return t.sessionID return t.sessionID
} }
// waitSession waits for the session to be established. This should be
// the first thing to call after instantiating handshakeTransport.
func (t *handshakeTransport) waitSession() error {
p, err := t.readPacket()
if err != nil {
return err
}
if p[0] != msgNewKeys {
return fmt.Errorf("ssh: first packet should be msgNewKeys")
}
return nil
}
func (t *handshakeTransport) id() string { func (t *handshakeTransport) id() string {
if len(t.hostKeys) > 0 { if len(t.hostKeys) > 0 {
return "server" return "server"
@ -116,6 +163,20 @@ func (t *handshakeTransport) id() string {
return "client" return "client"
} }
func (t *handshakeTransport) printPacket(p []byte, write bool) {
action := "got"
if write {
action = "sent"
}
if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
} else {
msg, err := decode(p)
log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
}
}
func (t *handshakeTransport) readPacket() ([]byte, error) { func (t *handshakeTransport) readPacket() ([]byte, error) {
p, ok := <-t.incoming p, ok := <-t.incoming
if !ok { if !ok {
@ -125,8 +186,10 @@ func (t *handshakeTransport) readPacket() ([]byte, error) {
} }
func (t *handshakeTransport) readLoop() { func (t *handshakeTransport) readLoop() {
first := true
for { for {
p, err := t.readOnePacket() p, err := t.readOnePacket(first)
first = false
if err != nil { if err != nil {
t.readError = err t.readError = err
close(t.incoming) close(t.incoming)
@ -138,67 +201,217 @@ func (t *handshakeTransport) readLoop() {
t.incoming <- p t.incoming <- p
} }
// If we can't read, declare the writing part dead too. // Stop writers too.
t.mu.Lock() t.recordWriteError(t.readError)
defer t.mu.Unlock()
if t.writeError == nil { // Unblock the writer should it wait for this.
t.writeError = t.readError close(t.startKex)
}
t.cond.Broadcast() // Don't close t.requestKex; it's also written to from writePacket.
} }
func (t *handshakeTransport) readOnePacket() ([]byte, error) { func (t *handshakeTransport) pushPacket(p []byte) error {
if t.readSinceKex > t.config.RekeyThreshold { if debugHandshake {
if err := t.requestKeyChange(); err != nil { t.printPacket(p, true)
return nil, err }
return t.conn.writePacket(p)
}
func (t *handshakeTransport) getWriteError() error {
t.mu.Lock()
defer t.mu.Unlock()
return t.writeError
}
func (t *handshakeTransport) recordWriteError(err error) {
t.mu.Lock()
defer t.mu.Unlock()
if t.writeError == nil && err != nil {
t.writeError = err
}
}
func (t *handshakeTransport) requestKeyExchange() {
select {
case t.requestKex <- struct{}{}:
default:
// something already requested a kex, so do nothing.
}
}
func (t *handshakeTransport) resetWriteThresholds() {
t.writePacketsLeft = packetRekeyThreshold
if t.config.RekeyThreshold > 0 {
t.writeBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil {
t.writeBytesLeft = t.algorithms.w.rekeyBytes()
} else {
t.writeBytesLeft = 1 << 30
}
}
func (t *handshakeTransport) kexLoop() {
write:
for t.getWriteError() == nil {
var request *pendingKex
var sent bool
for request == nil || !sent {
var ok bool
select {
case request, ok = <-t.startKex:
if !ok {
break write
}
case <-t.requestKex:
break
}
if !sent {
if err := t.sendKexInit(); err != nil {
t.recordWriteError(err)
break
}
sent = true
}
} }
if err := t.getWriteError(); err != nil {
if request != nil {
request.done <- err
}
break
}
// We're not servicing t.requestKex, but that is OK:
// we never block on sending to t.requestKex.
// We're not servicing t.startKex, but the remote end
// has just sent us a kexInitMsg, so it can't send
// another key change request, until we close the done
// channel on the pendingKex request.
err := t.enterKeyExchange(request.otherInit)
t.mu.Lock()
t.writeError = err
t.sentInitPacket = nil
t.sentInitMsg = nil
t.resetWriteThresholds()
// we have completed the key exchange. Since the
// reader is still blocked, it is safe to clear out
// the requestKex channel. This avoids the situation
// where: 1) we consumed our own request for the
// initial kex, and 2) the kex from the remote side
// caused another send on the requestKex channel,
clear:
for {
select {
case <-t.requestKex:
//
default:
break clear
}
}
request.done <- t.writeError
// kex finished. Push packets that we received while
// the kex was in progress. Don't look at t.startKex
// and don't increment writtenSinceKex: if we trigger
// another kex while we are still busy with the last
// one, things will become very confusing.
for _, p := range t.pendingPackets {
t.writeError = t.pushPacket(p)
if t.writeError != nil {
break
}
}
t.pendingPackets = t.pendingPackets[:0]
t.mu.Unlock()
} }
// drain startKex channel. We don't service t.requestKex
// because nobody does blocking sends there.
go func() {
for init := range t.startKex {
init.done <- t.writeError
}
}()
// Unblock reader.
t.conn.Close()
}
// The protocol uses uint32 for packet counters, so we can't let them
// reach 1<<32. We will actually read and write more packets than
// this, though: the other side may send more packets, and after we
// hit this limit on writing we will send a few more packets for the
// key exchange itself.
const packetRekeyThreshold = (1 << 31)
func (t *handshakeTransport) resetReadThresholds() {
t.readPacketsLeft = packetRekeyThreshold
if t.config.RekeyThreshold > 0 {
t.readBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil {
t.readBytesLeft = t.algorithms.r.rekeyBytes()
} else {
t.readBytesLeft = 1 << 30
}
}
func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
p, err := t.conn.readPacket() p, err := t.conn.readPacket()
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.readSinceKex += uint64(len(p)) if t.readPacketsLeft > 0 {
if debugHandshake { t.readPacketsLeft--
if p[0] == msgChannelData || p[0] == msgChannelExtendedData { } else {
log.Printf("%s got data (packet %d bytes)", t.id(), len(p)) t.requestKeyExchange()
} else {
msg, err := decode(p)
log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err)
}
} }
if t.readBytesLeft > 0 {
t.readBytesLeft -= int64(len(p))
} else {
t.requestKeyExchange()
}
if debugHandshake {
t.printPacket(p, false)
}
if first && p[0] != msgKexInit {
return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
}
if p[0] != msgKexInit { if p[0] != msgKexInit {
return p, nil return p, nil
} }
t.mu.Lock()
firstKex := t.sessionID == nil firstKex := t.sessionID == nil
err = t.enterKeyExchangeLocked(p) kex := pendingKex{
if err != nil { done: make(chan error, 1),
// drop connection otherInit: p,
t.conn.Close()
t.writeError = err
} }
t.startKex <- &kex
err = <-kex.done
if debugHandshake { if debugHandshake {
log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err) log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
} }
// Unblock writers.
t.sentInitMsg = nil
t.sentInitPacket = nil
t.cond.Broadcast()
t.writtenSinceKex = 0
t.mu.Unlock()
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.readSinceKex = 0 t.resetReadThresholds()
// By default, a key exchange is hidden from higher layers by // By default, a key exchange is hidden from higher layers by
// translating it into msgIgnore. // translating it into msgIgnore.
@ -213,61 +426,16 @@ func (t *handshakeTransport) readOnePacket() ([]byte, error) {
return successPacket, nil return successPacket, nil
} }
// keyChangeCategory describes whether a key exchange is the first on a // sendKexInit sends a key change message.
// connection, or a subsequent one. func (t *handshakeTransport) sendKexInit() error {
type keyChangeCategory bool
const (
firstKeyExchange keyChangeCategory = true
subsequentKeyExchange keyChangeCategory = false
)
// sendKexInit sends a key change message, and returns the message
// that was sent. After initiating the key change, all writes will be
// blocked until the change is done, and a failed key change will
// close the underlying transport. This function is safe for
// concurrent use by multiple goroutines.
func (t *handshakeTransport) sendKexInit(isFirst keyChangeCategory) error {
var err error
t.mu.Lock() t.mu.Lock()
// If this is the initial key change, but we already have a sessionID, defer t.mu.Unlock()
// then do nothing because the key exchange has already completed
// asynchronously.
if !isFirst || t.sessionID == nil {
_, _, err = t.sendKexInitLocked(isFirst)
}
t.mu.Unlock()
if err != nil {
return err
}
if isFirst {
if packet, err := t.readPacket(); err != nil {
return err
} else if packet[0] != msgNewKeys {
return unexpectedMessageError(msgNewKeys, packet[0])
}
}
return nil
}
func (t *handshakeTransport) requestInitialKeyChange() error {
return t.sendKexInit(firstKeyExchange)
}
func (t *handshakeTransport) requestKeyChange() error {
return t.sendKexInit(subsequentKeyExchange)
}
// sendKexInitLocked sends a key change message. t.mu must be locked
// while this happens.
func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexInitMsg, []byte, error) {
// kexInits may be sent either in response to the other side,
// or because our side wants to initiate a key change, so we
// may have already sent a kexInit. In that case, don't send a
// second kexInit.
if t.sentInitMsg != nil { if t.sentInitMsg != nil {
return t.sentInitMsg, t.sentInitPacket, nil // kexInits may be sent either in response to the other side,
// or because our side wants to initiate a key change, so we
// may have already sent a kexInit. In that case, don't send a
// second kexInit.
return nil
} }
msg := &kexInitMsg{ msg := &kexInitMsg{
@ -295,53 +463,65 @@ func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexI
packetCopy := make([]byte, len(packet)) packetCopy := make([]byte, len(packet))
copy(packetCopy, packet) copy(packetCopy, packet)
if err := t.conn.writePacket(packetCopy); err != nil { if err := t.pushPacket(packetCopy); err != nil {
return nil, nil, err return err
} }
t.sentInitMsg = msg t.sentInitMsg = msg
t.sentInitPacket = packet t.sentInitPacket = packet
return msg, packet, nil
return nil
} }
func (t *handshakeTransport) writePacket(p []byte) error { func (t *handshakeTransport) writePacket(p []byte) error {
t.mu.Lock()
defer t.mu.Unlock()
if t.writtenSinceKex > t.config.RekeyThreshold {
t.sendKexInitLocked(subsequentKeyExchange)
}
for t.sentInitMsg != nil && t.writeError == nil {
t.cond.Wait()
}
if t.writeError != nil {
return t.writeError
}
t.writtenSinceKex += uint64(len(p))
switch p[0] { switch p[0] {
case msgKexInit: case msgKexInit:
return errors.New("ssh: only handshakeTransport can send kexInit") return errors.New("ssh: only handshakeTransport can send kexInit")
case msgNewKeys: case msgNewKeys:
return errors.New("ssh: only handshakeTransport can send newKeys") return errors.New("ssh: only handshakeTransport can send newKeys")
default:
return t.conn.writePacket(p)
} }
t.mu.Lock()
defer t.mu.Unlock()
if t.writeError != nil {
return t.writeError
}
if t.sentInitMsg != nil {
// Copy the packet so the writer can reuse the buffer.
cp := make([]byte, len(p))
copy(cp, p)
t.pendingPackets = append(t.pendingPackets, cp)
return nil
}
if t.writeBytesLeft > 0 {
t.writeBytesLeft -= int64(len(p))
} else {
t.requestKeyExchange()
}
if t.writePacketsLeft > 0 {
t.writePacketsLeft--
} else {
t.requestKeyExchange()
}
if err := t.pushPacket(p); err != nil {
t.writeError = err
}
return nil
} }
func (t *handshakeTransport) Close() error { func (t *handshakeTransport) Close() error {
return t.conn.Close() return t.conn.Close()
} }
// enterKeyExchange runs the key exchange. t.mu must be held while running this. func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) error {
if debugHandshake { if debugHandshake {
log.Printf("%s entered key exchange", t.id()) log.Printf("%s entered key exchange", t.id())
} }
myInit, myInitPacket, err := t.sendKexInitLocked(subsequentKeyExchange)
if err != nil {
return err
}
otherInit := &kexInitMsg{} otherInit := &kexInitMsg{}
if err := Unmarshal(otherInitPacket, otherInit); err != nil { if err := Unmarshal(otherInitPacket, otherInit); err != nil {
@ -352,20 +532,20 @@ func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) erro
clientVersion: t.clientVersion, clientVersion: t.clientVersion,
serverVersion: t.serverVersion, serverVersion: t.serverVersion,
clientKexInit: otherInitPacket, clientKexInit: otherInitPacket,
serverKexInit: myInitPacket, serverKexInit: t.sentInitPacket,
} }
clientInit := otherInit clientInit := otherInit
serverInit := myInit serverInit := t.sentInitMsg
if len(t.hostKeys) == 0 { if len(t.hostKeys) == 0 {
clientInit = myInit clientInit, serverInit = serverInit, clientInit
serverInit = otherInit
magics.clientKexInit = myInitPacket magics.clientKexInit = t.sentInitPacket
magics.serverKexInit = otherInitPacket magics.serverKexInit = otherInitPacket
} }
algs, err := findAgreedAlgorithms(clientInit, serverInit) var err error
t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit)
if err != nil { if err != nil {
return err return err
} }
@ -388,16 +568,16 @@ func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) erro
} }
} }
kex, ok := kexAlgoMap[algs.kex] kex, ok := kexAlgoMap[t.algorithms.kex]
if !ok { if !ok {
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex) return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
} }
var result *kexResult var result *kexResult
if len(t.hostKeys) > 0 { if len(t.hostKeys) > 0 {
result, err = t.server(kex, algs, &magics) result, err = t.server(kex, t.algorithms, &magics)
} else { } else {
result, err = t.client(kex, algs, &magics) result, err = t.client(kex, t.algorithms, &magics)
} }
if err != nil { if err != nil {
@ -409,7 +589,9 @@ func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) erro
} }
result.SessionID = t.sessionID result.SessionID = t.sessionID
t.conn.prepareKeyChange(algs, result) if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil {
return err
}
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
return err return err
} }
@ -449,11 +631,9 @@ func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *
return nil, err return nil, err
} }
if t.hostKeyCallback != nil { err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) if err != nil {
if err != nil { return nil, err
return nil, err
}
} }
return result, nil return result, nil

View File

@ -9,6 +9,7 @@ import (
"crypto/rand" "crypto/rand"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"reflect" "reflect"
"runtime" "runtime"
@ -41,7 +42,10 @@ func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error
func netPipe() (net.Conn, net.Conn, error) { func netPipe() (net.Conn, net.Conn, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
return nil, nil, err listener, err = net.Listen("tcp", "[::1]:0")
if err != nil {
return nil, nil, err
}
} }
defer listener.Close() defer listener.Close()
c1, err := net.Dial("tcp", listener.Addr().String()) c1, err := net.Dial("tcp", listener.Addr().String())
@ -58,14 +62,46 @@ func netPipe() (net.Conn, net.Conn, error) {
return c1, c2, nil return c1, c2, nil
} }
func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) { // noiseTransport inserts ignore messages to check that the read loop
// and the key exchange filters out these messages.
type noiseTransport struct {
keyingTransport
}
func (t *noiseTransport) writePacket(p []byte) error {
ignore := []byte{msgIgnore}
if err := t.keyingTransport.writePacket(ignore); err != nil {
return err
}
debug := []byte{msgDebug, 1, 2, 3}
if err := t.keyingTransport.writePacket(debug); err != nil {
return err
}
return t.keyingTransport.writePacket(p)
}
func addNoiseTransport(t keyingTransport) keyingTransport {
return &noiseTransport{t}
}
// handshakePair creates two handshakeTransports connected with each
// other. If the noise argument is true, both transports will try to
// confuse the other side by sending ignore and debug messages.
func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) {
a, b, err := netPipe() a, b, err := netPipe()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
trC := newTransport(a, rand.Reader, true) var trC, trS keyingTransport
trS := newTransport(b, rand.Reader, false)
trC = newTransport(a, rand.Reader, true)
trS = newTransport(b, rand.Reader, false)
if noise {
trC = addNoiseTransport(trC)
trS = addNoiseTransport(trS)
}
clientConf.SetDefaults() clientConf.SetDefaults()
v := []byte("version") v := []byte("version")
@ -77,6 +113,13 @@ func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTran
serverConf.SetDefaults() serverConf.SetDefaults()
server = newServerTransport(trS, v, v, serverConf) server = newServerTransport(trS, v, v, serverConf)
if err := server.waitSession(); err != nil {
return nil, nil, fmt.Errorf("server.waitSession: %v", err)
}
if err := client.waitSession(); err != nil {
return nil, nil, fmt.Errorf("client.waitSession: %v", err)
}
return client, server, nil return client, server, nil
} }
@ -84,8 +127,14 @@ func TestHandshakeBasic(t *testing.T) {
if runtime.GOOS == "plan9" { if runtime.GOOS == "plan9" {
t.Skip("see golang.org/issue/7237") t.Skip("see golang.org/issue/7237")
} }
checker := &testChecker{}
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") checker := &syncChecker{
waitCall: make(chan int, 10),
called: make(chan int, 10),
}
checker.waitCall <- 1
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
if err != nil { if err != nil {
t.Fatalf("handshakePair: %v", err) t.Fatalf("handshakePair: %v", err)
} }
@ -93,240 +142,195 @@ func TestHandshakeBasic(t *testing.T) {
defer trC.Close() defer trC.Close()
defer trS.Close() defer trS.Close()
// Let first kex complete normally.
<-checker.called
clientDone := make(chan int, 0)
gotHalf := make(chan int, 0)
const N = 20
go func() { go func() {
defer close(clientDone)
// Client writes a bunch of stuff, and does a key // Client writes a bunch of stuff, and does a key
// change in the middle. This should not confuse the // change in the middle. This should not confuse the
// handshake in progress // handshake in progress. We do this twice, so we test
for i := 0; i < 10; i++ { // that the packet buffer is reset correctly.
for i := 0; i < N; i++ {
p := []byte{msgRequestSuccess, byte(i)} p := []byte{msgRequestSuccess, byte(i)}
if err := trC.writePacket(p); err != nil { if err := trC.writePacket(p); err != nil {
t.Fatalf("sendPacket: %v", err) t.Fatalf("sendPacket: %v", err)
} }
if i == 5 { if (i % 10) == 5 {
<-gotHalf
// halfway through, we request a key change. // halfway through, we request a key change.
err := trC.sendKexInit(subsequentKeyExchange) trC.requestKeyExchange()
if err != nil {
t.Fatalf("sendKexInit: %v", err) // Wait until we can be sure the key
} // change has really started before we
// write more.
<-checker.called
}
if (i % 10) == 7 {
// write some packets until the kex
// completes, to test buffering of
// packets.
checker.waitCall <- 1
} }
} }
trC.Close()
}() }()
// Server checks that client messages come in cleanly // Server checks that client messages come in cleanly
i := 0 i := 0
for { err = nil
p, err := trS.readPacket() for ; i < N; i++ {
var p []byte
p, err = trS.readPacket()
if err != nil { if err != nil {
break break
} }
if p[0] == msgNewKeys { if (i % 10) == 5 {
continue gotHalf <- 1
} }
want := []byte{msgRequestSuccess, byte(i)} want := []byte{msgRequestSuccess, byte(i)}
if bytes.Compare(p, want) != 0 { if bytes.Compare(p, want) != 0 {
t.Errorf("message %d: got %q, want %q", i, p, want) t.Errorf("message %d: got %v, want %v", i, p, want)
} }
i++
} }
if i != 10 { <-clientDone
if err != nil && err != io.EOF {
t.Fatalf("server error: %v", err)
}
if i != N {
t.Errorf("received %d messages, want 10.", i) t.Errorf("received %d messages, want 10.", i)
} }
// If all went well, we registered exactly 1 key change. close(checker.called)
if len(checker.calls) != 1 { if _, ok := <-checker.called; ok {
t.Fatalf("got %d host key checks, want 1", len(checker.calls)) // If all went well, we registered exactly 2 key changes: one
} // that establishes the session, and one that we requested
// additionally.
pub := testSigners["ecdsa"].PublicKey() t.Fatalf("got another host key checks after 2 handshakes")
want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())
if want != checker.calls[0] {
t.Errorf("got %q want %q for host key check", checker.calls[0], want)
}
}
func TestHandshakeError(t *testing.T) {
checker := &testChecker{}
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad")
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
// send a packet
packet := []byte{msgRequestSuccess, 42}
if err := trC.writePacket(packet); err != nil {
t.Errorf("writePacket: %v", err)
}
// Now request a key change.
err = trC.sendKexInit(subsequentKeyExchange)
if err != nil {
t.Errorf("sendKexInit: %v", err)
}
// the key change will fail, and afterwards we can't write.
if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {
t.Errorf("writePacket after botched rekey succeeded.")
}
readback, err := trS.readPacket()
if err != nil {
t.Fatalf("server closed too soon: %v", err)
}
if bytes.Compare(readback, packet) != 0 {
t.Errorf("got %q want %q", readback, packet)
}
readback, err = trS.readPacket()
if err == nil {
t.Errorf("got a message %q after failed key change", readback)
} }
} }
func TestForceFirstKex(t *testing.T) { func TestForceFirstKex(t *testing.T) {
// like handshakePair, but must access the keyingTransport.
checker := &testChecker{} checker := &testChecker{}
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") clientConf := &ClientConfig{HostKeyCallback: checker.Check}
a, b, err := netPipe()
if err != nil { if err != nil {
t.Fatalf("handshakePair: %v", err) t.Fatalf("netPipe: %v", err)
} }
defer trC.Close() var trC, trS keyingTransport
defer trS.Close()
trC = newTransport(a, rand.Reader, true)
// This is the disallowed packet:
trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})) trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
// Rest of the setup.
trS = newTransport(b, rand.Reader, false)
clientConf.SetDefaults()
v := []byte("version")
client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
serverConf := &ServerConfig{}
serverConf.AddHostKey(testSigners["ecdsa"])
serverConf.AddHostKey(testSigners["rsa"])
serverConf.SetDefaults()
server := newServerTransport(trS, v, v, serverConf)
defer client.Close()
defer server.Close()
// We setup the initial key exchange, but the remote side // We setup the initial key exchange, but the remote side
// tries to send serviceRequestMsg in cleartext, which is // tries to send serviceRequestMsg in cleartext, which is
// disallowed. // disallowed.
err = trS.sendKexInit(firstKeyExchange) if err := server.waitSession(); err == nil {
if err == nil {
t.Errorf("server first kex init should reject unexpected packet") t.Errorf("server first kex init should reject unexpected packet")
} }
} }
func TestHandshakeTwice(t *testing.T) {
checker := &testChecker{}
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
// Both sides should ask for the first key exchange first.
err = trS.sendKexInit(firstKeyExchange)
if err != nil {
t.Errorf("server sendKexInit: %v", err)
}
err = trC.sendKexInit(firstKeyExchange)
if err != nil {
t.Errorf("client sendKexInit: %v", err)
}
sent := 0
// send a packet
packet := make([]byte, 5)
packet[0] = msgRequestSuccess
if err := trC.writePacket(packet); err != nil {
t.Errorf("writePacket: %v", err)
}
sent++
// Send another packet. Use a fresh one, since writePacket destroys.
packet = make([]byte, 5)
packet[0] = msgRequestSuccess
if err := trC.writePacket(packet); err != nil {
t.Errorf("writePacket: %v", err)
}
sent++
// 2nd key change.
err = trC.sendKexInit(subsequentKeyExchange)
if err != nil {
t.Errorf("sendKexInit: %v", err)
}
packet = make([]byte, 5)
packet[0] = msgRequestSuccess
if err := trC.writePacket(packet); err != nil {
t.Errorf("writePacket: %v", err)
}
sent++
packet = make([]byte, 5)
packet[0] = msgRequestSuccess
for i := 0; i < sent; i++ {
msg, err := trS.readPacket()
if err != nil {
t.Fatalf("server closed too soon: %v", err)
}
if bytes.Compare(msg, packet) != 0 {
t.Errorf("packet %d: got %q want %q", i, msg, packet)
}
}
if len(checker.calls) != 2 {
t.Errorf("got %d key changes, want 2", len(checker.calls))
}
}
func TestHandshakeAutoRekeyWrite(t *testing.T) { func TestHandshakeAutoRekeyWrite(t *testing.T) {
checker := &testChecker{} checker := &syncChecker{
called: make(chan int, 10),
waitCall: nil,
}
clientConf := &ClientConfig{HostKeyCallback: checker.Check} clientConf := &ClientConfig{HostKeyCallback: checker.Check}
clientConf.RekeyThreshold = 500 clientConf.RekeyThreshold = 500
trC, trS, err := handshakePair(clientConf, "addr") trC, trS, err := handshakePair(clientConf, "addr", false)
if err != nil { if err != nil {
t.Fatalf("handshakePair: %v", err) t.Fatalf("handshakePair: %v", err)
} }
defer trC.Close() defer trC.Close()
defer trS.Close() defer trS.Close()
for i := 0; i < 5; i++ { input := make([]byte, 251)
packet := make([]byte, 251) input[0] = msgRequestSuccess
packet[0] = msgRequestSuccess
if err := trC.writePacket(packet); err != nil { done := make(chan int, 1)
const numPacket = 5
go func() {
defer close(done)
j := 0
for ; j < numPacket; j++ {
if p, err := trS.readPacket(); err != nil {
break
} else if !bytes.Equal(input, p) {
t.Errorf("got packet type %d, want %d", p[0], input[0])
}
}
if j != numPacket {
t.Errorf("got %d, want 5 messages", j)
}
}()
<-checker.called
for i := 0; i < numPacket; i++ {
p := make([]byte, len(input))
copy(p, input)
if err := trC.writePacket(p); err != nil {
t.Errorf("writePacket: %v", err) t.Errorf("writePacket: %v", err)
} }
} if i == 2 {
// Make sure the kex is in progress.
j := 0 <-checker.called
for ; j < 5; j++ {
_, err := trS.readPacket()
if err != nil {
break
} }
}
if j != 5 {
t.Errorf("got %d, want 5 messages", j)
}
if len(checker.calls) != 2 {
t.Errorf("got %d key changes, wanted 2", len(checker.calls))
} }
<-done
} }
type syncChecker struct { type syncChecker struct {
called chan int waitCall chan int
called chan int
} }
func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
t.called <- 1 c.called <- 1
if c.waitCall != nil {
<-c.waitCall
}
return nil return nil
} }
func TestHandshakeAutoRekeyRead(t *testing.T) { func TestHandshakeAutoRekeyRead(t *testing.T) {
sync := &syncChecker{make(chan int, 2)} sync := &syncChecker{
called: make(chan int, 2),
waitCall: nil,
}
clientConf := &ClientConfig{ clientConf := &ClientConfig{
HostKeyCallback: sync.Check, HostKeyCallback: sync.Check,
} }
clientConf.RekeyThreshold = 500 clientConf.RekeyThreshold = 500
trC, trS, err := handshakePair(clientConf, "addr") trC, trS, err := handshakePair(clientConf, "addr", false)
if err != nil { if err != nil {
t.Fatalf("handshakePair: %v", err) t.Fatalf("handshakePair: %v", err)
} }
@ -338,12 +342,19 @@ func TestHandshakeAutoRekeyRead(t *testing.T) {
if err := trS.writePacket(packet); err != nil { if err := trS.writePacket(packet); err != nil {
t.Fatalf("writePacket: %v", err) t.Fatalf("writePacket: %v", err)
} }
// While we read out the packet, a key change will be // While we read out the packet, a key change will be
// initiated. // initiated.
if _, err := trC.readPacket(); err != nil { done := make(chan int, 1)
t.Fatalf("readPacket(client): %v", err) go func() {
} defer close(done)
if _, err := trC.readPacket(); err != nil {
t.Fatalf("readPacket(client): %v", err)
}
}()
<-done
<-sync.called <-sync.called
} }
@ -357,6 +368,7 @@ type errorKeyingTransport struct {
func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error { func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
return nil return nil
} }
func (n *errorKeyingTransport) getSessionID() []byte { func (n *errorKeyingTransport) getSessionID() []byte {
return nil return nil
} }
@ -383,20 +395,32 @@ func (n *errorKeyingTransport) readPacket() ([]byte, error) {
func TestHandshakeErrorHandlingRead(t *testing.T) { func TestHandshakeErrorHandlingRead(t *testing.T) {
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
testHandshakeErrorHandlingN(t, i, -1) testHandshakeErrorHandlingN(t, i, -1, false)
} }
} }
func TestHandshakeErrorHandlingWrite(t *testing.T) { func TestHandshakeErrorHandlingWrite(t *testing.T) {
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
testHandshakeErrorHandlingN(t, -1, i) testHandshakeErrorHandlingN(t, -1, i, false)
}
}
func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
for i := 0; i < 20; i++ {
testHandshakeErrorHandlingN(t, i, -1, true)
}
}
func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
for i := 0; i < 20; i++ {
testHandshakeErrorHandlingN(t, -1, i, true)
} }
} }
// testHandshakeErrorHandlingN runs handshakes, injecting errors. If // testHandshakeErrorHandlingN runs handshakes, injecting errors. If
// handshakeTransport deadlocks, the go runtime will detect it and // handshakeTransport deadlocks, the go runtime will detect it and
// panic. // panic.
func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) { func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)}) msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
a, b := memPipe() a, b := memPipe()
@ -409,37 +433,58 @@ func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {
serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'}) serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
serverConn.hostKeys = []Signer{key} serverConn.hostKeys = []Signer{key}
go serverConn.readLoop() go serverConn.readLoop()
go serverConn.kexLoop()
clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold} clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
clientConf.SetDefaults() clientConf.SetDefaults()
clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
clientConn.hostKeyCallback = InsecureIgnoreHostKey()
go clientConn.readLoop() go clientConn.readLoop()
go clientConn.kexLoop()
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(4)
for _, hs := range []packetConn{serverConn, clientConn} { for _, hs := range []packetConn{serverConn, clientConn} {
go func(c packetConn) { if !coupled {
for { wg.Add(2)
err := c.writePacket(msg) go func(c packetConn) {
if err != nil { for i := 0; ; i++ {
break str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
err := c.writePacket(Marshal(&serviceRequestMsg{str}))
if err != nil {
break
}
} }
} wg.Done()
wg.Done() c.Close()
}(hs) }(hs)
go func(c packetConn) { go func(c packetConn) {
for { for {
_, err := c.readPacket() _, err := c.readPacket()
if err != nil { if err != nil {
break break
}
} }
} wg.Done()
wg.Done() }(hs)
}(hs) } else {
} wg.Add(1)
go func(c packetConn) {
for {
_, err := c.readPacket()
if err != nil {
break
}
if err := c.writePacket(msg); err != nil {
break
}
}
wg.Done()
}(hs)
}
}
wg.Wait() wg.Wait()
} }
@ -448,7 +493,7 @@ func TestDisconnect(t *testing.T) {
t.Skip("see golang.org/issue/7237") t.Skip("see golang.org/issue/7237")
} }
checker := &testChecker{} checker := &testChecker{}
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
if err != nil { if err != nil {
t.Fatalf("handshakePair: %v", err) t.Fatalf("handshakePair: %v", err)
} }
@ -484,3 +529,31 @@ func TestDisconnect(t *testing.T) {
t.Errorf("readPacket 3 succeeded") t.Errorf("readPacket 3 succeeded")
} }
} }
func TestHandshakeRekeyDefault(t *testing.T) {
clientConf := &ClientConfig{
Config: Config{
Ciphers: []string{"aes128-ctr"},
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
trC, trS, err := handshakePair(clientConf, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
trC.writePacket([]byte{msgRequestSuccess, 0, 0})
trC.Close()
rgb := (1024 + trC.readBytesLeft) >> 30
wgb := (1024 + trC.writeBytesLeft) >> 30
if rgb != 64 {
t.Errorf("got rekey after %dG read, want 64G", rgb)
}
if wgb != 64 {
t.Errorf("got rekey after %dG write, want 64G", wgb)
}
}

View File

@ -798,8 +798,8 @@ func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
P *big.Int P *big.Int
Q *big.Int Q *big.Int
G *big.Int G *big.Int
Priv *big.Int
Pub *big.Int Pub *big.Int
Priv *big.Int
} }
rest, err := asn1.Unmarshal(der, &k) rest, err := asn1.Unmarshal(der, &k)
if err != nil { if err != nil {
@ -816,15 +816,15 @@ func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
Q: k.Q, Q: k.Q,
G: k.G, G: k.G,
}, },
Y: k.Priv, Y: k.Pub,
}, },
X: k.Pub, X: k.Priv,
}, nil }, nil
} }
// Implemented based on the documentation at // Implemented based on the documentation at
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key // https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key
func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) { func parseOpenSSHPrivateKey(key []byte) (crypto.PrivateKey, error) {
magic := append([]byte("openssh-key-v1"), 0) magic := append([]byte("openssh-key-v1"), 0)
if !bytes.Equal(magic, key[0:len(magic)]) { if !bytes.Equal(magic, key[0:len(magic)]) {
return nil, errors.New("ssh: invalid openssh private key format") return nil, errors.New("ssh: invalid openssh private key format")
@ -844,14 +844,15 @@ func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) {
return nil, err return nil, err
} }
if w.KdfName != "none" || w.CipherName != "none" {
return nil, errors.New("ssh: cannot decode encrypted private keys")
}
pk1 := struct { pk1 := struct {
Check1 uint32 Check1 uint32
Check2 uint32 Check2 uint32
Keytype string Keytype string
Pub []byte Rest []byte `ssh:"rest"`
Priv []byte
Comment string
Pad []byte `ssh:"rest"`
}{} }{}
if err := Unmarshal(w.PrivKeyBlock, &pk1); err != nil { if err := Unmarshal(w.PrivKeyBlock, &pk1); err != nil {
@ -862,24 +863,75 @@ func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) {
return nil, errors.New("ssh: checkint mismatch") return nil, errors.New("ssh: checkint mismatch")
} }
// we only handle ed25519 keys currently // we only handle ed25519 and rsa keys currently
if pk1.Keytype != KeyAlgoED25519 { switch pk1.Keytype {
case KeyAlgoRSA:
// https://github.com/openssh/openssh-portable/blob/master/sshkey.c#L2760-L2773
key := struct {
N *big.Int
E *big.Int
D *big.Int
Iqmp *big.Int
P *big.Int
Q *big.Int
Comment string
Pad []byte `ssh:"rest"`
}{}
if err := Unmarshal(pk1.Rest, &key); err != nil {
return nil, err
}
for i, b := range key.Pad {
if int(b) != i+1 {
return nil, errors.New("ssh: padding not as expected")
}
}
pk := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{
N: key.N,
E: int(key.E.Int64()),
},
D: key.D,
Primes: []*big.Int{key.P, key.Q},
}
if err := pk.Validate(); err != nil {
return nil, err
}
pk.Precompute()
return pk, nil
case KeyAlgoED25519:
key := struct {
Pub []byte
Priv []byte
Comment string
Pad []byte `ssh:"rest"`
}{}
if err := Unmarshal(pk1.Rest, &key); err != nil {
return nil, err
}
if len(key.Priv) != ed25519.PrivateKeySize {
return nil, errors.New("ssh: private key unexpected length")
}
for i, b := range key.Pad {
if int(b) != i+1 {
return nil, errors.New("ssh: padding not as expected")
}
}
pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize))
copy(pk, key.Priv)
return &pk, nil
default:
return nil, errors.New("ssh: unhandled key type") return nil, errors.New("ssh: unhandled key type")
} }
for i, b := range pk1.Pad {
if int(b) != i+1 {
return nil, errors.New("ssh: padding not as expected")
}
}
if len(pk1.Priv) != ed25519.PrivateKeySize {
return nil, errors.New("ssh: private key unexpected length")
}
pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize))
copy(pk, pk1.Priv)
return &pk, nil
} }
// FingerprintLegacyMD5 returns the user presentation of the key's // FingerprintLegacyMD5 returns the user presentation of the key's

View File

@ -0,0 +1,546 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package knownhosts implements a parser for the OpenSSH
// known_hosts host key database.
package knownhosts
import (
"bufio"
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"os"
"strings"
"golang.org/x/crypto/ssh"
)
// See the sshd manpage
// (http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT) for
// background.
type addr struct{ host, port string }
func (a *addr) String() string {
h := a.host
if strings.Contains(h, ":") {
h = "[" + h + "]"
}
return h + ":" + a.port
}
type matcher interface {
match([]addr) bool
}
type hostPattern struct {
negate bool
addr addr
}
func (p *hostPattern) String() string {
n := ""
if p.negate {
n = "!"
}
return n + p.addr.String()
}
type hostPatterns []hostPattern
func (ps hostPatterns) match(addrs []addr) bool {
matched := false
for _, p := range ps {
for _, a := range addrs {
m := p.match(a)
if !m {
continue
}
if p.negate {
return false
}
matched = true
}
}
return matched
}
// See
// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/addrmatch.c
// The matching of * has no regard for separators, unlike filesystem globs
func wildcardMatch(pat []byte, str []byte) bool {
for {
if len(pat) == 0 {
return len(str) == 0
}
if len(str) == 0 {
return false
}
if pat[0] == '*' {
if len(pat) == 1 {
return true
}
for j := range str {
if wildcardMatch(pat[1:], str[j:]) {
return true
}
}
return false
}
if pat[0] == '?' || pat[0] == str[0] {
pat = pat[1:]
str = str[1:]
} else {
return false
}
}
}
func (l *hostPattern) match(a addr) bool {
return wildcardMatch([]byte(l.addr.host), []byte(a.host)) && l.addr.port == a.port
}
type keyDBLine struct {
cert bool
matcher matcher
knownKey KnownKey
}
func serialize(k ssh.PublicKey) string {
return k.Type() + " " + base64.StdEncoding.EncodeToString(k.Marshal())
}
func (l *keyDBLine) match(addrs []addr) bool {
return l.matcher.match(addrs)
}
type hostKeyDB struct {
// Serialized version of revoked keys
revoked map[string]*KnownKey
lines []keyDBLine
}
func newHostKeyDB() *hostKeyDB {
db := &hostKeyDB{
revoked: make(map[string]*KnownKey),
}
return db
}
func keyEq(a, b ssh.PublicKey) bool {
return bytes.Equal(a.Marshal(), b.Marshal())
}
// IsAuthorityForHost can be used as a callback in ssh.CertChecker
func (db *hostKeyDB) IsHostAuthority(remote ssh.PublicKey, address string) bool {
h, p, err := net.SplitHostPort(address)
if err != nil {
return false
}
a := addr{host: h, port: p}
for _, l := range db.lines {
if l.cert && keyEq(l.knownKey.Key, remote) && l.match([]addr{a}) {
return true
}
}
return false
}
// IsRevoked can be used as a callback in ssh.CertChecker
func (db *hostKeyDB) IsRevoked(key *ssh.Certificate) bool {
_, ok := db.revoked[string(key.Marshal())]
return ok
}
const markerCert = "@cert-authority"
const markerRevoked = "@revoked"
func nextWord(line []byte) (string, []byte) {
i := bytes.IndexAny(line, "\t ")
if i == -1 {
return string(line), nil
}
return string(line[:i]), bytes.TrimSpace(line[i:])
}
func parseLine(line []byte) (marker, host string, key ssh.PublicKey, err error) {
if w, next := nextWord(line); w == markerCert || w == markerRevoked {
marker = w
line = next
}
host, line = nextWord(line)
if len(line) == 0 {
return "", "", nil, errors.New("knownhosts: missing host pattern")
}
// ignore the keytype as it's in the key blob anyway.
_, line = nextWord(line)
if len(line) == 0 {
return "", "", nil, errors.New("knownhosts: missing key type pattern")
}
keyBlob, _ := nextWord(line)
keyBytes, err := base64.StdEncoding.DecodeString(keyBlob)
if err != nil {
return "", "", nil, err
}
key, err = ssh.ParsePublicKey(keyBytes)
if err != nil {
return "", "", nil, err
}
return marker, host, key, nil
}
func (db *hostKeyDB) parseLine(line []byte, filename string, linenum int) error {
marker, pattern, key, err := parseLine(line)
if err != nil {
return err
}
if marker == markerRevoked {
db.revoked[string(key.Marshal())] = &KnownKey{
Key: key,
Filename: filename,
Line: linenum,
}
return nil
}
entry := keyDBLine{
cert: marker == markerCert,
knownKey: KnownKey{
Filename: filename,
Line: linenum,
Key: key,
},
}
if pattern[0] == '|' {
entry.matcher, err = newHashedHost(pattern)
} else {
entry.matcher, err = newHostnameMatcher(pattern)
}
if err != nil {
return err
}
db.lines = append(db.lines, entry)
return nil
}
func newHostnameMatcher(pattern string) (matcher, error) {
var hps hostPatterns
for _, p := range strings.Split(pattern, ",") {
if len(p) == 0 {
continue
}
var a addr
var negate bool
if p[0] == '!' {
negate = true
p = p[1:]
}
if len(p) == 0 {
return nil, errors.New("knownhosts: negation without following hostname")
}
var err error
if p[0] == '[' {
a.host, a.port, err = net.SplitHostPort(p)
if err != nil {
return nil, err
}
} else {
a.host, a.port, err = net.SplitHostPort(p)
if err != nil {
a.host = p
a.port = "22"
}
}
hps = append(hps, hostPattern{
negate: negate,
addr: a,
})
}
return hps, nil
}
// KnownKey represents a key declared in a known_hosts file.
type KnownKey struct {
Key ssh.PublicKey
Filename string
Line int
}
func (k *KnownKey) String() string {
return fmt.Sprintf("%s:%d: %s", k.Filename, k.Line, serialize(k.Key))
}
// KeyError is returned if we did not find the key in the host key
// database, or there was a mismatch. Typically, in batch
// applications, this should be interpreted as failure. Interactive
// applications can offer an interactive prompt to the user.
type KeyError struct {
// Want holds the accepted host keys. For each key algorithm,
// there can be one hostkey. If Want is empty, the host is
// unknown. If Want is non-empty, there was a mismatch, which
// can signify a MITM attack.
Want []KnownKey
}
func (u *KeyError) Error() string {
if len(u.Want) == 0 {
return "knownhosts: key is unknown"
}
return "knownhosts: key mismatch"
}
// RevokedError is returned if we found a key that was revoked.
type RevokedError struct {
Revoked KnownKey
}
func (r *RevokedError) Error() string {
return "knownhosts: key is revoked"
}
// check checks a key against the host database. This should not be
// used for verifying certificates.
func (db *hostKeyDB) check(address string, remote net.Addr, remoteKey ssh.PublicKey) error {
if revoked := db.revoked[string(remoteKey.Marshal())]; revoked != nil {
return &RevokedError{Revoked: *revoked}
}
host, port, err := net.SplitHostPort(remote.String())
if err != nil {
return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", remote, err)
}
addrs := []addr{
{host, port},
}
if address != "" {
host, port, err := net.SplitHostPort(address)
if err != nil {
return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", address, err)
}
addrs = append(addrs, addr{host, port})
}
return db.checkAddrs(addrs, remoteKey)
}
// checkAddrs checks if we can find the given public key for any of
// the given addresses. If we only find an entry for the IP address,
// or only the hostname, then this still succeeds.
func (db *hostKeyDB) checkAddrs(addrs []addr, remoteKey ssh.PublicKey) error {
// TODO(hanwen): are these the right semantics? What if there
// is just a key for the IP address, but not for the
// hostname?
// Algorithm => key.
knownKeys := map[string]KnownKey{}
for _, l := range db.lines {
if l.match(addrs) {
typ := l.knownKey.Key.Type()
if _, ok := knownKeys[typ]; !ok {
knownKeys[typ] = l.knownKey
}
}
}
keyErr := &KeyError{}
for _, v := range knownKeys {
keyErr.Want = append(keyErr.Want, v)
}
// Unknown remote host.
if len(knownKeys) == 0 {
return keyErr
}
// If the remote host starts using a different, unknown key type, we
// also interpret that as a mismatch.
if known, ok := knownKeys[remoteKey.Type()]; !ok || !keyEq(known.Key, remoteKey) {
return keyErr
}
return nil
}
// The Read function parses file contents.
func (db *hostKeyDB) Read(r io.Reader, filename string) error {
scanner := bufio.NewScanner(r)
lineNum := 0
for scanner.Scan() {
lineNum++
line := scanner.Bytes()
line = bytes.TrimSpace(line)
if len(line) == 0 || line[0] == '#' {
continue
}
if err := db.parseLine(line, filename, lineNum); err != nil {
return fmt.Errorf("knownhosts: %s:%d: %v", filename, lineNum, err)
}
}
return scanner.Err()
}
// New creates a host key callback from the given OpenSSH host key
// files. The returned callback is for use in
// ssh.ClientConfig.HostKeyCallback. Hashed hostnames are not supported.
func New(files ...string) (ssh.HostKeyCallback, error) {
db := newHostKeyDB()
for _, fn := range files {
f, err := os.Open(fn)
if err != nil {
return nil, err
}
defer f.Close()
if err := db.Read(f, fn); err != nil {
return nil, err
}
}
var certChecker ssh.CertChecker
certChecker.IsHostAuthority = db.IsHostAuthority
certChecker.IsRevoked = db.IsRevoked
certChecker.HostKeyFallback = db.check
return certChecker.CheckHostKey, nil
}
// Normalize normalizes an address into the form used in known_hosts
func Normalize(address string) string {
host, port, err := net.SplitHostPort(address)
if err != nil {
host = address
port = "22"
}
entry := host
if port != "22" {
entry = "[" + entry + "]:" + port
} else if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") {
entry = "[" + entry + "]"
}
return entry
}
// Line returns a line to add append to the known_hosts files.
func Line(addresses []string, key ssh.PublicKey) string {
var trimmed []string
for _, a := range addresses {
trimmed = append(trimmed, Normalize(a))
}
return strings.Join(trimmed, ",") + " " + serialize(key)
}
// HashHostname hashes the given hostname. The hostname is not
// normalized before hashing.
func HashHostname(hostname string) string {
// TODO(hanwen): check if we can safely normalize this always.
salt := make([]byte, sha1.Size)
_, err := rand.Read(salt)
if err != nil {
panic(fmt.Sprintf("crypto/rand failure %v", err))
}
hash := hashHost(hostname, salt)
return encodeHash(sha1HashType, salt, hash)
}
func decodeHash(encoded string) (hashType string, salt, hash []byte, err error) {
if len(encoded) == 0 || encoded[0] != '|' {
err = errors.New("knownhosts: hashed host must start with '|'")
return
}
components := strings.Split(encoded, "|")
if len(components) != 4 {
err = fmt.Errorf("knownhosts: got %d components, want 3", len(components))
return
}
hashType = components[1]
if salt, err = base64.StdEncoding.DecodeString(components[2]); err != nil {
return
}
if hash, err = base64.StdEncoding.DecodeString(components[3]); err != nil {
return
}
return
}
func encodeHash(typ string, salt []byte, hash []byte) string {
return strings.Join([]string{"",
typ,
base64.StdEncoding.EncodeToString(salt),
base64.StdEncoding.EncodeToString(hash),
}, "|")
}
// See https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
func hashHost(hostname string, salt []byte) []byte {
mac := hmac.New(sha1.New, salt)
mac.Write([]byte(hostname))
return mac.Sum(nil)
}
type hashedHost struct {
salt []byte
hash []byte
}
const sha1HashType = "1"
func newHashedHost(encoded string) (*hashedHost, error) {
typ, salt, hash, err := decodeHash(encoded)
if err != nil {
return nil, err
}
// The type field seems for future algorithm agility, but it's
// actually hardcoded in openssh currently, see
// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
if typ != sha1HashType {
return nil, fmt.Errorf("knownhosts: got hash type %s, must be '1'", typ)
}
return &hashedHost{salt: salt, hash: hash}, nil
}
func (h *hashedHost) match(addrs []addr) bool {
for _, a := range addrs {
if bytes.Equal(hashHost(Normalize(a.String()), h.salt), h.hash) {
return true
}
}
return false
}

View File

@ -0,0 +1,329 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package knownhosts
import (
"bytes"
"fmt"
"net"
"reflect"
"testing"
"golang.org/x/crypto/ssh"
)
const edKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGBAarftlLeoyf+v+nVchEZII/vna2PCV8FaX4vsF5BX"
const alternateEdKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIIXffBYeYL+WVzVru8npl5JHt2cjlr4ornFTWzoij9sx"
const ecKeyStr = "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNLCu01+wpXe3xB5olXCN4SqU2rQu0qjSRKJO4Bg+JRCPU+ENcgdA5srTU8xYDz/GEa4dzK5ldPw4J/gZgSXCMs="
var ecKey, alternateEdKey, edKey ssh.PublicKey
var testAddr = &net.TCPAddr{
IP: net.IP{198, 41, 30, 196},
Port: 22,
}
var testAddr6 = &net.TCPAddr{
IP: net.IP{198, 41, 30, 196,
1, 2, 3, 4,
1, 2, 3, 4,
1, 2, 3, 4,
},
Port: 22,
}
func init() {
var err error
ecKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(ecKeyStr))
if err != nil {
panic(err)
}
edKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(edKeyStr))
if err != nil {
panic(err)
}
alternateEdKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(alternateEdKeyStr))
if err != nil {
panic(err)
}
}
func testDB(t *testing.T, s string) *hostKeyDB {
db := newHostKeyDB()
if err := db.Read(bytes.NewBufferString(s), "testdb"); err != nil {
t.Fatalf("Read: %v", err)
}
return db
}
func TestRevoked(t *testing.T) {
db := testDB(t, "\n\n@revoked * "+edKeyStr+"\n")
want := &RevokedError{
Revoked: KnownKey{
Key: edKey,
Filename: "testdb",
Line: 3,
},
}
if err := db.check("", &net.TCPAddr{
Port: 42,
}, edKey); err == nil {
t.Fatal("no error for revoked key")
} else if !reflect.DeepEqual(want, err) {
t.Fatalf("got %#v, want %#v", want, err)
}
}
func TestHostAuthority(t *testing.T) {
for _, m := range []struct {
authorityFor string
address string
good bool
}{
{authorityFor: "localhost", address: "localhost:22", good: true},
{authorityFor: "localhost", address: "localhost", good: false},
{authorityFor: "localhost", address: "localhost:1234", good: false},
{authorityFor: "[localhost]:1234", address: "localhost:1234", good: true},
{authorityFor: "[localhost]:1234", address: "localhost:22", good: false},
{authorityFor: "[localhost]:1234", address: "localhost", good: false},
} {
db := testDB(t, `@cert-authority `+m.authorityFor+` `+edKeyStr)
if ok := db.IsHostAuthority(db.lines[0].knownKey.Key, m.address); ok != m.good {
t.Errorf("IsHostAuthority: authority %s, address %s, wanted good = %v, got good = %v",
m.authorityFor, m.address, m.good, ok)
}
}
}
func TestBracket(t *testing.T) {
db := testDB(t, `[git.eclipse.org]:29418,[198.41.30.196]:29418 `+edKeyStr)
if err := db.check("git.eclipse.org:29418", &net.TCPAddr{
IP: net.IP{198, 41, 30, 196},
Port: 29418,
}, edKey); err != nil {
t.Errorf("got error %v, want none", err)
}
if err := db.check("git.eclipse.org:29419", &net.TCPAddr{
Port: 42,
}, edKey); err == nil {
t.Fatalf("no error for unknown address")
} else if ke, ok := err.(*KeyError); !ok {
t.Fatalf("got type %T, want *KeyError", err)
} else if len(ke.Want) > 0 {
t.Fatalf("got Want %v, want []", ke.Want)
}
}
func TestNewKeyType(t *testing.T) {
str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
db := testDB(t, str)
if err := db.check("", testAddr, ecKey); err == nil {
t.Fatalf("no error for unknown address")
} else if ke, ok := err.(*KeyError); !ok {
t.Fatalf("got type %T, want *KeyError", err)
} else if len(ke.Want) == 0 {
t.Fatalf("got empty KeyError.Want")
}
}
func TestSameKeyType(t *testing.T) {
str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
db := testDB(t, str)
if err := db.check("", testAddr, alternateEdKey); err == nil {
t.Fatalf("no error for unknown address")
} else if ke, ok := err.(*KeyError); !ok {
t.Fatalf("got type %T, want *KeyError", err)
} else if len(ke.Want) == 0 {
t.Fatalf("got empty KeyError.Want")
} else if got, want := ke.Want[0].Key.Marshal(), edKey.Marshal(); !bytes.Equal(got, want) {
t.Fatalf("got key %q, want %q", got, want)
}
}
func TestIPAddress(t *testing.T) {
str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
db := testDB(t, str)
if err := db.check("", testAddr, edKey); err != nil {
t.Errorf("got error %q, want none", err)
}
}
func TestIPv6Address(t *testing.T) {
str := fmt.Sprintf("%s %s", testAddr6, edKeyStr)
db := testDB(t, str)
if err := db.check("", testAddr6, edKey); err != nil {
t.Errorf("got error %q, want none", err)
}
}
func TestBasic(t *testing.T) {
str := fmt.Sprintf("#comment\n\nserver.org,%s %s\notherhost %s", testAddr, edKeyStr, ecKeyStr)
db := testDB(t, str)
if err := db.check("server.org:22", testAddr, edKey); err != nil {
t.Errorf("got error %q, want none", err)
}
want := KnownKey{
Key: edKey,
Filename: "testdb",
Line: 3,
}
if err := db.check("server.org:22", testAddr, ecKey); err == nil {
t.Errorf("succeeded, want KeyError")
} else if ke, ok := err.(*KeyError); !ok {
t.Errorf("got %T, want *KeyError", err)
} else if len(ke.Want) != 1 {
t.Errorf("got %v, want 1 entry", ke)
} else if !reflect.DeepEqual(ke.Want[0], want) {
t.Errorf("got %v, want %v", ke.Want[0], want)
}
}
func TestNegate(t *testing.T) {
str := fmt.Sprintf("%s,!server.org %s", testAddr, edKeyStr)
db := testDB(t, str)
if err := db.check("server.org:22", testAddr, ecKey); err == nil {
t.Errorf("succeeded")
} else if ke, ok := err.(*KeyError); !ok {
t.Errorf("got error type %T, want *KeyError", err)
} else if len(ke.Want) != 0 {
t.Errorf("got expected keys %d (first of type %s), want []", len(ke.Want), ke.Want[0].Key.Type())
}
}
func TestWildcard(t *testing.T) {
str := fmt.Sprintf("server*.domain %s", edKeyStr)
db := testDB(t, str)
want := &KeyError{
Want: []KnownKey{{
Filename: "testdb",
Line: 1,
Key: edKey,
}},
}
got := db.check("server.domain:22", &net.TCPAddr{}, ecKey)
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s, want %s", got, want)
}
}
func TestLine(t *testing.T) {
for in, want := range map[string]string{
"server.org": "server.org " + edKeyStr,
"server.org:22": "server.org " + edKeyStr,
"server.org:23": "[server.org]:23 " + edKeyStr,
"[c629:1ec4:102:304:102:304:102:304]:22": "[c629:1ec4:102:304:102:304:102:304] " + edKeyStr,
"[c629:1ec4:102:304:102:304:102:304]:23": "[c629:1ec4:102:304:102:304:102:304]:23 " + edKeyStr,
} {
if got := Line([]string{in}, edKey); got != want {
t.Errorf("Line(%q) = %q, want %q", in, got, want)
}
}
}
func TestWildcardMatch(t *testing.T) {
for _, c := range []struct {
pat, str string
want bool
}{
{"a?b", "abb", true},
{"ab", "abc", false},
{"abc", "ab", false},
{"a*b", "axxxb", true},
{"a*b", "axbxb", true},
{"a*b", "axbxbc", false},
{"a*?", "axbxc", true},
{"a*b*", "axxbxxxxxx", true},
{"a*b*c", "axxbxxxxxxc", true},
{"a*b*?", "axxbxxxxxxc", true},
{"a*b*z", "axxbxxbxxxz", true},
{"a*b*z", "axxbxxzxxxz", true},
{"a*b*z", "axxbxxzxxx", false},
} {
got := wildcardMatch([]byte(c.pat), []byte(c.str))
if got != c.want {
t.Errorf("wildcardMatch(%q, %q) = %v, want %v", c.pat, c.str, got, c.want)
}
}
}
// TODO(hanwen): test coverage for certificates.
const testHostname = "hostname"
// generated with keygen -H -f
const encodedTestHostnameHash = "|1|IHXZvQMvTcZTUU29+2vXFgx8Frs=|UGccIWfRVDwilMBnA3WJoRAC75Y="
func TestHostHash(t *testing.T) {
testHostHash(t, testHostname, encodedTestHostnameHash)
}
func TestHashList(t *testing.T) {
encoded := HashHostname(testHostname)
testHostHash(t, testHostname, encoded)
}
func testHostHash(t *testing.T, hostname, encoded string) {
typ, salt, hash, err := decodeHash(encoded)
if err != nil {
t.Fatalf("decodeHash: %v", err)
}
if got := encodeHash(typ, salt, hash); got != encoded {
t.Errorf("got encoding %s want %s", got, encoded)
}
if typ != sha1HashType {
t.Fatalf("got hash type %q, want %q", typ, sha1HashType)
}
got := hashHost(hostname, salt)
if !bytes.Equal(got, hash) {
t.Errorf("got hash %x want %x", got, hash)
}
}
func TestNormalize(t *testing.T) {
for in, want := range map[string]string{
"127.0.0.1:22": "127.0.0.1",
"[127.0.0.1]:22": "127.0.0.1",
"[127.0.0.1]:23": "[127.0.0.1]:23",
"127.0.0.1:23": "[127.0.0.1]:23",
"[a.b.c]:22": "a.b.c",
"[abcd:abcd:abcd:abcd]": "[abcd:abcd:abcd:abcd]",
"[abcd:abcd:abcd:abcd]:22": "[abcd:abcd:abcd:abcd]",
"[abcd:abcd:abcd:abcd]:23": "[abcd:abcd:abcd:abcd]:23",
} {
got := Normalize(in)
if got != want {
t.Errorf("Normalize(%q) = %q, want %q", in, got, want)
}
}
}
func TestHashedHostkeyCheck(t *testing.T) {
str := fmt.Sprintf("%s %s", HashHostname(testHostname), edKeyStr)
db := testDB(t, str)
if err := db.check(testHostname+":22", testAddr, edKey); err != nil {
t.Errorf("check(%s): %v", testHostname, err)
}
want := &KeyError{
Want: []KnownKey{{
Filename: "testdb",
Line: 1,
Key: edKey,
}},
}
if got := db.check(testHostname+":22", testAddr, alternateEdKey); !reflect.DeepEqual(got, want) {
t.Errorf("got error %v, want %v", got, want)
}
}

View File

@ -15,6 +15,7 @@ import (
type macMode struct { type macMode struct {
keySize int keySize int
etm bool
new func(key []byte) hash.Hash new func(key []byte) hash.Hash
} }
@ -45,13 +46,16 @@ func (t truncatingMAC) Size() int {
func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() } func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() }
var macModes = map[string]*macMode{ var macModes = map[string]*macMode{
"hmac-sha2-256": {32, func(key []byte) hash.Hash { "hmac-sha2-256-etm@openssh.com": {32, true, func(key []byte) hash.Hash {
return hmac.New(sha256.New, key) return hmac.New(sha256.New, key)
}}, }},
"hmac-sha1": {20, func(key []byte) hash.Hash { "hmac-sha2-256": {32, false, func(key []byte) hash.Hash {
return hmac.New(sha256.New, key)
}},
"hmac-sha1": {20, false, func(key []byte) hash.Hash {
return hmac.New(sha1.New, key) return hmac.New(sha1.New, key)
}}, }},
"hmac-sha1-96": {20, func(key []byte) hash.Hash { "hmac-sha1-96": {20, false, func(key []byte) hash.Hash {
return truncatingMAC{12, hmac.New(sha1.New, key)} return truncatingMAC{12, hmac.New(sha1.New, key)}
}}, }},
} }

View File

@ -116,9 +116,9 @@ func (m *mux) Wait() error {
func newMux(p packetConn) *mux { func newMux(p packetConn) *mux {
m := &mux{ m := &mux{
conn: p, conn: p,
incomingChannels: make(chan NewChannel, 16), incomingChannels: make(chan NewChannel, chanSize),
globalResponses: make(chan interface{}, 1), globalResponses: make(chan interface{}, 1),
incomingRequests: make(chan *Request, 16), incomingRequests: make(chan *Request, chanSize),
errCond: newCond(), errCond: newCond(),
} }
if debugMux { if debugMux {

View File

@ -499,4 +499,7 @@ func TestDebug(t *testing.T) {
if debugHandshake { if debugHandshake {
t.Error("handshake debug switched on") t.Error("handshake debug switched on")
} }
if debugTransport {
t.Error("transport debug switched on")
}
} }

View File

@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"strings"
) )
// The Permissions type holds fine-grained permissions that are // The Permissions type holds fine-grained permissions that are
@ -44,6 +45,12 @@ type ServerConfig struct {
// authenticating. // authenticating.
NoClientAuth bool NoClientAuth bool
// MaxAuthTries specifies the maximum number of authentication attempts
// permitted per connection. If set to a negative number, the number of
// attempts are unlimited. If set to zero, the number of attempts are limited
// to 6.
MaxAuthTries int
// PasswordCallback, if non-nil, is called when a user // PasswordCallback, if non-nil, is called when a user
// attempts to authenticate using a password. // attempts to authenticate using a password.
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
@ -142,6 +149,10 @@ type ServerConn struct {
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) { func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) {
fullConf := *config fullConf := *config
fullConf.SetDefaults() fullConf.SetDefaults()
if fullConf.MaxAuthTries == 0 {
fullConf.MaxAuthTries = 6
}
s := &connection{ s := &connection{
sshConn: sshConn{conn: c}, sshConn: sshConn{conn: c},
} }
@ -188,7 +199,7 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */) tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */)
s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config) s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config)
if err := s.transport.requestInitialKeyChange(); err != nil { if err := s.transport.waitSession(); err != nil {
return nil, err return nil, err
} }
@ -231,7 +242,7 @@ func isAcceptableAlgo(algo string) bool {
return false return false
} }
func checkSourceAddress(addr net.Addr, sourceAddr string) error { func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
if addr == nil { if addr == nil {
return errors.New("ssh: no address known for client, but source-address match required") return errors.New("ssh: no address known for client, but source-address match required")
} }
@ -241,18 +252,20 @@ func checkSourceAddress(addr net.Addr, sourceAddr string) error {
return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr) return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr)
} }
if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil { for _, sourceAddr := range strings.Split(sourceAddrs, ",") {
if allowedIP.Equal(tcpAddr.IP) { if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil {
return nil if allowedIP.Equal(tcpAddr.IP) {
} return nil
} else { }
_, ipNet, err := net.ParseCIDR(sourceAddr) } else {
if err != nil { _, ipNet, err := net.ParseCIDR(sourceAddr)
return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err) if err != nil {
} return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err)
}
if ipNet.Contains(tcpAddr.IP) { if ipNet.Contains(tcpAddr.IP) {
return nil return nil
}
} }
} }
@ -260,12 +273,27 @@ func checkSourceAddress(addr net.Addr, sourceAddr string) error {
} }
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
var err error sessionID := s.transport.getSessionID()
var cache pubKeyCache var cache pubKeyCache
var perms *Permissions var perms *Permissions
authFailures := 0
userAuthLoop: userAuthLoop:
for { for {
if authFailures >= config.MaxAuthTries && config.MaxAuthTries > 0 {
discMsg := &disconnectMsg{
Reason: 2,
Message: "too many authentication failures",
}
if err := s.transport.writePacket(Marshal(discMsg)); err != nil {
return nil, err
}
return nil, discMsg
}
var userAuthReq userAuthRequestMsg var userAuthReq userAuthRequestMsg
if packet, err := s.transport.readPacket(); err != nil { if packet, err := s.transport.readPacket(); err != nil {
return nil, err return nil, err
@ -286,6 +314,11 @@ userAuthLoop:
if config.NoClientAuth { if config.NoClientAuth {
authErr = nil authErr = nil
} }
// allow initial attempt of 'none' without penalty
if authFailures == 0 {
authFailures--
}
case "password": case "password":
if config.PasswordCallback == nil { if config.PasswordCallback == nil {
authErr = errors.New("ssh: password auth not configured") authErr = errors.New("ssh: password auth not configured")
@ -357,6 +390,7 @@ userAuthLoop:
if isQuery { if isQuery {
// The client can query if the given public key // The client can query if the given public key
// would be okay. // would be okay.
if len(payload) > 0 { if len(payload) > 0 {
return nil, parseError(msgUserAuthRequest) return nil, parseError(msgUserAuthRequest)
} }
@ -385,7 +419,7 @@ userAuthLoop:
if !isAcceptableAlgo(sig.Format) { if !isAcceptableAlgo(sig.Format) {
break break
} }
signedData := buildDataSignedForAuth(s.transport.getSessionID(), userAuthReq, algoBytes, pubKeyData) signedData := buildDataSignedForAuth(sessionID, userAuthReq, algoBytes, pubKeyData)
if err := pubKey.Verify(signedData, sig); err != nil { if err := pubKey.Verify(signedData, sig); err != nil {
return nil, err return nil, err
@ -406,6 +440,8 @@ userAuthLoop:
break userAuthLoop break userAuthLoop
} }
authFailures++
var failureMsg userAuthFailureMsg var failureMsg userAuthFailureMsg
if config.PasswordCallback != nil { if config.PasswordCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "password") failureMsg.Methods = append(failureMsg.Methods, "password")
@ -421,12 +457,12 @@ userAuthLoop:
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
} }
if err = s.transport.writePacket(Marshal(&failureMsg)); err != nil { if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {
return nil, err return nil, err
} }
} }
if err = s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil { if err := s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
return nil, err return nil, err
} }
return perms, nil return perms, nil

View File

@ -59,7 +59,8 @@ func dial(handler serverType, t *testing.T) *Client {
}() }()
config := &ClientConfig{ config := &ClientConfig{
User: "testuser", User: "testuser",
HostKeyCallback: InsecureIgnoreHostKey(),
} }
conn, chans, reqs, err := NewClientConn(c2, "", config) conn, chans, reqs, err := NewClientConn(c2, "", config)
@ -641,7 +642,8 @@ func TestSessionID(t *testing.T) {
} }
serverConf.AddHostKey(testSigners["ecdsa"]) serverConf.AddHostKey(testSigners["ecdsa"])
clientConf := &ClientConfig{ clientConf := &ClientConfig{
User: "user", HostKeyCallback: InsecureIgnoreHostKey(),
User: "user",
} }
go func() { go func() {
@ -747,7 +749,9 @@ func TestHostKeyAlgorithms(t *testing.T) {
// By default, we get the preferred algorithm, which is ECDSA 256. // By default, we get the preferred algorithm, which is ECDSA 256.
clientConf := &ClientConfig{} clientConf := &ClientConfig{
HostKeyCallback: InsecureIgnoreHostKey(),
}
connect(clientConf, KeyAlgoECDSA256) connect(clientConf, KeyAlgoECDSA256)
// Client asks for RSA explicitly. // Client asks for RSA explicitly.

115
x/crypto/ssh/streamlocal.go Normal file
View File

@ -0,0 +1,115 @@
package ssh
import (
"errors"
"io"
"net"
)
// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message
// with "direct-streamlocal@openssh.com" string.
//
// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235
type streamLocalChannelOpenDirectMsg struct {
socketPath string
reserved0 string
reserved1 uint32
}
// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message
// with "forwarded-streamlocal@openssh.com" string.
type forwardedStreamLocalPayload struct {
SocketPath string
Reserved0 string
}
// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message
// with "streamlocal-forward@openssh.com"/"cancel-streamlocal-forward@openssh.com" string.
type streamLocalChannelForwardMsg struct {
socketPath string
}
// ListenUnix is similar to ListenTCP but uses a Unix domain socket.
func (c *Client) ListenUnix(socketPath string) (net.Listener, error) {
m := streamLocalChannelForwardMsg{
socketPath,
}
// send message
ok, _, err := c.SendRequest("streamlocal-forward@openssh.com", true, Marshal(&m))
if err != nil {
return nil, err
}
if !ok {
return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer")
}
ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"})
return &unixListener{socketPath, c, ch}, nil
}
func (c *Client) dialStreamLocal(socketPath string) (Channel, error) {
msg := streamLocalChannelOpenDirectMsg{
socketPath: socketPath,
}
ch, in, err := c.OpenChannel("direct-streamlocal@openssh.com", Marshal(&msg))
if err != nil {
return nil, err
}
go DiscardRequests(in)
return ch, err
}
type unixListener struct {
socketPath string
conn *Client
in <-chan forward
}
// Accept waits for and returns the next connection to the listener.
func (l *unixListener) Accept() (net.Conn, error) {
s, ok := <-l.in
if !ok {
return nil, io.EOF
}
ch, incoming, err := s.newCh.Accept()
if err != nil {
return nil, err
}
go DiscardRequests(incoming)
return &chanConn{
Channel: ch,
laddr: &net.UnixAddr{
Name: l.socketPath,
Net: "unix",
},
raddr: &net.UnixAddr{
Name: "@",
Net: "unix",
},
}, nil
}
// Close closes the listener.
func (l *unixListener) Close() error {
// this also closes the listener.
l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"})
m := streamLocalChannelForwardMsg{
l.socketPath,
}
ok, _, err := l.conn.SendRequest("cancel-streamlocal-forward@openssh.com", true, Marshal(&m))
if err == nil && !ok {
err = errors.New("ssh: cancel-streamlocal-forward@openssh.com failed")
}
return err
}
// Addr returns the listener's network address.
func (l *unixListener) Addr() net.Addr {
return &net.UnixAddr{
Name: l.socketPath,
Net: "unix",
}
}

View File

@ -20,12 +20,20 @@ import (
// addr. Incoming connections will be available by calling Accept on // addr. Incoming connections will be available by calling Accept on
// the returned net.Listener. The listener must be serviced, or the // the returned net.Listener. The listener must be serviced, or the
// SSH connection may hang. // SSH connection may hang.
// N must be "tcp", "tcp4", "tcp6", or "unix".
func (c *Client) Listen(n, addr string) (net.Listener, error) { func (c *Client) Listen(n, addr string) (net.Listener, error) {
laddr, err := net.ResolveTCPAddr(n, addr) switch n {
if err != nil { case "tcp", "tcp4", "tcp6":
return nil, err laddr, err := net.ResolveTCPAddr(n, addr)
if err != nil {
return nil, err
}
return c.ListenTCP(laddr)
case "unix":
return c.ListenUnix(addr)
default:
return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
} }
return c.ListenTCP(laddr)
} }
// Automatic port allocation is broken with OpenSSH before 6.0. See // Automatic port allocation is broken with OpenSSH before 6.0. See
@ -116,7 +124,7 @@ func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
} }
// Register this forward, using the port number we obtained. // Register this forward, using the port number we obtained.
ch := c.forwards.add(*laddr) ch := c.forwards.add(laddr)
return &tcpListener{laddr, c, ch}, nil return &tcpListener{laddr, c, ch}, nil
} }
@ -131,7 +139,7 @@ type forwardList struct {
// forwardEntry represents an established mapping of a laddr on a // forwardEntry represents an established mapping of a laddr on a
// remote ssh server to a channel connected to a tcpListener. // remote ssh server to a channel connected to a tcpListener.
type forwardEntry struct { type forwardEntry struct {
laddr net.TCPAddr laddr net.Addr
c chan forward c chan forward
} }
@ -139,16 +147,16 @@ type forwardEntry struct {
// arguments to add/remove/lookup should be address as specified in // arguments to add/remove/lookup should be address as specified in
// the original forward-request. // the original forward-request.
type forward struct { type forward struct {
newCh NewChannel // the ssh client channel underlying this forward newCh NewChannel // the ssh client channel underlying this forward
raddr *net.TCPAddr // the raddr of the incoming connection raddr net.Addr // the raddr of the incoming connection
} }
func (l *forwardList) add(addr net.TCPAddr) chan forward { func (l *forwardList) add(addr net.Addr) chan forward {
l.Lock() l.Lock()
defer l.Unlock() defer l.Unlock()
f := forwardEntry{ f := forwardEntry{
addr, laddr: addr,
make(chan forward, 1), c: make(chan forward, 1),
} }
l.entries = append(l.entries, f) l.entries = append(l.entries, f)
return f.c return f.c
@ -176,44 +184,69 @@ func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
func (l *forwardList) handleChannels(in <-chan NewChannel) { func (l *forwardList) handleChannels(in <-chan NewChannel) {
for ch := range in { for ch := range in {
var payload forwardedTCPPayload var (
if err := Unmarshal(ch.ExtraData(), &payload); err != nil { laddr net.Addr
ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) raddr net.Addr
continue err error
} )
switch channelType := ch.ChannelType(); channelType {
case "forwarded-tcpip":
var payload forwardedTCPPayload
if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
continue
}
// RFC 4254 section 7.2 specifies that incoming // RFC 4254 section 7.2 specifies that incoming
// addresses should list the address, in string // addresses should list the address, in string
// format. It is implied that this should be an IP // format. It is implied that this should be an IP
// address, as it would be impossible to connect to it // address, as it would be impossible to connect to it
// otherwise. // otherwise.
laddr, err := parseTCPAddr(payload.Addr, payload.Port) laddr, err = parseTCPAddr(payload.Addr, payload.Port)
if err != nil { if err != nil {
ch.Reject(ConnectionFailed, err.Error()) ch.Reject(ConnectionFailed, err.Error())
continue continue
} }
raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort) raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
if err != nil { if err != nil {
ch.Reject(ConnectionFailed, err.Error()) ch.Reject(ConnectionFailed, err.Error())
continue continue
} }
if ok := l.forward(*laddr, *raddr, ch); !ok { case "forwarded-streamlocal@openssh.com":
var payload forwardedStreamLocalPayload
if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
continue
}
laddr = &net.UnixAddr{
Name: payload.SocketPath,
Net: "unix",
}
raddr = &net.UnixAddr{
Name: "@",
Net: "unix",
}
default:
panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
}
if ok := l.forward(laddr, raddr, ch); !ok {
// Section 7.2, implementations MUST reject spurious incoming // Section 7.2, implementations MUST reject spurious incoming
// connections. // connections.
ch.Reject(Prohibited, "no forward for address") ch.Reject(Prohibited, "no forward for address")
continue continue
} }
} }
} }
// remove removes the forward entry, and the channel feeding its // remove removes the forward entry, and the channel feeding its
// listener. // listener.
func (l *forwardList) remove(addr net.TCPAddr) { func (l *forwardList) remove(addr net.Addr) {
l.Lock() l.Lock()
defer l.Unlock() defer l.Unlock()
for i, f := range l.entries { for i, f := range l.entries {
if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port { if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
l.entries = append(l.entries[:i], l.entries[i+1:]...) l.entries = append(l.entries[:i], l.entries[i+1:]...)
close(f.c) close(f.c)
return return
@ -231,12 +264,12 @@ func (l *forwardList) closeAll() {
l.entries = nil l.entries = nil
} }
func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool { func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
l.Lock() l.Lock()
defer l.Unlock() defer l.Unlock()
for _, f := range l.entries { for _, f := range l.entries {
if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port { if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
f.c <- forward{ch, &raddr} f.c <- forward{newCh: ch, raddr: raddr}
return true return true
} }
} }
@ -262,7 +295,7 @@ func (l *tcpListener) Accept() (net.Conn, error) {
} }
go DiscardRequests(incoming) go DiscardRequests(incoming)
return &tcpChanConn{ return &chanConn{
Channel: ch, Channel: ch,
laddr: l.laddr, laddr: l.laddr,
raddr: s.raddr, raddr: s.raddr,
@ -277,7 +310,7 @@ func (l *tcpListener) Close() error {
} }
// this also closes the listener. // this also closes the listener.
l.conn.forwards.remove(*l.laddr) l.conn.forwards.remove(l.laddr)
ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m)) ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
if err == nil && !ok { if err == nil && !ok {
err = errors.New("ssh: cancel-tcpip-forward failed") err = errors.New("ssh: cancel-tcpip-forward failed")
@ -293,29 +326,52 @@ func (l *tcpListener) Addr() net.Addr {
// Dial initiates a connection to the addr from the remote host. // Dial initiates a connection to the addr from the remote host.
// The resulting connection has a zero LocalAddr() and RemoteAddr(). // The resulting connection has a zero LocalAddr() and RemoteAddr().
func (c *Client) Dial(n, addr string) (net.Conn, error) { func (c *Client) Dial(n, addr string) (net.Conn, error) {
// Parse the address into host and numeric port. var ch Channel
host, portString, err := net.SplitHostPort(addr) switch n {
if err != nil { case "tcp", "tcp4", "tcp6":
return nil, err // Parse the address into host and numeric port.
host, portString, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, err
}
ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
if err != nil {
return nil, err
}
// Use a zero address for local and remote address.
zeroAddr := &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
return &chanConn{
Channel: ch,
laddr: zeroAddr,
raddr: zeroAddr,
}, nil
case "unix":
var err error
ch, err = c.dialStreamLocal(addr)
if err != nil {
return nil, err
}
return &chanConn{
Channel: ch,
laddr: &net.UnixAddr{
Name: "@",
Net: "unix",
},
raddr: &net.UnixAddr{
Name: addr,
Net: "unix",
},
}, nil
default:
return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
} }
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, err
}
// Use a zero address for local and remote address.
zeroAddr := &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port))
if err != nil {
return nil, err
}
return &tcpChanConn{
Channel: ch,
laddr: zeroAddr,
raddr: zeroAddr,
}, nil
} }
// DialTCP connects to the remote address raddr on the network net, // DialTCP connects to the remote address raddr on the network net,
@ -332,7 +388,7 @@ func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &tcpChanConn{ return &chanConn{
Channel: ch, Channel: ch,
laddr: laddr, laddr: laddr,
raddr: raddr, raddr: raddr,
@ -366,26 +422,26 @@ type tcpChan struct {
Channel // the backing channel Channel // the backing channel
} }
// tcpChanConn fulfills the net.Conn interface without // chanConn fulfills the net.Conn interface without
// the tcpChan having to hold laddr or raddr directly. // the tcpChan having to hold laddr or raddr directly.
type tcpChanConn struct { type chanConn struct {
Channel Channel
laddr, raddr net.Addr laddr, raddr net.Addr
} }
// LocalAddr returns the local network address. // LocalAddr returns the local network address.
func (t *tcpChanConn) LocalAddr() net.Addr { func (t *chanConn) LocalAddr() net.Addr {
return t.laddr return t.laddr
} }
// RemoteAddr returns the remote network address. // RemoteAddr returns the remote network address.
func (t *tcpChanConn) RemoteAddr() net.Addr { func (t *chanConn) RemoteAddr() net.Addr {
return t.raddr return t.raddr
} }
// SetDeadline sets the read and write deadlines associated // SetDeadline sets the read and write deadlines associated
// with the connection. // with the connection.
func (t *tcpChanConn) SetDeadline(deadline time.Time) error { func (t *chanConn) SetDeadline(deadline time.Time) error {
if err := t.SetReadDeadline(deadline); err != nil { if err := t.SetReadDeadline(deadline); err != nil {
return err return err
} }
@ -396,12 +452,14 @@ func (t *tcpChanConn) SetDeadline(deadline time.Time) error {
// A zero value for t means Read will not time out. // A zero value for t means Read will not time out.
// After the deadline, the error from Read will implement net.Error // After the deadline, the error from Read will implement net.Error
// with Timeout() == true. // with Timeout() == true.
func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error { func (t *chanConn) SetReadDeadline(deadline time.Time) error {
// for compatibility with previous version,
// the error message contains "tcpChan"
return errors.New("ssh: tcpChan: deadline not supported") return errors.New("ssh: tcpChan: deadline not supported")
} }
// SetWriteDeadline exists to satisfy the net.Conn interface // SetWriteDeadline exists to satisfy the net.Conn interface
// but is not implemented by this type. It always returns an error. // but is not implemented by this type. It always returns an error.
func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error { func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
return errors.New("ssh: tcpChan: deadline not supported") return errors.New("ssh: tcpChan: deadline not supported")
} }

View File

@ -596,7 +596,7 @@ func (t *Terminal) writeLine(line []rune) {
} }
} }
// writeWithCRLF writes buf to w but replaces all occurances of \n with \r\n. // writeWithCRLF writes buf to w but replaces all occurrences of \n with \r\n.
func writeWithCRLF(w io.Writer, buf []byte) (n int, err error) { func writeWithCRLF(w io.Writer, buf []byte) (n int, err error) {
for len(buf) > 0 { for len(buf) > 0 {
i := bytes.IndexByte(buf, '\n') i := bytes.IndexByte(buf, '\n')
@ -772,8 +772,6 @@ func (t *Terminal) readLine() (line string, err error) {
t.remainder = t.inBuf[:n+len(t.remainder)] t.remainder = t.inBuf[:n+len(t.remainder)]
} }
panic("unreachable") // for Go 1.0.
} }
// SetPrompt sets the prompt to be used when reading subsequent lines. // SetPrompt sets the prompt to be used when reading subsequent lines.
@ -922,3 +920,32 @@ func (s *stRingBuffer) NthPreviousEntry(n int) (value string, ok bool) {
} }
return s.entries[index], true return s.entries[index], true
} }
// readPasswordLine reads from reader until it finds \n or io.EOF.
// The slice returned does not include the \n.
// readPasswordLine also ignores any \r it finds.
func readPasswordLine(reader io.Reader) ([]byte, error) {
var buf [1]byte
var ret []byte
for {
n, err := reader.Read(buf[:])
if n > 0 {
switch buf[0] {
case '\n':
return ret, nil
case '\r':
// remove \r from passwords on Windows
default:
ret = append(ret, buf[0])
}
continue
}
if err != nil {
if err == io.EOF && len(ret) > 0 {
return ret, nil
}
return ret, err
}
}
}

View File

@ -270,6 +270,50 @@ func TestTerminalSetSize(t *testing.T) {
} }
} }
func TestReadPasswordLineEnd(t *testing.T) {
var tests = []struct {
input string
want string
}{
{"\n", ""},
{"\r\n", ""},
{"test\r\n", "test"},
{"testtesttesttes\n", "testtesttesttes"},
{"testtesttesttes\r\n", "testtesttesttes"},
{"testtesttesttesttest\n", "testtesttesttesttest"},
{"testtesttesttesttest\r\n", "testtesttesttesttest"},
}
for _, test := range tests {
buf := new(bytes.Buffer)
if _, err := buf.WriteString(test.input); err != nil {
t.Fatal(err)
}
have, err := readPasswordLine(buf)
if err != nil {
t.Errorf("readPasswordLine(%q) failed: %v", test.input, err)
continue
}
if string(have) != test.want {
t.Errorf("readPasswordLine(%q) returns %q, but %q is expected", test.input, string(have), test.want)
continue
}
if _, err = buf.WriteString(test.input); err != nil {
t.Fatal(err)
}
have, err = readPasswordLine(buf)
if err != nil {
t.Errorf("readPasswordLine(%q) failed: %v", test.input, err)
continue
}
if string(have) != test.want {
t.Errorf("readPasswordLine(%q) returns %q, but %q is expected", test.input, string(have), test.want)
continue
}
}
}
func TestMakeRawState(t *testing.T) { func TestMakeRawState(t *testing.T) {
fd := int(os.Stdout.Fd()) fd := int(os.Stdout.Fd())
if !IsTerminal(fd) { if !IsTerminal(fd) {

View File

@ -17,7 +17,6 @@
package terminal // import "golang.org/x/crypto/ssh/terminal" package terminal // import "golang.org/x/crypto/ssh/terminal"
import ( import (
"io"
"syscall" "syscall"
"unsafe" "unsafe"
) )
@ -72,8 +71,10 @@ func GetState(fd int) (*State, error) {
// Restore restores the terminal connected to the given file descriptor to a // Restore restores the terminal connected to the given file descriptor to a
// previous state. // previous state.
func Restore(fd int, state *State) error { func Restore(fd int, state *State) error {
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&state.termios)), 0, 0, 0) if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&state.termios)), 0, 0, 0); err != 0 {
return err return err
}
return nil
} }
// GetSize returns the dimensions of the given terminal. // GetSize returns the dimensions of the given terminal.
@ -86,6 +87,13 @@ func GetSize(fd int) (width, height int, err error) {
return int(dimensions[1]), int(dimensions[0]), nil return int(dimensions[1]), int(dimensions[0]), nil
} }
// passwordReader is an io.Reader that reads from a specific file descriptor.
type passwordReader int
func (r passwordReader) Read(buf []byte) (int, error) {
return syscall.Read(int(r), buf)
}
// ReadPassword reads a line of input from a terminal without local echo. This // ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice // is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n. // returned does not include the \n.
@ -107,27 +115,5 @@ func ReadPassword(fd int) ([]byte, error) {
syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0) syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0)
}() }()
var buf [16]byte return readPasswordLine(passwordReader(fd))
var ret []byte
for {
n, err := syscall.Read(fd, buf[:])
if err != nil {
return nil, err
}
if n == 0 {
if len(ret) == 0 {
return nil, io.EOF
}
break
}
if buf[n-1] == '\n' {
n--
}
ret = append(ret, buf[:n]...)
if n < len(buf) {
break
}
}
return ret, nil
} }

View File

@ -14,14 +14,12 @@ import (
// State contains the state of a terminal. // State contains the state of a terminal.
type State struct { type State struct {
termios syscall.Termios state *unix.Termios
} }
// IsTerminal returns true if the given file descriptor is a terminal. // IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd int) bool { func IsTerminal(fd int) bool {
// see: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libbc/libc/gen/common/isatty.c _, err := unix.IoctlGetTermio(fd, unix.TCGETA)
var termio unix.Termio
err := unix.IoctlSetTermio(fd, unix.TCGETA, &termio)
return err == nil return err == nil
} }
@ -71,3 +69,60 @@ func ReadPassword(fd int) ([]byte, error) {
return ret, nil return ret, nil
} }
// MakeRaw puts the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
// see http://cr.illumos.org/~webrev/andy_js/1060/
func MakeRaw(fd int) (*State, error) {
oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
oldTermios := *oldTermiosPtr
newTermios := oldTermios
newTermios.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON
newTermios.Oflag &^= syscall.OPOST
newTermios.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN
newTermios.Cflag &^= syscall.CSIZE | syscall.PARENB
newTermios.Cflag |= syscall.CS8
newTermios.Cc[unix.VMIN] = 1
newTermios.Cc[unix.VTIME] = 0
if err := unix.IoctlSetTermios(fd, unix.TCSETS, &newTermios); err != nil {
return nil, err
}
return &State{
state: oldTermiosPtr,
}, nil
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, oldState *State) error {
return unix.IoctlSetTermios(fd, unix.TCSETS, oldState.state)
}
// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
return &State{
state: oldTermiosPtr,
}, nil
}
// GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) {
ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ)
if err != nil {
return 0, 0, err
}
return int(ws.Col), int(ws.Row), nil
}

View File

@ -17,7 +17,6 @@
package terminal package terminal
import ( import (
"io"
"syscall" "syscall"
"unsafe" "unsafe"
) )
@ -123,6 +122,13 @@ func GetSize(fd int) (width, height int, err error) {
return int(info.size.x), int(info.size.y), nil return int(info.size.x), int(info.size.y), nil
} }
// passwordReader is an io.Reader that reads from a specific Windows HANDLE.
type passwordReader int
func (r passwordReader) Read(buf []byte) (int, error) {
return syscall.Read(syscall.Handle(r), buf)
}
// ReadPassword reads a line of input from a terminal without local echo. This // ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice // is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n. // returned does not include the \n.
@ -145,30 +151,5 @@ func ReadPassword(fd int) ([]byte, error) {
syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0) syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0)
}() }()
var buf [16]byte return readPasswordLine(passwordReader(fd))
var ret []byte
for {
n, err := syscall.Read(syscall.Handle(fd), buf[:])
if err != nil {
return nil, err
}
if n == 0 {
if len(ret) == 0 {
return nil, io.EOF
}
break
}
if buf[n-1] == '\n' {
n--
}
if n > 0 && buf[n-1] == '\r' {
n--
}
ret = append(ret, buf[:n]...)
if n < len(buf) {
break
}
}
return ret, nil
} }

View File

@ -36,7 +36,8 @@ func TestCertLogin(t *testing.T) {
} }
conf := &ssh.ClientConfig{ conf := &ssh.ClientConfig{
User: username(), User: username(),
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
} }
conf.Auth = append(conf.Auth, ssh.PublicKeys(certSigner)) conf.Auth = append(conf.Auth, ssh.PublicKeys(certSigner))
client, err := s.TryDial(conf) client, err := s.TryDial(conf)

View File

@ -0,0 +1,128 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !windows
package test
// direct-tcpip and direct-streamlocal functional tests
import (
"fmt"
"io"
"io/ioutil"
"net"
"strings"
"testing"
)
type dialTester interface {
TestServerConn(t *testing.T, c net.Conn)
TestClientConn(t *testing.T, c net.Conn)
}
func testDial(t *testing.T, n, listenAddr string, x dialTester) {
server := newServer(t)
defer server.Shutdown()
sshConn := server.Dial(clientConfig())
defer sshConn.Close()
l, err := net.Listen(n, listenAddr)
if err != nil {
t.Fatalf("Listen: %v", err)
}
defer l.Close()
testData := fmt.Sprintf("hello from %s, %s", n, listenAddr)
go func() {
for {
c, err := l.Accept()
if err != nil {
break
}
x.TestServerConn(t, c)
io.WriteString(c, testData)
c.Close()
}
}()
conn, err := sshConn.Dial(n, l.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
x.TestClientConn(t, conn)
defer conn.Close()
b, err := ioutil.ReadAll(conn)
if err != nil {
t.Fatalf("ReadAll: %v", err)
}
t.Logf("got %q", string(b))
if string(b) != testData {
t.Fatalf("expected %q, got %q", testData, string(b))
}
}
type tcpDialTester struct {
listenAddr string
}
func (x *tcpDialTester) TestServerConn(t *testing.T, c net.Conn) {
host := strings.Split(x.listenAddr, ":")[0]
prefix := host + ":"
if !strings.HasPrefix(c.LocalAddr().String(), prefix) {
t.Fatalf("expected to start with %q, got %q", prefix, c.LocalAddr().String())
}
if !strings.HasPrefix(c.RemoteAddr().String(), prefix) {
t.Fatalf("expected to start with %q, got %q", prefix, c.RemoteAddr().String())
}
}
func (x *tcpDialTester) TestClientConn(t *testing.T, c net.Conn) {
// we use zero addresses. see *Client.Dial.
if c.LocalAddr().String() != "0.0.0.0:0" {
t.Fatalf("expected \"0.0.0.0:0\", got %q", c.LocalAddr().String())
}
if c.RemoteAddr().String() != "0.0.0.0:0" {
t.Fatalf("expected \"0.0.0.0:0\", got %q", c.RemoteAddr().String())
}
}
func TestDialTCP(t *testing.T) {
x := &tcpDialTester{
listenAddr: "127.0.0.1:0",
}
testDial(t, "tcp", x.listenAddr, x)
}
type unixDialTester struct {
listenAddr string
}
func (x *unixDialTester) TestServerConn(t *testing.T, c net.Conn) {
if c.LocalAddr().String() != x.listenAddr {
t.Fatalf("expected %q, got %q", x.listenAddr, c.LocalAddr().String())
}
if c.RemoteAddr().String() != "@" {
t.Fatalf("expected \"@\", got %q", c.RemoteAddr().String())
}
}
func (x *unixDialTester) TestClientConn(t *testing.T, c net.Conn) {
if c.RemoteAddr().String() != x.listenAddr {
t.Fatalf("expected %q, got %q", x.listenAddr, c.RemoteAddr().String())
}
if c.LocalAddr().String() != "@" {
t.Fatalf("expected \"@\", got %q", c.LocalAddr().String())
}
}
func TestDialUnix(t *testing.T) {
addr, cleanup := newTempSocket(t)
defer cleanup()
x := &unixDialTester{
listenAddr: addr,
}
testDial(t, "unix", x.listenAddr, x)
}

View File

@ -16,13 +16,17 @@ import (
"time" "time"
) )
func TestPortForward(t *testing.T) { type closeWriter interface {
CloseWrite() error
}
func testPortForward(t *testing.T, n, listenAddr string) {
server := newServer(t) server := newServer(t)
defer server.Shutdown() defer server.Shutdown()
conn := server.Dial(clientConfig()) conn := server.Dial(clientConfig())
defer conn.Close() defer conn.Close()
sshListener, err := conn.Listen("tcp", "localhost:0") sshListener, err := conn.Listen(n, listenAddr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -41,14 +45,14 @@ func TestPortForward(t *testing.T) {
}() }()
forwardedAddr := sshListener.Addr().String() forwardedAddr := sshListener.Addr().String()
tcpConn, err := net.Dial("tcp", forwardedAddr) netConn, err := net.Dial(n, forwardedAddr)
if err != nil { if err != nil {
t.Fatalf("TCP dial failed: %v", err) t.Fatalf("net dial failed: %v", err)
} }
readChan := make(chan []byte) readChan := make(chan []byte)
go func() { go func() {
data, _ := ioutil.ReadAll(tcpConn) data, _ := ioutil.ReadAll(netConn)
readChan <- data readChan <- data
}() }()
@ -62,14 +66,14 @@ func TestPortForward(t *testing.T) {
for len(sent) < 1000*1000 { for len(sent) < 1000*1000 {
// Send random sized chunks // Send random sized chunks
m := rand.Intn(len(data)) m := rand.Intn(len(data))
n, err := tcpConn.Write(data[:m]) n, err := netConn.Write(data[:m])
if err != nil { if err != nil {
break break
} }
sent = append(sent, data[:n]...) sent = append(sent, data[:n]...)
} }
if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil { if err := netConn.(closeWriter).CloseWrite(); err != nil {
t.Errorf("tcpConn.CloseWrite: %v", err) t.Errorf("netConn.CloseWrite: %v", err)
} }
read := <-readChan read := <-readChan
@ -86,19 +90,29 @@ func TestPortForward(t *testing.T) {
} }
// Check that the forward disappeared. // Check that the forward disappeared.
tcpConn, err = net.Dial("tcp", forwardedAddr) netConn, err = net.Dial(n, forwardedAddr)
if err == nil { if err == nil {
tcpConn.Close() netConn.Close()
t.Errorf("still listening to %s after closing", forwardedAddr) t.Errorf("still listening to %s after closing", forwardedAddr)
} }
} }
func TestAcceptClose(t *testing.T) { func TestPortForwardTCP(t *testing.T) {
testPortForward(t, "tcp", "localhost:0")
}
func TestPortForwardUnix(t *testing.T) {
addr, cleanup := newTempSocket(t)
defer cleanup()
testPortForward(t, "unix", addr)
}
func testAcceptClose(t *testing.T, n, listenAddr string) {
server := newServer(t) server := newServer(t)
defer server.Shutdown() defer server.Shutdown()
conn := server.Dial(clientConfig()) conn := server.Dial(clientConfig())
sshListener, err := conn.Listen("tcp", "localhost:0") sshListener, err := conn.Listen(n, listenAddr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -124,13 +138,23 @@ func TestAcceptClose(t *testing.T) {
} }
} }
func TestAcceptCloseTCP(t *testing.T) {
testAcceptClose(t, "tcp", "localhost:0")
}
func TestAcceptCloseUnix(t *testing.T) {
addr, cleanup := newTempSocket(t)
defer cleanup()
testAcceptClose(t, "unix", addr)
}
// Check that listeners exit if the underlying client transport dies. // Check that listeners exit if the underlying client transport dies.
func TestPortForwardConnectionClose(t *testing.T) { func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) {
server := newServer(t) server := newServer(t)
defer server.Shutdown() defer server.Shutdown()
conn := server.Dial(clientConfig()) conn := server.Dial(clientConfig())
sshListener, err := conn.Listen("tcp", "localhost:0") sshListener, err := conn.Listen(n, listenAddr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -158,3 +182,13 @@ func TestPortForwardConnectionClose(t *testing.T) {
t.Logf("quit as expected (error %v)", err) t.Logf("quit as expected (error %v)", err)
} }
} }
func TestPortForwardConnectionCloseTCP(t *testing.T) {
testPortForwardConnectionClose(t, "tcp", "localhost:0")
}
func TestPortForwardConnectionCloseUnix(t *testing.T) {
addr, cleanup := newTempSocket(t)
defer cleanup()
testPortForwardConnectionClose(t, "unix", addr)
}

View File

@ -266,3 +266,13 @@ func newServer(t *testing.T) *server {
}, },
} }
} }
func newTempSocket(t *testing.T) (string, func()) {
dir, err := ioutil.TempDir("", "socket")
if err != nil {
t.Fatal(err)
}
deferFunc := func() { os.RemoveAll(dir) }
addr := filepath.Join(dir, "sock")
return addr, deferFunc
}

View File

@ -48,6 +48,22 @@ AAAEAaYmXltfW6nhRo3iWGglRB48lYq0z0Q3I3KyrdutEr6j7d/uFLuDlRbBc4ZVOsx+Gb
HKuOrPtLHFvHsjWPwO+/AAAAE2dhcnRvbm1AZ2FydG9ubS14cHMBAg== HKuOrPtLHFvHsjWPwO+/AAAAE2dhcnRvbm1AZ2FydG9ubS14cHMBAg==
-----END OPENSSH PRIVATE KEY----- -----END OPENSSH PRIVATE KEY-----
`), `),
"rsa-openssh-format": []byte(`-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAlwAAAAdzc2gtcn
NhAAAAAwEAAQAAAIEAwa48yfWFi3uIdqzuf9X7C2Zxfea/Iaaw0zIwHudpF8U92WVIiC5l
oEuW1+OaVi3UWfIEjWMV1tHGysrHOwtwc34BPCJqJknUQO/KtDTBTJ4Pryhw1bWPC999Lz
a+yrCTdNQYBzoROXKExZgPFh9pTMi5wqpHDuOQ2qZFIEI3lT0AAAIQWL0H31i9B98AAAAH
c3NoLXJzYQAAAIEAwa48yfWFi3uIdqzuf9X7C2Zxfea/Iaaw0zIwHudpF8U92WVIiC5loE
uW1+OaVi3UWfIEjWMV1tHGysrHOwtwc34BPCJqJknUQO/KtDTBTJ4Pryhw1bWPC999Lza+
yrCTdNQYBzoROXKExZgPFh9pTMi5wqpHDuOQ2qZFIEI3lT0AAAADAQABAAAAgCThyTGsT4
IARDxVMhWl6eiB2ZrgFgWSeJm/NOqtppWgOebsIqPMMg4UVuVFsl422/lE3RkPhVkjGXgE
pWvZAdCnmLmApK8wK12vF334lZhZT7t3Z9EzJps88PWEHo7kguf285HcnUM7FlFeissJdk
kXly34y7/3X/a6Tclm+iABAAAAQE0xR/KxZ39slwfMv64Rz7WKk1PPskaryI29aHE3mKHk
pY2QA+P3QlrKxT/VWUMjHUbNNdYfJm48xu0SGNMRdKMAAABBAORh2NP/06JUV3J9W/2Hju
X1ViJuqqcQnJPVzpgSL826EC2xwOECTqoY8uvFpUdD7CtpksIxNVqRIhuNOlz0lqEAAABB
ANkaHTTaPojClO0dKJ/Zjs7pWOCGliebBYprQ/Y4r9QLBkC/XaWMS26gFIrjgC7D2Rv+rZ
wSD0v0RcmkITP1ZR0AAAAYcHF1ZXJuYUBMdWNreUh5ZHJvLmxvY2FsAQID
-----END OPENSSH PRIVATE KEY-----`),
"user": []byte(`-----BEGIN EC PRIVATE KEY----- "user": []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49 MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49
AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD

View File

@ -8,8 +8,13 @@ import (
"bufio" "bufio"
"errors" "errors"
"io" "io"
"log"
) )
// debugTransport if set, will print packet types as they go over the
// wire. No message decoding is done, to minimize the impact on timing.
const debugTransport = false
const ( const (
gcmCipherID = "aes128-gcm@openssh.com" gcmCipherID = "aes128-gcm@openssh.com"
aes128cbcID = "aes128-cbc" aes128cbcID = "aes128-cbc"
@ -22,7 +27,9 @@ type packetConn interface {
// Encrypt and send a packet of data to the remote peer. // Encrypt and send a packet of data to the remote peer.
writePacket(packet []byte) error writePacket(packet []byte) error
// Read a packet from the connection // Read a packet from the connection. The read is blocking,
// i.e. if error is nil, then the returned byte slice is
// always non-empty.
readPacket() ([]byte, error) readPacket() ([]byte, error)
// Close closes the write-side of the connection. // Close closes the write-side of the connection.
@ -38,7 +45,7 @@ type transport struct {
bufReader *bufio.Reader bufReader *bufio.Reader
bufWriter *bufio.Writer bufWriter *bufio.Writer
rand io.Reader rand io.Reader
isClient bool
io.Closer io.Closer
} }
@ -84,9 +91,38 @@ func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) err
return nil return nil
} }
func (t *transport) printPacket(p []byte, write bool) {
if len(p) == 0 {
return
}
who := "server"
if t.isClient {
who = "client"
}
what := "read"
if write {
what = "write"
}
log.Println(what, who, p[0])
}
// Read and decrypt next packet. // Read and decrypt next packet.
func (t *transport) readPacket() ([]byte, error) { func (t *transport) readPacket() (p []byte, err error) {
return t.reader.readPacket(t.bufReader) for {
p, err = t.reader.readPacket(t.bufReader)
if err != nil {
break
}
if len(p) == 0 || (p[0] != msgIgnore && p[0] != msgDebug) {
break
}
}
if debugTransport {
t.printPacket(p, false)
}
return p, err
} }
func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) { func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
@ -129,6 +165,9 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
} }
func (t *transport) writePacket(packet []byte) error { func (t *transport) writePacket(packet []byte) error {
if debugTransport {
t.printPacket(packet, true)
}
return t.writer.writePacket(t.bufWriter, t.rand, packet) return t.writer.writePacket(t.bufWriter, t.rand, packet)
} }
@ -169,6 +208,8 @@ func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transp
}, },
Closer: rwc, Closer: rwc,
} }
t.isClient = isClient
if isClient { if isClient {
t.reader.dir = serverKeys t.reader.dir = serverKeys
t.writer.dir = clientKeys t.writer.dir = clientKeys
@ -226,6 +267,7 @@ func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (pac
c := &streamPacketCipher{ c := &streamPacketCipher{
mac: macModes[algs.MAC].new(macKey), mac: macModes[algs.MAC].new(macKey),
etm: macModes[algs.MAC].etm,
} }
c.macResult = make([]byte, c.mac.Size()) c.macResult = make([]byte, c.mac.Size())