From d5f9ab1f203eb56595d0e83a5a44e7990ab81f78 Mon Sep 17 00:00:00 2001 From: John Beisley Date: Sun, 5 Sep 2021 11:29:02 +0100 Subject: [PATCH] SOAP types now implement more standard interfaces. Specifically TextMarshaler and TextUnmarshaller. These adapt better to use in xml/encoding. --- v2/soap/types/types.go | 383 ++++++++++++++++++++---------------- v2/soap/types/types_test.go | 9 +- 2 files changed, 223 insertions(+), 169 deletions(-) diff --git a/v2/soap/types/types.go b/v2/soap/types/types.go index 6983671..cc3d585 100644 --- a/v2/soap/types/types.go +++ b/v2/soap/types/types.go @@ -7,6 +7,8 @@ package types import ( + "bytes" + "encoding" "encoding/base64" "encoding/hex" "errors" @@ -14,35 +16,45 @@ import ( "net/url" "regexp" "strconv" - "strings" "time" "unicode/utf8" ) type SOAPValue interface { - Marshal() (string, error) - Unmarshal(s string) error + encoding.TextMarshaler + encoding.TextUnmarshaler } type UI1 uint8 var _ SOAPValue = new(UI1) +// toStringNoError converts `v` to a string, returning empty string on error. +// This should only be used for String() implementations if no error can be +// returned by v.MarshalText(). +func toStringNoError(v encoding.TextMarshaler) string { + b, err := v.MarshalText() + if err != nil { + return "" + } + return string(b) +} + func NewUI1(v uint8) *UI1 { v2 := UI1(v) return &v2 } func (v *UI1) String() string { - return strconv.FormatUint(uint64(*v), 10) + return toStringNoError(v) } -func (v *UI1) Marshal() (string, error) { - return v.String(), nil +func (v *UI1) MarshalText() ([]byte, error) { + return strconv.AppendUint(nil, uint64(*v), 10), nil } -func (v *UI1) Unmarshal(s string) error { - v2, err := strconv.ParseUint(s, 10, 8) +func (v *UI1) UnmarshalText(b []byte) error { + v2, err := strconv.ParseUint(string(b), 10, 8) *v = UI1(v2) return err } @@ -57,15 +69,15 @@ func NewUI2(v uint16) *UI2 { } func (v *UI2) String() string { - return strconv.FormatUint(uint64(*v), 10) + return toStringNoError(v) } -func (v *UI2) Marshal() (string, error) { - return v.String(), nil +func (v *UI2) MarshalText() ([]byte, error) { + return strconv.AppendUint(nil, uint64(*v), 10), nil } -func (v *UI2) Unmarshal(s string) error { - v2, err := strconv.ParseUint(s, 10, 16) +func (v *UI2) UnmarshalText(b []byte) error { + v2, err := strconv.ParseUint(string(b), 10, 16) *v = UI2(v2) return err } @@ -80,15 +92,15 @@ func NewUI4(v uint32) *UI4 { } func (v *UI4) String() string { - return strconv.FormatUint(uint64(*v), 10) + return toStringNoError(v) } -func (v *UI4) Marshal() (string, error) { - return v.String(), nil +func (v *UI4) MarshalText() ([]byte, error) { + return strconv.AppendUint(nil, uint64(*v), 10), nil } -func (v *UI4) Unmarshal(s string) error { - v2, err := strconv.ParseUint(s, 10, 32) +func (v *UI4) UnmarshalText(b []byte) error { + v2, err := strconv.ParseUint(string(b), 10, 32) *v = UI4(v2) return err } @@ -103,15 +115,15 @@ func NewUI8(v uint64) *UI8 { } func (v *UI8) String() string { - return strconv.FormatUint(uint64(*v), 10) + return toStringNoError(v) } -func (v *UI8) Marshal() (string, error) { - return v.String(), nil +func (v *UI8) MarshalText() ([]byte, error) { + return strconv.AppendUint(nil, uint64(*v), 10), nil } -func (v *UI8) Unmarshal(s string) error { - v2, err := strconv.ParseUint(s, 10, 64) +func (v *UI8) UnmarshalText(b []byte) error { + v2, err := strconv.ParseUint(string(b), 10, 64) *v = UI8(v2) return err } @@ -126,15 +138,15 @@ func NewI1(v int8) *I1 { } func (v *I1) String() string { - return strconv.FormatInt(int64(*v), 10) + return toStringNoError(v) } -func (v *I1) Marshal() (string, error) { - return v.String(), nil +func (v *I1) MarshalText() ([]byte, error) { + return strconv.AppendInt(nil, int64(*v), 10), nil } -func (v *I1) Unmarshal(s string) error { - v2, err := strconv.ParseInt(s, 10, 8) +func (v *I1) UnmarshalText(b []byte) error { + v2, err := strconv.ParseInt(string(b), 10, 8) *v = I1(v2) return err } @@ -149,15 +161,15 @@ func NewI2(v int16) *I2 { } func (v *I2) String() string { - return strconv.FormatInt(int64(*v), 10) + return toStringNoError(v) } -func (v *I2) Marshal() (string, error) { - return v.String(), nil +func (v *I2) MarshalText() ([]byte, error) { + return strconv.AppendInt(nil, int64(*v), 10), nil } -func (v *I2) Unmarshal(s string) error { - v2, err := strconv.ParseInt(s, 10, 16) +func (v *I2) UnmarshalText(b []byte) error { + v2, err := strconv.ParseInt(string(b), 10, 16) *v = I2(v2) return err } @@ -172,15 +184,15 @@ func NewI4(v int32) *I4 { } func (v *I4) String() string { - return strconv.FormatInt(int64(*v), 10) + return toStringNoError(v) } -func (v *I4) Marshal() (string, error) { - return v.String(), nil +func (v *I4) MarshalText() ([]byte, error) { + return strconv.AppendInt(nil, int64(*v), 10), nil } -func (v *I4) Unmarshal(s string) error { - v2, err := strconv.ParseInt(s, 10, 32) +func (v *I4) UnmarshalText(b []byte) error { + v2, err := strconv.ParseInt(string(b), 10, 32) *v = I4(v2) return err } @@ -195,15 +207,15 @@ func NewI8(v int64) *I8 { } func (v *I8) String() string { - return strconv.FormatInt(int64(*v), 10) + return toStringNoError(v) } -func (v *I8) Marshal() (string, error) { - return v.String(), nil +func (v *I8) MarshalText() ([]byte, error) { + return strconv.AppendInt(nil, int64(*v), 10), nil } -func (v *I8) Unmarshal(s string) error { - v2, err := strconv.ParseInt(s, 10, 64) +func (v *I8) UnmarshalText(b []byte) error { + v2, err := strconv.ParseInt(string(b), 10, 64) *v = I8(v2) return err } @@ -218,15 +230,15 @@ func NewR4(v float32) *R4 { } func (v *R4) String() string { - return strconv.FormatFloat(float64(*v), 'G', -1, 32) + return toStringNoError(v) } -func (v *R4) Marshal() (string, error) { - return v.String(), nil +func (v *R4) MarshalText() ([]byte, error) { + return strconv.AppendFloat(nil, float64(*v), 'g', -1, 32), nil } -func (v *R4) Unmarshal(s string) error { - v2, err := strconv.ParseFloat(s, 32) +func (v *R4) UnmarshalText(b []byte) error { + v2, err := strconv.ParseFloat(string(b), 32) *v = R4(v2) return err } @@ -241,15 +253,15 @@ func NewR8(v float64) *R8 { } func (v *R8) String() string { - return strconv.FormatFloat(float64(*v), 'G', -1, 64) + return toStringNoError(v) } -func (v *R8) Marshal() (string, error) { - return v.String(), nil +func (v *R8) MarshalText() ([]byte, error) { + return strconv.AppendFloat(nil, float64(*v), 'g', -1, 64), nil } -func (v *R8) Unmarshal(s string) error { - v2, err := strconv.ParseFloat(s, 64) +func (v *R8) UnmarshalText(b []byte) error { + v2, err := strconv.ParseFloat(string(b), 64) *v = R8(v2) return err } @@ -347,21 +359,23 @@ func (v Fixed14_4) Float64() float64 { return float64(v.Fractional) / Fixed14_4Denominator } -func (v Fixed14_4) String() string { +func (v *Fixed14_4) String() string { + return toStringNoError(v) +} + +func (v *Fixed14_4) MarshalText() ([]byte, error) { intPart, fracPart := v.Parts() if fracPart < 0 { fracPart = -fracPart } - return fmt.Sprintf("%d.%04d", intPart, fracPart) + return []byte(fmt.Sprintf("%d.%04d", intPart, fracPart)), nil } -func (v *Fixed14_4) Marshal() (string, error) { - return v.String(), nil -} +var decimalByte = []byte{'.'} -func (v *Fixed14_4) Unmarshal(s string) error { - parts := strings.SplitN(s, ".", 2) - intPart, err := strconv.ParseInt(parts[0], 10, 64) +func (v *Fixed14_4) UnmarshalText(b []byte) error { + parts := bytes.SplitN(b, decimalByte, 2) + intPart, err := strconv.ParseInt(string(parts[0]), 10, 64) if err != nil { return err } @@ -372,20 +386,20 @@ func (v *Fixed14_4) Unmarshal(s string) error { for _, r := range fracStr { if r < '0' || r > '9' { - return fmt.Errorf("found non-digit in fractional component of %q", s) + return fmt.Errorf("found non-digit in fractional component of %q", string(b)) } } // Take only the 4 most significant digits of the fractional component. fracStr = fracStr[:min(len(fracStr), 4)] - fracPart, err = strconv.ParseInt(fracStr, 10, 16) + fracPart, err = strconv.ParseInt(string(fracStr), 10, 16) if err != nil { return err } if fracPart < 0 { // This shouldn't happen by virtue of earlier digit-only check. - return fmt.Errorf("got negative fractional component in %q", s) + return fmt.Errorf("got negative fractional component in %q", string(b)) } switch len(fracStr) { @@ -417,25 +431,28 @@ func NewChar(v rune) *Char { } func (v *Char) String() string { - return string(*v) + return string(rune(*v)) } -func (v *Char) Marshal() (string, error) { +func (v *Char) MarshalText() ([]byte, error) { if *v == 0 { - return "", errors.New("soap char: rune 0 is not allowed") + return nil, errors.New("soap char: rune 0 is not allowed") } - return v.String(), nil + result := make([]byte, utf8.RuneLen(rune(*v))) + n := utf8.EncodeRune(result, rune(*v)) + result = result[0:n] + return result, nil } -func (v *Char) Unmarshal(s string) error { - if len(s) == 0 { +func (v *Char) UnmarshalText(b []byte) error { + if len(b) == 0 { return errors.New("soap char: got empty string") } - v2, n := utf8.DecodeRune([]byte(s)) - if n != len(s) { - return fmt.Errorf("soap char: value %q is not a single rune", s) + r, n := utf8.DecodeRune(b) + if n != len(b) { + return fmt.Errorf("soap char: value %q is not a single rune", string(b)) } - *v = Char(v2) + *v = Char(r) return nil } @@ -448,17 +465,17 @@ func NewString(v string) *String { return &v2 } -func (v *String) Marshal() (string, error) { - return string(*v), nil +func (v *String) MarshalText() ([]byte, error) { + return []byte(*v), nil } -func (v *String) Unmarshal(s string) error { - *v = String(s) +func (v *String) UnmarshalText(b []byte) error { + *v = String(b) return nil } -func parseInt(s string, err *error) int { - v, parseErr := strconv.ParseInt(s, 10, 64) +func parseInt(b []byte, err *error) int { + v, parseErr := strconv.ParseInt(string(b), 10, 64) if parseErr != nil { *err = parseErr } @@ -473,22 +490,22 @@ var dateRegexps = []*regexp.Regexp{ } type prefixRemainder struct { - prefix string - remainder string + prefix []byte + remainder []byte } -// prefixUntilAny returns a prefix of the leading string prior to any +// prefixUntilAny returns a prefix of the leading bytes prior to any // characters in `chars`, and the remainder. If no character from `chars` is -// present in `s`, then returns `s` as `prefix`, and empty remainder. +// present in `b`, then returns `b` as `prefix`, and empty `remainder`. // -// prefixUntilAny("123/abc", "/") => {"123", "/abc"} -// prefixUntilAny("123", "/") => {"123", ""} -func prefixUntilAny(s string, chars string) prefixRemainder { - i := strings.IndexAny(s, chars) +// prefixUntilAny([]byte("123/abc"), "/") => {[]byte("123"), []byte("/abc")} +// prefixUntilAny([]byte("123"), "/") => {[]byte("123"), []byte("")} +func prefixUntilAny(b []byte, chars string) prefixRemainder { + i := bytes.IndexAny(b, chars) if i == -1 { - return prefixRemainder{prefix: s, remainder: ""} + return prefixRemainder{prefix: b, remainder: nil} } - return prefixRemainder{prefix: s[:i], remainder: s[i:]} + return prefixRemainder{prefix: b[:i], remainder: b[i:]} } // TimeOfDay is used in cases where SOAP "time" or "time.tz" is used. @@ -533,8 +550,8 @@ func (tod TimeOfDay) ToDuration() time.Duration { time.Duration(tod.Second)*time.Second } -func (tod TimeOfDay) String() string { - return fmt.Sprintf("%02d:%02d:%02d", tod.Hour, tod.Minute, tod.Second) +func (tod *TimeOfDay) String() string { + return toStringNoError(tod) } // IsValid returns true iff v is positive and <= 24 hours. @@ -556,11 +573,11 @@ func (tod *TimeOfDay) clear() { tod.Second = 0 } -func (tod *TimeOfDay) Marshal() (string, error) { +func (tod *TimeOfDay) MarshalText() ([]byte, error) { if err := tod.CheckValid(); err != nil { - return "", err + return nil, err } - return tod.String(), nil + return []byte(fmt.Sprintf("%02d:%02d:%02d", tod.Hour, tod.Minute, tod.Second)), nil } var timeRegexps = []*regexp.Regexp{ @@ -570,17 +587,17 @@ var timeRegexps = []*regexp.Regexp{ regexp.MustCompile(`^(\d{2}):(\d{2}):(\d{2})$`), } -func (tod *TimeOfDay) Unmarshal(s string) error { +func (tod *TimeOfDay) UnmarshalText(b []byte) error { tod.clear() - var parts []string + var parts [][]byte for _, re := range timeRegexps { - parts = re.FindStringSubmatch(s) + parts = re.FindSubmatch(b) if parts != nil { break } } if parts == nil { - return fmt.Errorf("value %q is not in ISO8601 time format", s) + return fmt.Errorf("value %q is not in ISO8601 time format", string(b)) } var err error @@ -588,7 +605,7 @@ func (tod *TimeOfDay) Unmarshal(s string) error { tod.Minute = int8(parseInt(parts[2], &err)) tod.Second = int8(parseInt(parts[3], &err)) if err != nil { - return fmt.Errorf("value %q is not in ISO8601 time format: %v", s, err) + return fmt.Errorf("value %q is not in ISO8601 time format: %v", string(b), err) } return tod.CheckValid() @@ -605,8 +622,8 @@ type TimeOfDayTZ struct { var _ SOAPValue = &TimeOfDayTZ{} -func (todz TimeOfDayTZ) String() string { - return fmt.Sprintf("%v%v", todz.TimeOfDay, todz.TZ) +func (todz *TimeOfDayTZ) String() string { + return toStringNoError(todz) } // clear removes data from v, setting to default values. @@ -615,17 +632,22 @@ func (todz *TimeOfDayTZ) clear() { todz.TZ.clear() } -func (todz *TimeOfDayTZ) Marshal() (string, error) { - return todz.String(), nil +func (todz *TimeOfDayTZ) MarshalText() ([]byte, error) { + result, err := todz.TimeOfDay.MarshalText() + if err != nil { + return nil, err + } + result = append(result, []byte(todz.TZ.String())...) + return result, nil } -func (todz *TimeOfDayTZ) Unmarshal(s string) error { +func (todz *TimeOfDayTZ) UnmarshalText(b []byte) error { todz.clear() - parts := prefixUntilAny(s, "Z+-") - if err := todz.TimeOfDay.Unmarshal(parts.prefix); err != nil { + parts := prefixUntilAny(b, "Z+-") + if err := todz.TimeOfDay.UnmarshalText(parts.prefix); err != nil { return err } - return todz.TZ.unmarshal(parts.remainder) + return todz.TZ.unmarshalText(parts.remainder) } // Date maps to the SOAP "date" type. Marshaling and Unmarshalling does *not* @@ -656,8 +678,8 @@ func (d Date) ToTime(loc *time.Location) time.Time { return time.Date(d.Year, d.Month, d.Day, 0, 0, 0, 0, loc) } -func (d Date) String() string { - return fmt.Sprintf("%04d-%02d-%02d", d.Year, d.Month, d.Day) +func (d *Date) String() string { + return toStringNoError(d) } // CheckValid returns an error if the date components are out of range. @@ -676,21 +698,21 @@ func (d *Date) clear() { d.Day = 0 } -func (d *Date) Marshal() (string, error) { - return d.String(), nil +func (d *Date) MarshalText() ([]byte, error) { + return []byte(fmt.Sprintf("%04d-%02d-%02d", d.Year, d.Month, d.Day)), nil } -func (d *Date) Unmarshal(s string) error { +func (d *Date) UnmarshalText(b []byte) error { d.clear() - var parts []string + var parts [][]byte for _, re := range dateRegexps { - parts = re.FindStringSubmatch(s) + parts = re.FindSubmatch(b) if parts != nil { break } } if parts == nil { - return fmt.Errorf("error parsing date: value %q is not in a recognized ISO8601 date format", s) + return fmt.Errorf("error parsing date: value %q is not in a recognized ISO8601 date format", string(b)) } var err error @@ -705,7 +727,7 @@ func (d *Date) Unmarshal(s string) error { } if err != nil { - return fmt.Errorf("error parsing date %q: %v", s, err) + return fmt.Errorf("error parsing date %q: %v", string(b), err) } return nil @@ -726,8 +748,8 @@ func DateTimeFromTime(v time.Time) DateTime { return dt } -func (dt DateTime) String() string { - return fmt.Sprintf("%vT%v", dt.Date, dt.TimeOfDay) +func (dt *DateTime) String() string { + return toStringNoError(dt) } func (dt DateTime) ToTime(loc *time.Location) time.Time { @@ -742,26 +764,38 @@ func (dt *DateTime) clear() { dt.TimeOfDay.clear() } -func (dt *DateTime) Marshal() (string, error) { - return dt.String(), nil +func (dt *DateTime) MarshalText() ([]byte, error) { + var result []byte + d, err := dt.Date.MarshalText() + if err != nil { + return nil, err + } + t, err := dt.TimeOfDay.MarshalText() + if err != nil { + return nil, err + } + result = append(result, d...) + result = append(result, 'T') + result = append(result, t...) + return result, nil } -func (dt *DateTime) Unmarshal(s string) error { +func (dt *DateTime) UnmarshalText(b []byte) error { dt.clear() - parts := prefixUntilAny(s, "T") - if err := dt.Date.Unmarshal(parts.prefix); err != nil { + parts := prefixUntilAny(b, "T") + if err := dt.Date.UnmarshalText(parts.prefix); err != nil { return err } - if parts.remainder == "" { + if len(parts.remainder) == 0 { return nil } if parts.remainder[0] != 'T' { - return fmt.Errorf("missing 'T' time separator in dateTime %q", s) + return fmt.Errorf("missing 'T' time separator in dateTime %q", string(b)) } - return dt.TimeOfDay.Unmarshal(parts.remainder[1:]) + return dt.TimeOfDay.UnmarshalText(parts.remainder[1:]) } // DateTime maps to SOAP type "dateTime.tz". @@ -781,8 +815,8 @@ func DateTimeTZFromTime(t time.Time) DateTimeTZ { } } -func (dtz DateTimeTZ) String() string { - return dtz.Date.String() + "T" + dtz.TimeOfDay.String() + dtz.TZ.String() +func (dtz *DateTimeTZ) String() string { + return toStringNoError(dtz) } // Time converts `dtz` to time.Time, using defaultLoc as the default location if @@ -807,33 +841,46 @@ func (dtz *DateTimeTZ) clear() { dtz.TZ.clear() } -func (dtz *DateTimeTZ) Marshal() (string, error) { - return dtz.String(), nil +func (dtz *DateTimeTZ) MarshalText() ([]byte, error) { + var result []byte + d, err := dtz.Date.MarshalText() + if err != nil { + return nil, err + } + t, err := dtz.TimeOfDay.MarshalText() + if err != nil { + return nil, err + } + result = append(result, d...) + result = append(result, 'T') + result = append(result, t...) + result = append(result, []byte(dtz.TZ.String())...) + return result, nil } -func (dtz *DateTimeTZ) Unmarshal(s string) error { +func (dtz *DateTimeTZ) UnmarshalText(b []byte) error { dtz.clear() - dateParts := prefixUntilAny(s, "T") - if err := dtz.Date.Unmarshal(dateParts.prefix); err != nil { + dateParts := prefixUntilAny(b, "T") + if err := dtz.Date.UnmarshalText(dateParts.prefix); err != nil { return err } - if dateParts.remainder == "" { + if len(dateParts.remainder) == 0 { return nil } // Trim the leading "T" between date and time. remainder := dateParts.remainder[1:] timeParts := prefixUntilAny(remainder, "Z+-") - if err := dtz.TimeOfDay.Unmarshal(timeParts.prefix); err != nil { + if err := dtz.TimeOfDay.UnmarshalText(timeParts.prefix); err != nil { return err } - if timeParts.remainder == "" { + if len(timeParts.remainder) == 0 { return nil } - return dtz.TZ.unmarshal(timeParts.remainder) + return dtz.TZ.unmarshalText(timeParts.remainder) } // TZD is a timezone designator. Not a full SOAP time in itself, but used as @@ -902,29 +949,29 @@ func (tzd *TZD) clear() { // (+|-)(hh):(mm) var timezoneRegexp = regexp.MustCompile(`^([+-])(\d{2}):(\d{2})$`) -func (tzd *TZD) unmarshal(s string) error { +func (tzd *TZD) unmarshalText(b []byte) error { tzd.clear() - if s == "" { + if len(b) == 0 { return nil } tzd.HasTZ = true - if s == "Z" { + if len(b) == 1 && b[0] == 'Z' { return nil } - parts := timezoneRegexp.FindStringSubmatch(s) + parts := timezoneRegexp.FindSubmatch(b) if parts == nil { - return fmt.Errorf("value %q is not in ISO8601 timezone format", s) + return fmt.Errorf("value %q is not in ISO8601 timezone format", string(b)) } var err error tzd.Offset = parseInt(parts[2], &err) * 3600 tzd.Offset += parseInt(parts[3], &err) * 60 - if parts[1] == "-" { + if len(parts[1]) == 1 && parts[1][0] == '-' { tzd.Offset = -tzd.Offset } if err != nil { - err = fmt.Errorf("value %q is not in ISO8601 timezone format: %v", s, err) + err = fmt.Errorf("value %q is not in ISO8601 timezone format: %v", string(b), err) } return nil @@ -946,21 +993,21 @@ func (v *Boolean) String() string { return "false" } -func (v *Boolean) Marshal() (string, error) { +func (v *Boolean) MarshalText() ([]byte, error) { if *v { - return "1", nil + return []byte{'1'}, nil } - return "0", nil + return []byte{'0'}, nil } -func (v *Boolean) Unmarshal(s string) error { - switch s { +func (v *Boolean) UnmarshalText(b []byte) error { + switch string(b) { case "0", "false", "no": *v = false case "1", "true", "yes": *v = true default: - return fmt.Errorf("soap boolean: %q is not a valid boolean value", s) + return fmt.Errorf("soap boolean: %q is not a valid boolean value", string(b)) } return nil } @@ -979,13 +1026,16 @@ func (v *BinBase64) String() string { return base64.StdEncoding.EncodeToString(*v) } -func (v *BinBase64) Marshal() (string, error) { - return v.String(), nil +func (v *BinBase64) MarshalText() ([]byte, error) { + result := make([]byte, base64.StdEncoding.EncodedLen(len(*v))) + base64.StdEncoding.Encode(result, []byte(*v)) + return result, nil } -func (v *BinBase64) Unmarshal(s string) error { - v2, err := base64.StdEncoding.DecodeString(s) - *v = v2 +func (v *BinBase64) UnmarshalText(b []byte) error { + *v = make(BinBase64, base64.StdEncoding.DecodedLen(len(b))) + n, err := base64.StdEncoding.Decode([]byte(*v), b) + *v = (*v)[:n] return err } @@ -1003,13 +1053,16 @@ func (v *BinHex) String() string { return hex.EncodeToString(*v) } -func (v *BinHex) Marshal() (string, error) { - return v.String(), nil +func (v *BinHex) MarshalText() ([]byte, error) { + result := make([]byte, hex.EncodedLen(len(*v))) + hex.Encode(result, []byte(*v)) + return result, nil } -func (v *BinHex) Unmarshal(s string) error { - v2, err := hex.DecodeString(s) - *v = v2 +func (v *BinHex) UnmarshalText(b []byte) error { + *v = make(BinHex, hex.DecodedLen(len(b))) + n, err := hex.Decode(*v, b) + *v = (*v)[:n] return err } @@ -1026,12 +1079,12 @@ func (v *URI) ToURL() *url.URL { return (*url.URL)(v) } -func (v *URI) Marshal() (string, error) { - return (*url.URL)(v).String(), nil +func (v *URI) MarshalText() ([]byte, error) { + return []byte((*url.URL)(v).String()), nil } -func (v *URI) Unmarshal(s string) error { - v2, err := url.Parse(s) +func (v *URI) UnmarshalText(b []byte) error { + v2, err := url.Parse(string(b)) if err != nil { return err } diff --git a/v2/soap/types/types_test.go b/v2/soap/types/types_test.go index d5246ae..c3007e0 100644 --- a/v2/soap/types/types_test.go +++ b/v2/soap/types/types_test.go @@ -372,10 +372,11 @@ func Test(t *testing.T) { for i, mt := range tt.marshalTests { mt := mt t.Run(fmt.Sprintf("marshalTest#%d_%v", i, mt.input), func(t *testing.T) { - got, err := mt.input.Marshal() + gotBytes, err := mt.input.MarshalText() if err != nil { t.Errorf("got unexpected error: %v", err) } + got := string(gotBytes) if got != mt.want { t.Errorf("got %q, want: %q", got, mt.want) } @@ -384,7 +385,7 @@ func Test(t *testing.T) { for i, input := range tt.marshalErrs { input := input t.Run(fmt.Sprintf("marshalErr#%d_%v", i, input), func(t *testing.T) { - got, err := input.Marshal() + got, err := input.MarshalText() if err == nil { t.Errorf("got %q, want error", got) } @@ -394,7 +395,7 @@ func Test(t *testing.T) { ut := ut t.Run(fmt.Sprintf("unmarshalTest#%d_%q", i, ut.input), func(t *testing.T) { got := tt.makeValue() - if err := got.Unmarshal(ut.input); err != nil { + if err := got.UnmarshalText([]byte(ut.input)); err != nil { t.Errorf("got unexpected error: %v", err) } if !tt.isEqual(got, ut.want) { @@ -406,7 +407,7 @@ func Test(t *testing.T) { input := input t.Run(fmt.Sprintf("unmarshalErrs#%d_%q", i, input), func(t *testing.T) { got := tt.makeValue() - if err := got.Unmarshal(input); err == nil { + if err := got.UnmarshalText([]byte(input)); err == nil { t.Errorf("got %v, want error", got) } })