package winio import ( "fmt" "io" "net" "os" "syscall" "time" "unsafe" "github.com/Microsoft/go-winio/pkg/guid" ) //sys bind(s syscall.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind const ( afHvSock = 34 // AF_HYPERV socketError = ^uintptr(0) ) // An HvsockAddr is an address for a AF_HYPERV socket. type HvsockAddr struct { VMID guid.GUID ServiceID guid.GUID } type rawHvsockAddr struct { Family uint16 _ uint16 VMID guid.GUID ServiceID guid.GUID } // Network returns the address's network name, "hvsock". func (addr *HvsockAddr) Network() string { return "hvsock" } func (addr *HvsockAddr) String() string { return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID) } // VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port. func VsockServiceID(port uint32) guid.GUID { g, _ := guid.FromString("00000000-facb-11e6-bd58-64006a7986d3") g.Data1 = port return g } func (addr *HvsockAddr) raw() rawHvsockAddr { return rawHvsockAddr{ Family: afHvSock, VMID: addr.VMID, ServiceID: addr.ServiceID, } } func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) { addr.VMID = raw.VMID addr.ServiceID = raw.ServiceID } // HvsockListener is a socket listener for the AF_HYPERV address family. type HvsockListener struct { sock *win32File addr HvsockAddr } // HvsockConn is a connected socket of the AF_HYPERV address family. type HvsockConn struct { sock *win32File local, remote HvsockAddr } func newHvSocket() (*win32File, error) { fd, err := syscall.Socket(afHvSock, syscall.SOCK_STREAM, 1) if err != nil { return nil, os.NewSyscallError("socket", err) } f, err := makeWin32File(fd) if err != nil { syscall.Close(fd) return nil, err } f.socket = true return f, nil } // ListenHvsock listens for connections on the specified hvsock address. func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) { l := &HvsockListener{addr: *addr} sock, err := newHvSocket() if err != nil { return nil, l.opErr("listen", err) } sa := addr.raw() err = bind(sock.handle, unsafe.Pointer(&sa), int32(unsafe.Sizeof(sa))) if err != nil { return nil, l.opErr("listen", os.NewSyscallError("socket", err)) } err = syscall.Listen(sock.handle, 16) if err != nil { return nil, l.opErr("listen", os.NewSyscallError("listen", err)) } return &HvsockListener{sock: sock, addr: *addr}, nil } func (l *HvsockListener) opErr(op string, err error) error { return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err} } // Addr returns the listener's network address. func (l *HvsockListener) Addr() net.Addr { return &l.addr } // Accept waits for the next connection and returns it. func (l *HvsockListener) Accept() (_ net.Conn, err error) { sock, err := newHvSocket() if err != nil { return nil, l.opErr("accept", err) } defer func() { if sock != nil { sock.Close() } }() c, err := l.sock.prepareIo() if err != nil { return nil, l.opErr("accept", err) } defer l.sock.wg.Done() // AcceptEx, per documentation, requires an extra 16 bytes per address. const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{})) var addrbuf [addrlen * 2]byte var bytes uint32 err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0, addrlen, addrlen, &bytes, &c.o) _, err = l.sock.asyncIo(c, nil, bytes, err) if err != nil { return nil, l.opErr("accept", os.NewSyscallError("acceptex", err)) } conn := &HvsockConn{ sock: sock, } conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0]))) conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen]))) sock = nil return conn, nil } // Close closes the listener, causing any pending Accept calls to fail. func (l *HvsockListener) Close() error { return l.sock.Close() } /* Need to finish ConnectEx handling func DialHvsock(ctx context.Context, addr *HvsockAddr) (*HvsockConn, error) { sock, err := newHvSocket() if err != nil { return nil, err } defer func() { if sock != nil { sock.Close() } }() c, err := sock.prepareIo() if err != nil { return nil, err } defer sock.wg.Done() var bytes uint32 err = windows.ConnectEx(windows.Handle(sock.handle), sa, nil, 0, &bytes, &c.o) _, err = sock.asyncIo(ctx, c, nil, bytes, err) if err != nil { return nil, err } conn := &HvsockConn{ sock: sock, remote: *addr, } sock = nil return conn, nil } */ func (conn *HvsockConn) opErr(op string, err error) error { return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err} } func (conn *HvsockConn) Read(b []byte) (int, error) { c, err := conn.sock.prepareIo() if err != nil { return 0, conn.opErr("read", err) } defer conn.sock.wg.Done() buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))} var flags, bytes uint32 err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil) n, err := conn.sock.asyncIo(c, &conn.sock.readDeadline, bytes, err) if err != nil { if _, ok := err.(syscall.Errno); ok { err = os.NewSyscallError("wsarecv", err) } return 0, conn.opErr("read", err) } else if n == 0 { err = io.EOF } return n, err } func (conn *HvsockConn) Write(b []byte) (int, error) { t := 0 for len(b) != 0 { n, err := conn.write(b) if err != nil { return t + n, err } t += n b = b[n:] } return t, nil } func (conn *HvsockConn) write(b []byte) (int, error) { c, err := conn.sock.prepareIo() if err != nil { return 0, conn.opErr("write", err) } defer conn.sock.wg.Done() buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))} var bytes uint32 err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil) n, err := conn.sock.asyncIo(c, &conn.sock.writeDeadline, bytes, err) if err != nil { if _, ok := err.(syscall.Errno); ok { err = os.NewSyscallError("wsasend", err) } return 0, conn.opErr("write", err) } return n, err } // Close closes the socket connection, failing any pending read or write calls. func (conn *HvsockConn) Close() error { return conn.sock.Close() } func (conn *HvsockConn) shutdown(how int) error { err := syscall.Shutdown(conn.sock.handle, syscall.SHUT_RD) if err != nil { return os.NewSyscallError("shutdown", err) } return nil } // CloseRead shuts down the read end of the socket. func (conn *HvsockConn) CloseRead() error { err := conn.shutdown(syscall.SHUT_RD) if err != nil { return conn.opErr("close", err) } return nil } // CloseWrite shuts down the write end of the socket, notifying the other endpoint that // no more data will be written. func (conn *HvsockConn) CloseWrite() error { err := conn.shutdown(syscall.SHUT_WR) if err != nil { return conn.opErr("close", err) } return nil } // LocalAddr returns the local address of the connection. func (conn *HvsockConn) LocalAddr() net.Addr { return &conn.local } // RemoteAddr returns the remote address of the connection. func (conn *HvsockConn) RemoteAddr() net.Addr { return &conn.remote } // SetDeadline implements the net.Conn SetDeadline method. func (conn *HvsockConn) SetDeadline(t time.Time) error { conn.SetReadDeadline(t) conn.SetWriteDeadline(t) return nil } // SetReadDeadline implements the net.Conn SetReadDeadline method. func (conn *HvsockConn) SetReadDeadline(t time.Time) error { return conn.sock.SetReadDeadline(t) } // SetWriteDeadline implements the net.Conn SetWriteDeadline method. func (conn *HvsockConn) SetWriteDeadline(t time.Time) error { return conn.sock.SetWriteDeadline(t) }