diff --git a/goupnp.go b/goupnp.go index 6bdc9ec..a489f4f 100644 --- a/goupnp.go +++ b/goupnp.go @@ -2,112 +2,121 @@ package goupnp import ( "encoding/xml" - "errors" "fmt" "log" - "io" - "os" "net/http" - "net/url" ) const ( - ssdpUDP4Addr = "239.255.255.250:1900" - - methodSearch = "M-SEARCH" // Search Target for InternetGatewayDevice. - stIgd = "urn:schemas-upnp-org:device:InternetGatewayDevice:1" - hdrMan = `"ssdp:discover"` + SearchTargetIGD = "urn:schemas-upnp-org:device:InternetGatewayDevice:1" ) // DiscoverIGD attempts to find Internet Gateway Devices. // // TODO: Fix implementation to discover multiple. Currently it will find a // maximum of one. -func DiscoverIGD() ([]IGD, error) { - hc := http.Client{ - Transport: udpRoundTripper{}, - CheckRedirect: func(r *http.Request, via []*http.Request) error { - return errors.New("goupnp: unexpected HTTP redirect") - }, - Jar: nil, - } - - request := http.Request{ - Method: methodSearch, - // TODO: Support both IPv4 and IPv6. - Host: ssdpUDP4Addr, - URL: &url.URL{Opaque: "*"}, - Header: http.Header{ - // Putting headers in here avoids them being title-cased. - // (The UPnP discovery protocol uses case-sensitive headers) - "HOST": []string{ssdpUDP4Addr}, - "MX": []string{"2"}, // TODO: Variable max wait time. - "MAN": []string{hdrMan}, - "ST": []string{stIgd}, - }, - } - - response, err := hc.Do(&request) +func DiscoverIGD() ([]*IGD, error) { + httpu, err := NewHTTPUClient() if err != nil { return nil, err } + responses, err := SSDPRawSearch(httpu, SearchTargetIGD, 2, 3) - // Any errors past this point are simply "no result found". We log the - // errors, but report no results. In a future version of this implementation, - // multiple *good* results can be returned. - - if response.StatusCode != 200 { - log.Printf("goupnp: response code %d %q from UPnP discovery", - response.StatusCode, response.Status) - return nil, nil - } - if st := response.Header.Get("ST"); st != stIgd { - log.Printf("goupnp: got unexpected search target result %q", st) - return nil, nil + results := make([]*IGD, 0, len(responses)) + for _, response := range responses { + loc, err := response.Location() + if err != nil { + log.Printf("goupnp: unexpected bad location from search: %v", err) + continue + } + igd, err := requestIgd(loc.String()) + if err != nil { + log.Printf("goupnp: error requesting IGD: %v", err) + continue + } + results = append(results, igd) } - location, err := response.Location() - if err != nil { - log.Printf("goupnp: missing location in response") - return nil, nil - } - - igd, err := requestIgd(location.String()) - if err != nil { - log.Printf("goupnp: error requesting IGD: %v", err) - return nil, nil - } - - return []IGD{igd}, nil + return results, nil } // IGD defines the interface for an Internet Gateway Device. -type IGD interface { +type IGD struct { + xml xmlRootDevice } -type igd struct { - serviceUrl string -} - -func requestIgd(serviceUrl string) (IGD, error) { +func requestIgd(serviceUrl string) (*IGD, error) { resp, err := http.Get(serviceUrl) if err != nil { return nil, err } defer resp.Body.Close() - decoder := xml.NewDecoder(io.TeeReader(resp.Body, os.Stdout)) + if resp.StatusCode != 200 { + return nil, fmt.Errorf("goupnp: got response status %s from IGD at %q", + resp.Status, serviceUrl) + } + + decoder := xml.NewDecoder(resp.Body) decoder.DefaultSpace = deviceXmlNs - var root xmlRootDevice - if err = decoder.Decode(&root); err != nil { + var xml xmlRootDevice + if err = decoder.Decode(&xml); err != nil { return nil, err } - log.Printf("%+v", root) + log.Printf("%+v", xml) - return igd{serviceUrl}, nil + return &IGD{xml}, nil } -func (device *igd) String() string { - return fmt.Sprintf("goupnp.IGD @ %s", device.serviceUrl) +func (igd *IGD) Device() *Device { + return &Device{ + igd.xml.URLBase, + igd.xml.Device, + } +} + +func (igd *IGD) String() string { + return fmt.Sprintf("IGD{UDN: %q friendlyName: %q}", + igd.xml.Device.UDN, igd.xml.Device.FriendlyName) +} + +type Device struct { + urlBase string + xml xmlDevice +} + +func (device *Device) String() string { + return fmt.Sprintf("Device{friendlyName: %q}", device.xml.FriendlyName) +} + +func (device *Device) Devices() []*Device { + devices := make([]*Device, len(device.xml.Devices)) + for i, childXml := range device.xml.Devices { + devices[i] = &Device{ + device.urlBase, + childXml, + } + } + return devices +} + +func (device *Device) Services() []*Service { + srvs := make([]*Service, len(device.xml.Services)) + for i, childXml := range device.xml.Services { + srvs[i] = &Service{ + device.urlBase, + childXml, + } + } + return srvs +} + +type Service struct { + urlBase string + xml xmlService +} + +func (srv *Service) String() string { + return fmt.Sprintf("Service{serviceType: %q}", srv.xml.ServiceType) } diff --git a/httpu.go b/httpu.go new file mode 100644 index 0000000..cd332c9 --- /dev/null +++ b/httpu.go @@ -0,0 +1,98 @@ +package goupnp + +import ( + "bufio" + "bytes" + "fmt" + "log" + "net" + "net/http" + "sync" + "time" +) + +// HTTPUClient is a client for dealing with HTTPU (HTTP over UDP). Its typical +// function is for HTTPMU, and particularly SSDP. +type HTTPUClient struct { + connLock sync.Mutex // Protects use of conn. + conn net.PacketConn +} + +// NewHTTPUClient creates a new HTTPUClient, opening up a new UDP socket for the +// purpose. +func NewHTTPUClient() (*HTTPUClient, error) { + conn, err := net.ListenPacket("udp", ":0") + if err != nil { + return nil, err + } + return &HTTPUClient{conn: conn}, nil +} + +// Close shuts down the client. The client will no longer be useful following +// this. +func (httpu *HTTPUClient) Close() error { + httpu.connLock.Lock() + defer httpu.connLock.Unlock() + return httpu.conn.Close() +} + +// Do performs a request. The timeout is how long to wait for before returning +// the responses that were received. An error is only returned for failing to +// send the request. Failures in receipt simply do not add to the resulting +// responses. +// +// Note that at present only one concurrent connection will happen per +// HTTPUClient. +func (httpu *HTTPUClient) Do(req *http.Request, timeout time.Duration, numSends int) ([]*http.Response, error) { + httpu.connLock.Lock() + defer httpu.connLock.Unlock() + + var requestBuf bytes.Buffer + if err := req.Write(&requestBuf); err != nil { + return nil, err + } + destAddr, err := net.ResolveUDPAddr("udp", req.Host) + if err != nil { + return nil, err + } + if err = httpu.conn.SetDeadline(time.Now().Add(timeout)); err != nil { + return nil, err + } + + // Send request. + for i := 0; i < numSends; i++ { + if n, err := httpu.conn.WriteTo(requestBuf.Bytes(), destAddr); err != nil { + return nil, err + } else if n < len(requestBuf.Bytes()) { + return nil, fmt.Errorf("httpu: wrote %d bytes rather than full %d in request", + n, len(requestBuf.Bytes())) + } + } + + // Await responses until timeout. + var responses []*http.Response + responseBytes := make([]byte, 2048) + for { + // 2048 bytes should be sufficient for most networks. + n, _, err := httpu.conn.ReadFrom(responseBytes) + if err != nil { + if err, ok := err.(net.Error); ok && err.Timeout() { + break + } + log.Print("httpu: error while receiving response: %v", err) + // Sleep in case this is a persistent error to avoid pegging CPU until deadline. + time.Sleep(10 * time.Millisecond) + continue + } + + // Parse response. + response, err := http.ReadResponse(bufio.NewReader(bytes.NewBuffer(responseBytes[:n])), req) + if err != nil { + log.Print("httpu: error while parsing response: %v", err) + continue + } + + responses = append(responses, response) + } + return responses, err +} diff --git a/ssdp.go b/ssdp.go new file mode 100644 index 0000000..9e0fe05 --- /dev/null +++ b/ssdp.go @@ -0,0 +1,76 @@ +package goupnp + +import ( + "errors" + "log" + "net/http" + "net/url" + "strconv" + "time" +) + +const ( + ssdpDiscover = `"ssdp:discover"` + methodSearch = "M-SEARCH" + ssdpUDP4Addr = "239.255.255.250:1900" +) + +// SSDPRawSearch performs a fairly raw SSDP search request, and returns the +// unique response(s) that it receives. Each response has the requested +// searchTarget, a USN, and a valid location. maxWaitSeconds states how long to +// wait for responses in seconds, and must be a minimum of 1 (the +// implementation waits an additional 100ms for responses to arrive), 2 is a +// reasonable value for this. numSends is the number of requests to send - 3 is +// a reasonable value for this. +func SSDPRawSearch(httpu *HTTPUClient, searchTarget string, maxWaitSeconds int, numSends int) ([]*http.Response, error) { + if maxWaitSeconds < 1 { + return nil, errors.New("ssdp: maxWaitSeconds must be >= 1") + } + + seenUsns := make(map[string]bool) + var responses []*http.Response + req := http.Request{ + Method: methodSearch, + // TODO: Support both IPv4 and IPv6. + Host: ssdpUDP4Addr, + URL: &url.URL{Opaque: "*"}, + Header: http.Header{ + // Putting headers in here avoids them being title-cased. + // (The UPnP discovery protocol uses case-sensitive headers) + "HOST": []string{ssdpUDP4Addr}, + "MX": []string{strconv.FormatInt(int64(maxWaitSeconds), 10)}, + "MAN": []string{ssdpDiscover}, + "ST": []string{searchTarget}, + }, + } + allResponses, err := httpu.Do(&req, time.Duration(maxWaitSeconds)*time.Second+100*time.Millisecond, numSends) + if err != nil { + return nil, err + } + for _, response := range allResponses { + if response.StatusCode != 200 { + log.Printf("ssdp: got response status code %q in search response", response.Status) + continue + } + if st := response.Header.Get("ST"); st != searchTarget { + log.Printf("ssdp: got unexpected search target result %q", st) + return nil, nil + } + location, err := response.Location() + if err != nil { + log.Printf("ssdp: no usable location in search response (discarding): %v", err) + continue + } + usn := response.Header.Get("USN") + if usn == "" { + log.Printf("ssdp: empty/missing USN in search response (using location instead): %v", err) + usn = location.String() + } + if _, alreadySeen := seenUsns[usn]; !alreadySeen { + seenUsns[usn] = true + responses = append(responses, response) + } + } + + return responses, nil +} diff --git a/udproundtripper.go b/udproundtripper.go deleted file mode 100644 index ceeeaa8..0000000 --- a/udproundtripper.go +++ /dev/null @@ -1,69 +0,0 @@ -package goupnp - -import ( - "bufio" - "bytes" - "fmt" - "net" - "net/http" - "time" -) - -// TODO: RoundTripper is probably the wrong interface, as there could be -// multiple responses to a request. - -type udpRoundTripper struct { - // If zero, defaults to 3 second deadline (a zero deadline makes no sense). - Deadline time.Duration - MaxWaitSeconds int -} - -func (urt udpRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { - conn, err := net.ListenPacket("udp", ":0") - if err != nil { - return nil, err - } - defer conn.Close() - - var requestBuf bytes.Buffer - if err := r.Write(&requestBuf); err != nil { - return nil, err - } - destAddr, err := net.ResolveUDPAddr("udp", r.Host) - if err != nil { - return nil, err - } - - deadline := urt.Deadline - if urt.Deadline == 0 { - deadline = 3 * time.Second - } - - if err = conn.SetDeadline(time.Now().Add(deadline)); err != nil { - return nil, err - } - - // Send request. - if n, err := conn.WriteTo(requestBuf.Bytes(), destAddr); err != nil { - return nil, err - } else if n < len(requestBuf.Bytes()) { - return nil, fmt.Errorf("goupnp: wrote %d bytes rather than full %d in request", - n, len(requestBuf.Bytes())) - } - - // Await response. - responseBytes := make([]byte, 2048) - n, _, err := conn.ReadFrom(responseBytes) - if err != nil { - return nil, err - } - responseBytes = responseBytes[:n] - - // Parse response. - response, err := http.ReadResponse(bufio.NewReader(bytes.NewBuffer(responseBytes)), r) - if err != nil { - return nil, err - } - - return response, err -} diff --git a/xml.go b/xml.go index 5300a78..1403f26 100644 --- a/xml.go +++ b/xml.go @@ -11,10 +11,10 @@ const ( ) type xmlRootDevice struct { - Name xml.Name `xml:"root` + Name xml.Name `xml:"root` SpecVersion xmlSpecVersion `xml:"specVersion"` - URLBase string `xml:"URLBase"` - Device xmlDevice `xml:"device"` + URLBase string `xml:"URLBase"` + Device xmlDevice `xml:"device"` } type xmlSpecVersion struct { @@ -23,20 +23,20 @@ type xmlSpecVersion struct { } type xmlDevice struct { - DeviceType string `xml:"deviceType"` - FriendlyName string `xml:"friendlyName"` - Manufacturer string `xml:"manufacturer"` - ManufacturerURL string `xml:"manufacturerURL"` - ModelDescription string `xml:"modelDescription"` - ModelName string `xml:"modelName"` - ModelNumber string `xml:"modelNumber"` - ModelURL string `xml:"modelURL"` - SerialNumber string `xml:"serialNumber"` - UDN string `xml:"UDN"` - UPC string `xml:"UPC,omitempty"` - Icons []xmlIcon `xml:"iconList>icon,omitempty"` - Services []xmlService `xml:"serviceList>service,omitempty"` - Devices []xmlDevice `xml:"deviceList>device,omitempty"` + DeviceType string `xml:"deviceType"` + FriendlyName string `xml:"friendlyName"` + Manufacturer string `xml:"manufacturer"` + ManufacturerURL string `xml:"manufacturerURL"` + ModelDescription string `xml:"modelDescription"` + ModelName string `xml:"modelName"` + ModelNumber string `xml:"modelNumber"` + ModelURL string `xml:"modelURL"` + SerialNumber string `xml:"serialNumber"` + UDN string `xml:"UDN"` + UPC string `xml:"UPC,omitempty"` + Icons []xmlIcon `xml:"iconList>icon,omitempty"` + Services []xmlService `xml:"serviceList>service,omitempty"` + Devices []xmlDevice `xml:"deviceList>device,omitempty"` // Extra observed elements: PresentationURL string `xml:"presentationURL"` @@ -44,16 +44,16 @@ type xmlDevice struct { type xmlIcon struct { Mimetype string `xml:"mimetype"` - Width int32 `xml:"width"` - Height int32 `xml:"height"` - Depth int32 `xml:"depth"` - URL string `xml:"url"` + Width int32 `xml:"width"` + Height int32 `xml:"height"` + Depth int32 `xml:"depth"` + URL string `xml:"url"` } type xmlService struct { ServiceType string `xml:"serviceType"` - ServiceId string `xml:"serviceId"` - SCPDURL string `xml:"SCPDURL"` - ControlURL string `xml:"controlURL"` + ServiceId string `xml:"serviceId"` + SCPDURL string `xml:"SCPDURL"` + ControlURL string `xml:"controlURL"` EventSubURL string `xml:"eventSubURL"` }