diff --git a/go.mod b/go.mod index a75ae73..d0c4c68 100644 --- a/go.mod +++ b/go.mod @@ -5,5 +5,6 @@ go 1.14 require ( github.com/huin/goutil v0.0.0-20170803182201-1ca381bf3150 golang.org/x/net v0.0.0-20181011144130-49bb7cea24b1 + golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a golang.org/x/text v0.3.0 // indirect ) diff --git a/go.sum b/go.sum index 3e75869..43cd370 100644 --- a/go.sum +++ b/go.sum @@ -2,5 +2,7 @@ github.com/huin/goutil v0.0.0-20170803182201-1ca381bf3150 h1:vlNjIqmUZ9CMAWsbURY github.com/huin/goutil v0.0.0-20170803182201-1ca381bf3150/go.mod h1:PpLOETDnJ0o3iZrZfqZzyLl6l7F3c6L1oWn7OICBi6o= golang.org/x/net v0.0.0-20181011144130-49bb7cea24b1 h1:Y/KGZSOdz/2r0WJ9Mkmz6NJBusp0kiNx1Cn82lzJQ6w= golang.org/x/net v0.0.0-20181011144130-49bb7cea24b1/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a h1:WXEvlFVvvGxCJLG6REjsT03iWnKLEWinaScsxF2Vm2o= +golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/goupnp.go b/goupnp.go index eb92267..23027f5 100644 --- a/goupnp.go +++ b/goupnp.go @@ -21,10 +21,8 @@ import ( "net/url" "time" - "golang.org/x/net/html/charset" - - "github.com/huin/goupnp/httpu" "github.com/huin/goupnp/ssdp" + "golang.org/x/net/html/charset" ) // ContextError is an error that wraps an error with some context information. @@ -33,6 +31,20 @@ type ContextError struct { Err error } +func ctxError(err error, msg string) ContextError { + return ContextError{ + Context: msg, + Err: err, + } +} + +func ctxErrorf(err error, msg string, args ...interface{}) ContextError { + return ContextError{ + Context: fmt.Sprintf(msg, args...), + Err: err, + } +} + func (err ContextError) Error() string { return fmt.Sprintf("%s: %v", err.Context, err.Err) } @@ -61,12 +73,12 @@ type MaybeRootDevice struct { // while attempting to send the query. An error or RootDevice is returned for // each discovered RootDevice. func DiscoverDevices(searchTarget string) ([]MaybeRootDevice, error) { - httpu, err := httpu.NewHTTPUClient() + hc, hcCleanup, err := httpuClient() if err != nil { return nil, err } - defer httpu.Close() - responses, err := ssdp.SSDPRawSearch(httpu, string(searchTarget), 2, 3) + defer hcCleanup() + responses, err := ssdp.SSDPRawSearch(hc, string(searchTarget), 2, 3) if err != nil { return nil, err } diff --git a/httpu/httpu.go b/httpu/httpu.go index 44b0c58..3367c86 100644 --- a/httpu/httpu.go +++ b/httpu/httpu.go @@ -12,6 +12,20 @@ import ( "time" ) +// ClientInterface is the general interface provided to perform HTTP-over-UDP +// requests. +type ClientInterface interface { + // Do performs a request. The timeout is how long to wait for before returning + // the responses that were received. An error is only returned for failing to + // send the request. Failures in receipt simply do not add to the resulting + // responses. + Do( + req *http.Request, + timeout time.Duration, + numSends int, + ) ([]*http.Response, error) +} + // HTTPUClient is a client for dealing with HTTPU (HTTP over UDP). Its typical // function is for HTTPMU, and particularly SSDP. type HTTPUClient struct { @@ -19,6 +33,8 @@ type HTTPUClient struct { conn net.PacketConn } +var _ ClientInterface = &HTTPUClient{} + // NewHTTPUClient creates a new HTTPUClient, opening up a new UDP socket for the // purpose. func NewHTTPUClient() (*HTTPUClient, error) { @@ -51,14 +67,15 @@ func (httpu *HTTPUClient) Close() error { return httpu.conn.Close() } -// Do performs a request. The timeout is how long to wait for before returning -// the responses that were received. An error is only returned for failing to -// send the request. Failures in receipt simply do not add to the resulting -// responses. +// Do implements ClientInterface.Do. // // Note that at present only one concurrent connection will happen per // HTTPUClient. -func (httpu *HTTPUClient) Do(req *http.Request, timeout time.Duration, numSends int) ([]*http.Response, error) { +func (httpu *HTTPUClient) Do( + req *http.Request, + timeout time.Duration, + numSends int, +) ([]*http.Response, error) { httpu.connLock.Lock() defer httpu.connLock.Unlock() diff --git a/httpu/multiclient.go b/httpu/multiclient.go new file mode 100644 index 0000000..463ab7a --- /dev/null +++ b/httpu/multiclient.go @@ -0,0 +1,70 @@ +package httpu + +import ( + "net/http" + "time" + + "golang.org/x/sync/errgroup" +) + +// MultiClient dispatches requests out to all the delegated clients. +type MultiClient struct { + // The HTTPU clients to delegate to. + delegates []ClientInterface +} + +var _ ClientInterface = &MultiClient{} + +// NewMultiClient creates a new MultiClient that delegates to all the given +// clients. +func NewMultiClient(delegates []ClientInterface) *MultiClient { + return &MultiClient{ + delegates: delegates, + } +} + +// Do implements ClientInterface.Do. +func (mc *MultiClient) Do( + req *http.Request, + timeout time.Duration, + numSends int, +) ([]*http.Response, error) { + tasks := &errgroup.Group{} + + results := make(chan []*http.Response) + tasks.Go(func() error { + defer close(results) + return mc.sendRequests(results, req, timeout, numSends) + }) + + var responses []*http.Response + tasks.Go(func() error { + for rs := range results { + responses = append(responses, rs...) + } + return nil + }) + + return responses, tasks.Wait() +} + +func (mc *MultiClient) sendRequests( + results chan<-[]*http.Response, + req *http.Request, + timeout time.Duration, + numSends int, +) error { + tasks := &errgroup.Group{} + for _, d := range mc.delegates { + d := d // copy for closure + tasks.Go(func() error { + responses, err := d.Do(req, timeout, numSends) + if err != nil { + return err + } + results <- responses + return nil + }) + } + return tasks.Wait() +} diff --git a/network.go b/network.go new file mode 100644 index 0000000..a3c2265 --- /dev/null +++ b/network.go @@ -0,0 +1,75 @@ +package goupnp + +import ( + "io" + "net" + + "github.com/huin/goupnp/httpu" +) + +// httpuClient creates a HTTPU client that multiplexes to all multicast-capable +// IPv4 addresses on the host. Returns a function to clean up once the client is +// no longer required. +func httpuClient() (httpu.ClientInterface, func(), error) { + addrs, err := localIPv4MCastAddrs() + if err != nil { + return nil, nil, ctxError(err, "requesting host IPv4 addresses") + } + + closers := make([]io.Closer, 0, len(addrs)) + delegates := make([]httpu.ClientInterface, 0, len(addrs)) + for _, addr := range addrs { + c, err := httpu.NewHTTPUClientAddr(addr) + if err != nil { + return nil, nil, ctxErrorf(err, + "creating HTTPU client for address %s", addr) + } + closers = append(closers, c) + delegates = append(delegates, c) + } + + closer := func() { + for _, c := range closers { + c.Close() + } + } + + return httpu.NewMultiClient(delegates), closer, nil +} + +// localIPv2MCastAddrs returns the set of IPv4 addresses on multicast-able +// network interfaces. +func localIPv4MCastAddrs() ([]string, error) { + ifaces, err := net.Interfaces() + if err != nil { + return nil, ctxError(err, "requesting host interfaces") + } + + // Find the set of addresses to listen on. + var addrs []string + for _, iface := range ifaces { + if iface.Flags&net.FlagMulticast == 0 { + // Does not support multicast. + continue + } + ifaceAddrs, err := iface.Addrs() + if err != nil { + return nil, ctxErrorf(err, + "finding addresses on interface %s", iface.Name) + } + for _, netAddr := range ifaceAddrs { + addr, ok := netAddr.(*net.IPNet) + if !ok { + // Not an IPNet address. + continue + } + if addr.IP.To4() == nil { + // Not IPv4. + continue + } + addrs = append(addrs, addr.IP.String()) + } + } + + return addrs, nil +} diff --git a/ssdp/ssdp.go b/ssdp/ssdp.go index 4c03b25..9279363 100644 --- a/ssdp/ssdp.go +++ b/ssdp/ssdp.go @@ -7,8 +7,6 @@ import ( "net/url" "strconv" "time" - - "github.com/huin/goupnp/httpu" ) const ( @@ -27,6 +25,15 @@ const ( UPNPRootDevice = "upnp:rootdevice" ) +// HTTPUClient is the interface required to perform HTTP-over-UDP requests. +type HTTPUClient interface { + Do( + req *http.Request, + timeout time.Duration, + numSends int, + ) ([]*http.Response, error) +} + // SSDPRawSearch performs a fairly raw SSDP search request, and returns the // unique response(s) that it receives. Each response has the requested // searchTarget, a USN, and a valid location. maxWaitSeconds states how long to @@ -34,13 +41,16 @@ const ( // implementation waits an additional 100ms for responses to arrive), 2 is a // reasonable value for this. numSends is the number of requests to send - 3 is // a reasonable value for this. -func SSDPRawSearch(httpu *httpu.HTTPUClient, searchTarget string, maxWaitSeconds int, numSends int) ([]*http.Response, error) { +func SSDPRawSearch( + httpu HTTPUClient, + searchTarget string, + maxWaitSeconds int, + numSends int, +) ([]*http.Response, error) { if maxWaitSeconds < 1 { return nil, errors.New("ssdp: maxWaitSeconds must be >= 1") } - seenUsns := make(map[string]bool) - var responses []*http.Response req := http.Request{ Method: methodSearch, // TODO: Support both IPv4 and IPv6. @@ -62,6 +72,8 @@ func SSDPRawSearch(httpu *httpu.HTTPUClient, searchTarget string, maxWaitSeconds isExactSearch := searchTarget != SSDPAll && searchTarget != UPNPRootDevice + seenUSNs := make(map[string]bool) + var responses []*http.Response for _, response := range allResponses { if response.StatusCode != 200 { log.Printf("ssdp: got response status code %q in search response", response.Status) @@ -70,18 +82,18 @@ func SSDPRawSearch(httpu *httpu.HTTPUClient, searchTarget string, maxWaitSeconds if st := response.Header.Get("ST"); isExactSearch && st != searchTarget { continue } - location, err := response.Location() - if err != nil { - log.Printf("ssdp: no usable location in search response (discarding): %v", err) - continue - } usn := response.Header.Get("USN") if usn == "" { - log.Printf("ssdp: empty/missing USN in search response (using location instead): %v", err) + // Empty/missing USN in search response - using location instead. + location, err := response.Location() + if err != nil { + // No usable location in search response - discard. + continue + } usn = location.String() } - if _, alreadySeen := seenUsns[usn]; !alreadySeen { - seenUsns[usn] = true + if _, alreadySeen := seenUSNs[usn]; !alreadySeen { + seenUSNs[usn] = true responses = append(responses, response) } }