mirror of
				https://github.com/go-gitea/gitea
				synced 2025-10-31 03:18:24 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			416 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Go
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			416 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Go
		
	
	
		
			Executable File
		
	
	
	
	
| // Copyright 2013 The Go Authors. All rights reserved.
 | |
| // Use of this source code is governed by a BSD-style
 | |
| // license that can be found in the LICENSE file.
 | |
| 
 | |
| package ssh
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"crypto/rand"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"net"
 | |
| 	"runtime"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"testing"
 | |
| )
 | |
| 
 | |
| type testChecker struct {
 | |
| 	calls []string
 | |
| }
 | |
| 
 | |
| func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
 | |
| 	if dialAddr == "bad" {
 | |
| 		return fmt.Errorf("dialAddr is bad")
 | |
| 	}
 | |
| 
 | |
| 	if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
 | |
| 		return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
 | |
| 	}
 | |
| 
 | |
| 	t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
 | |
| // therefore is buffered (net.Pipe deadlocks if both sides start with
 | |
| // a write.)
 | |
| func netPipe() (net.Conn, net.Conn, error) {
 | |
| 	listener, err := net.Listen("tcp", "127.0.0.1:0")
 | |
| 	if err != nil {
 | |
| 		return nil, nil, err
 | |
| 	}
 | |
| 	defer listener.Close()
 | |
| 	c1, err := net.Dial("tcp", listener.Addr().String())
 | |
| 	if err != nil {
 | |
| 		return nil, nil, err
 | |
| 	}
 | |
| 
 | |
| 	c2, err := listener.Accept()
 | |
| 	if err != nil {
 | |
| 		c1.Close()
 | |
| 		return nil, nil, err
 | |
| 	}
 | |
| 
 | |
| 	return c1, c2, nil
 | |
| }
 | |
| 
 | |
| func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) {
 | |
| 	a, b, err := netPipe()
 | |
| 	if err != nil {
 | |
| 		return nil, nil, err
 | |
| 	}
 | |
| 
 | |
| 	trC := newTransport(a, rand.Reader, true)
 | |
| 	trS := newTransport(b, rand.Reader, false)
 | |
| 	clientConf.SetDefaults()
 | |
| 
 | |
| 	v := []byte("version")
 | |
| 	client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
 | |
| 
 | |
| 	serverConf := &ServerConfig{}
 | |
| 	serverConf.AddHostKey(testSigners["ecdsa"])
 | |
| 	serverConf.AddHostKey(testSigners["rsa"])
 | |
| 	serverConf.SetDefaults()
 | |
| 	server = newServerTransport(trS, v, v, serverConf)
 | |
| 
 | |
| 	return client, server, nil
 | |
| }
 | |
| 
 | |
| func TestHandshakeBasic(t *testing.T) {
 | |
| 	if runtime.GOOS == "plan9" {
 | |
| 		t.Skip("see golang.org/issue/7237")
 | |
| 	}
 | |
| 	checker := &testChecker{}
 | |
| 	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("handshakePair: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	defer trC.Close()
 | |
| 	defer trS.Close()
 | |
| 
 | |
| 	go func() {
 | |
| 		// Client writes a bunch of stuff, and does a key
 | |
| 		// change in the middle. This should not confuse the
 | |
| 		// handshake in progress
 | |
| 		for i := 0; i < 10; i++ {
 | |
| 			p := []byte{msgRequestSuccess, byte(i)}
 | |
| 			if err := trC.writePacket(p); err != nil {
 | |
| 				t.Fatalf("sendPacket: %v", err)
 | |
| 			}
 | |
| 			if i == 5 {
 | |
| 				// halfway through, we request a key change.
 | |
| 				_, _, err := trC.sendKexInit()
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("sendKexInit: %v", err)
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 		trC.Close()
 | |
| 	}()
 | |
| 
 | |
| 	// Server checks that client messages come in cleanly
 | |
| 	i := 0
 | |
| 	for {
 | |
| 		p, err := trS.readPacket()
 | |
| 		if err != nil {
 | |
| 			break
 | |
| 		}
 | |
| 		if p[0] == msgNewKeys {
 | |
| 			continue
 | |
| 		}
 | |
| 		want := []byte{msgRequestSuccess, byte(i)}
 | |
| 		if bytes.Compare(p, want) != 0 {
 | |
| 			t.Errorf("message %d: got %q, want %q", i, p, want)
 | |
| 		}
 | |
| 		i++
 | |
| 	}
 | |
| 	if i != 10 {
 | |
| 		t.Errorf("received %d messages, want 10.", i)
 | |
| 	}
 | |
| 
 | |
| 	// If all went well, we registered exactly 1 key change.
 | |
| 	if len(checker.calls) != 1 {
 | |
| 		t.Fatalf("got %d host key checks, want 1", len(checker.calls))
 | |
| 	}
 | |
| 
 | |
| 	pub := testSigners["ecdsa"].PublicKey()
 | |
| 	want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())
 | |
| 	if want != checker.calls[0] {
 | |
| 		t.Errorf("got %q want %q for host key check", checker.calls[0], want)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestHandshakeError(t *testing.T) {
 | |
| 	checker := &testChecker{}
 | |
| 	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad")
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("handshakePair: %v", err)
 | |
| 	}
 | |
| 	defer trC.Close()
 | |
| 	defer trS.Close()
 | |
| 
 | |
| 	// send a packet
 | |
| 	packet := []byte{msgRequestSuccess, 42}
 | |
| 	if err := trC.writePacket(packet); err != nil {
 | |
| 		t.Errorf("writePacket: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// Now request a key change.
 | |
| 	_, _, err = trC.sendKexInit()
 | |
| 	if err != nil {
 | |
| 		t.Errorf("sendKexInit: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// the key change will fail, and afterwards we can't write.
 | |
| 	if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {
 | |
| 		t.Errorf("writePacket after botched rekey succeeded.")
 | |
| 	}
 | |
| 
 | |
| 	readback, err := trS.readPacket()
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("server closed too soon: %v", err)
 | |
| 	}
 | |
| 	if bytes.Compare(readback, packet) != 0 {
 | |
| 		t.Errorf("got %q want %q", readback, packet)
 | |
| 	}
 | |
| 	readback, err = trS.readPacket()
 | |
| 	if err == nil {
 | |
| 		t.Errorf("got a message %q after failed key change", readback)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestHandshakeTwice(t *testing.T) {
 | |
| 	checker := &testChecker{}
 | |
| 	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("handshakePair: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	defer trC.Close()
 | |
| 	defer trS.Close()
 | |
| 
 | |
| 	// send a packet
 | |
| 	packet := make([]byte, 5)
 | |
| 	packet[0] = msgRequestSuccess
 | |
| 	if err := trC.writePacket(packet); err != nil {
 | |
| 		t.Errorf("writePacket: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// Now request a key change.
 | |
| 	_, _, err = trC.sendKexInit()
 | |
| 	if err != nil {
 | |
| 		t.Errorf("sendKexInit: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// Send another packet. Use a fresh one, since writePacket destroys.
 | |
| 	packet = make([]byte, 5)
 | |
| 	packet[0] = msgRequestSuccess
 | |
| 	if err := trC.writePacket(packet); err != nil {
 | |
| 		t.Errorf("writePacket: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// 2nd key change.
 | |
| 	_, _, err = trC.sendKexInit()
 | |
| 	if err != nil {
 | |
| 		t.Errorf("sendKexInit: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	packet = make([]byte, 5)
 | |
| 	packet[0] = msgRequestSuccess
 | |
| 	if err := trC.writePacket(packet); err != nil {
 | |
| 		t.Errorf("writePacket: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	packet = make([]byte, 5)
 | |
| 	packet[0] = msgRequestSuccess
 | |
| 	for i := 0; i < 5; i++ {
 | |
| 		msg, err := trS.readPacket()
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("server closed too soon: %v", err)
 | |
| 		}
 | |
| 		if msg[0] == msgNewKeys {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		if bytes.Compare(msg, packet) != 0 {
 | |
| 			t.Errorf("packet %d: got %q want %q", i, msg, packet)
 | |
| 		}
 | |
| 	}
 | |
| 	if len(checker.calls) != 2 {
 | |
| 		t.Errorf("got %d key changes, want 2", len(checker.calls))
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestHandshakeAutoRekeyWrite(t *testing.T) {
 | |
| 	checker := &testChecker{}
 | |
| 	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
 | |
| 	clientConf.RekeyThreshold = 500
 | |
| 	trC, trS, err := handshakePair(clientConf, "addr")
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("handshakePair: %v", err)
 | |
| 	}
 | |
| 	defer trC.Close()
 | |
| 	defer trS.Close()
 | |
| 
 | |
| 	for i := 0; i < 5; i++ {
 | |
| 		packet := make([]byte, 251)
 | |
| 		packet[0] = msgRequestSuccess
 | |
| 		if err := trC.writePacket(packet); err != nil {
 | |
| 			t.Errorf("writePacket: %v", err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	j := 0
 | |
| 	for ; j < 5; j++ {
 | |
| 		_, err := trS.readPacket()
 | |
| 		if err != nil {
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if j != 5 {
 | |
| 		t.Errorf("got %d, want 5 messages", j)
 | |
| 	}
 | |
| 
 | |
| 	if len(checker.calls) != 2 {
 | |
| 		t.Errorf("got %d key changes, wanted 2", len(checker.calls))
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type syncChecker struct {
 | |
| 	called chan int
 | |
| }
 | |
| 
 | |
| func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
 | |
| 	t.called <- 1
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func TestHandshakeAutoRekeyRead(t *testing.T) {
 | |
| 	sync := &syncChecker{make(chan int, 2)}
 | |
| 	clientConf := &ClientConfig{
 | |
| 		HostKeyCallback: sync.Check,
 | |
| 	}
 | |
| 	clientConf.RekeyThreshold = 500
 | |
| 
 | |
| 	trC, trS, err := handshakePair(clientConf, "addr")
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("handshakePair: %v", err)
 | |
| 	}
 | |
| 	defer trC.Close()
 | |
| 	defer trS.Close()
 | |
| 
 | |
| 	packet := make([]byte, 501)
 | |
| 	packet[0] = msgRequestSuccess
 | |
| 	if err := trS.writePacket(packet); err != nil {
 | |
| 		t.Fatalf("writePacket: %v", err)
 | |
| 	}
 | |
| 	// While we read out the packet, a key change will be
 | |
| 	// initiated.
 | |
| 	if _, err := trC.readPacket(); err != nil {
 | |
| 		t.Fatalf("readPacket(client): %v", err)
 | |
| 	}
 | |
| 
 | |
| 	<-sync.called
 | |
| }
 | |
| 
 | |
| // errorKeyingTransport generates errors after a given number of
 | |
| // read/write operations.
 | |
| type errorKeyingTransport struct {
 | |
| 	packetConn
 | |
| 	readLeft, writeLeft int
 | |
| }
 | |
| 
 | |
| func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
 | |
| 	return nil
 | |
| }
 | |
| func (n *errorKeyingTransport) getSessionID() []byte {
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (n *errorKeyingTransport) writePacket(packet []byte) error {
 | |
| 	if n.writeLeft == 0 {
 | |
| 		n.Close()
 | |
| 		return errors.New("barf")
 | |
| 	}
 | |
| 
 | |
| 	n.writeLeft--
 | |
| 	return n.packetConn.writePacket(packet)
 | |
| }
 | |
| 
 | |
| func (n *errorKeyingTransport) readPacket() ([]byte, error) {
 | |
| 	if n.readLeft == 0 {
 | |
| 		n.Close()
 | |
| 		return nil, errors.New("barf")
 | |
| 	}
 | |
| 
 | |
| 	n.readLeft--
 | |
| 	return n.packetConn.readPacket()
 | |
| }
 | |
| 
 | |
| func TestHandshakeErrorHandlingRead(t *testing.T) {
 | |
| 	for i := 0; i < 20; i++ {
 | |
| 		testHandshakeErrorHandlingN(t, i, -1)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestHandshakeErrorHandlingWrite(t *testing.T) {
 | |
| 	for i := 0; i < 20; i++ {
 | |
| 		testHandshakeErrorHandlingN(t, -1, i)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // testHandshakeErrorHandlingN runs handshakes, injecting errors. If
 | |
| // handshakeTransport deadlocks, the go runtime will detect it and
 | |
| // panic.
 | |
| func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {
 | |
| 	msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
 | |
| 
 | |
| 	a, b := memPipe()
 | |
| 	defer a.Close()
 | |
| 	defer b.Close()
 | |
| 
 | |
| 	key := testSigners["ecdsa"]
 | |
| 	serverConf := Config{RekeyThreshold: minRekeyThreshold}
 | |
| 	serverConf.SetDefaults()
 | |
| 	serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
 | |
| 	serverConn.hostKeys = []Signer{key}
 | |
| 	go serverConn.readLoop()
 | |
| 
 | |
| 	clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
 | |
| 	clientConf.SetDefaults()
 | |
| 	clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
 | |
| 	clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
 | |
| 	go clientConn.readLoop()
 | |
| 
 | |
| 	var wg sync.WaitGroup
 | |
| 	wg.Add(4)
 | |
| 
 | |
| 	for _, hs := range []packetConn{serverConn, clientConn} {
 | |
| 		go func(c packetConn) {
 | |
| 			for {
 | |
| 				err := c.writePacket(msg)
 | |
| 				if err != nil {
 | |
| 					break
 | |
| 				}
 | |
| 			}
 | |
| 			wg.Done()
 | |
| 		}(hs)
 | |
| 		go func(c packetConn) {
 | |
| 			for {
 | |
| 				_, err := c.readPacket()
 | |
| 				if err != nil {
 | |
| 					break
 | |
| 				}
 | |
| 			}
 | |
| 			wg.Done()
 | |
| 		}(hs)
 | |
| 	}
 | |
| 
 | |
| 	wg.Wait()
 | |
| }
 |