Include allowed string values in generated services.
This commit is contained in:
		@@ -23,7 +23,9 @@ import (
 | 
			
		||||
	"github.com/huin/goupnp/v2alpha/description/typedesc"
 | 
			
		||||
	"github.com/huin/goupnp/v2alpha/description/xmlsrvdesc"
 | 
			
		||||
	"github.com/huin/goupnp/v2alpha/soap"
 | 
			
		||||
	"github.com/huin/goupnp/v2alpha/soap/types"
 | 
			
		||||
	"golang.org/x/exp/maps"
 | 
			
		||||
 | 
			
		||||
	soaptypes "github.com/huin/goupnp/v2alpha/soap/types"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
@@ -93,7 +95,7 @@ func run() error {
 | 
			
		||||
 | 
			
		||||
	// Use default type map for now. Addtional types could be use instead or
 | 
			
		||||
	// as well as necessary for extended types.
 | 
			
		||||
	typeMap := types.TypeMap().Clone()
 | 
			
		||||
	typeMap := soaptypes.TypeMap().Clone()
 | 
			
		||||
	typeMap[soapActionInterface] = typedesc.TypeDesc{
 | 
			
		||||
		GoType: reflect.TypeOf((*soap.Action)(nil)).Elem(),
 | 
			
		||||
	}
 | 
			
		||||
@@ -160,7 +162,8 @@ func processService(
 | 
			
		||||
		return fmt.Errorf("transforming service description: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	imps, err := accumulateImports(sd, typeMap)
 | 
			
		||||
	imps := newImports()
 | 
			
		||||
	types, err := accumulateTypes(sd, typeMap, imps)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
@@ -169,6 +172,7 @@ func processService(
 | 
			
		||||
	err = tmpl.ExecuteTemplate(buf, "service", tmplArgs{
 | 
			
		||||
		Manifest: srvManifest,
 | 
			
		||||
		Imps:     imps,
 | 
			
		||||
		Types:    types,
 | 
			
		||||
		SCPD:     sd,
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -221,14 +225,44 @@ type ServiceManifest struct {
 | 
			
		||||
type tmplArgs struct {
 | 
			
		||||
	Manifest *ServiceManifest
 | 
			
		||||
	Imps     *imports
 | 
			
		||||
	Types    *types
 | 
			
		||||
	SCPD     *srvdesc.SCPD
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type imports struct {
 | 
			
		||||
	// Maps from a type name like "ui4" to the `alias.name` for the import.
 | 
			
		||||
	TypeByName map[string]typeDesc
 | 
			
		||||
	// Each required import line, ordered by path.
 | 
			
		||||
	ImportLines []importItem
 | 
			
		||||
	// aliasByPath maps from import path to its imported alias.
 | 
			
		||||
	aliasByPath map[string]string
 | 
			
		||||
	// nextAlias is the number for the next import alias.
 | 
			
		||||
	nextAlias int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newImports() *imports {
 | 
			
		||||
	return &imports{
 | 
			
		||||
		aliasByPath: make(map[string]string),
 | 
			
		||||
		nextAlias:   1,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (imps *imports) getAliasForPath(path string) string {
 | 
			
		||||
	if alias, ok := imps.aliasByPath[path]; ok {
 | 
			
		||||
		return alias
 | 
			
		||||
	}
 | 
			
		||||
	alias := fmt.Sprintf("pkg%d", imps.nextAlias)
 | 
			
		||||
	imps.nextAlias++
 | 
			
		||||
	imps.ImportLines = append(imps.ImportLines, importItem{
 | 
			
		||||
		Alias: alias,
 | 
			
		||||
		Path:  path,
 | 
			
		||||
	})
 | 
			
		||||
	imps.aliasByPath[path] = alias
 | 
			
		||||
	return alias
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type types struct {
 | 
			
		||||
	// Maps from a type name like "ui4" to the `alias.name` for the import.
 | 
			
		||||
	TypeByName    map[string]typeDesc
 | 
			
		||||
	StringVarDefs []stringVarDef
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type typeDesc struct {
 | 
			
		||||
@@ -241,17 +275,41 @@ type typeDesc struct {
 | 
			
		||||
	Name string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type stringVarDef struct {
 | 
			
		||||
	Name          string
 | 
			
		||||
	AllowedValues []string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type importItem struct {
 | 
			
		||||
	Alias string
 | 
			
		||||
	Path  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func accumulateImports(srvDesc *srvdesc.SCPD, typeMap typedesc.TypeMap) (*imports, error) {
 | 
			
		||||
	typeNames := make(map[string]bool)
 | 
			
		||||
	typeNames[soapActionInterface] = true
 | 
			
		||||
// accumulateTypes creates type information, and adds any required imports for
 | 
			
		||||
// them.
 | 
			
		||||
func accumulateTypes(
 | 
			
		||||
	srvDesc *srvdesc.SCPD,
 | 
			
		||||
	typeMap typedesc.TypeMap,
 | 
			
		||||
	imps *imports,
 | 
			
		||||
) (*types, error) {
 | 
			
		||||
	typeNames := make(map[string]struct{})
 | 
			
		||||
	typeNames[soapActionInterface] = struct{}{}
 | 
			
		||||
 | 
			
		||||
	err := visitTypesSCPD(srvDesc, func(typeName string) {
 | 
			
		||||
		typeNames[typeName] = true
 | 
			
		||||
	var stringVarDefs []stringVarDef
 | 
			
		||||
	sortedVarNames := maps.Keys(srvDesc.VariableByName)
 | 
			
		||||
	sort.Strings(sortedVarNames)
 | 
			
		||||
	for _, svName := range sortedVarNames {
 | 
			
		||||
		sv := srvDesc.VariableByName[svName]
 | 
			
		||||
		if sv.DataType == "string" && len(sv.AllowedValues) > 0 {
 | 
			
		||||
			stringVarDefs = append(stringVarDefs, stringVarDef{
 | 
			
		||||
				Name:          svName,
 | 
			
		||||
				AllowedValues: srvDesc.VariableByName[svName].AllowedValues,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := visitTypesSCPD(srvDesc, func(sv *srvdesc.StateVariable) {
 | 
			
		||||
		typeNames[sv.DataType] = struct{}{}
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@@ -259,7 +317,7 @@ func accumulateImports(srvDesc *srvdesc.SCPD, typeMap typedesc.TypeMap) (*import
 | 
			
		||||
 | 
			
		||||
	// Have sorted list of import package paths. Partly for aesthetics of generated code, but also
 | 
			
		||||
	// to have stable-generated aliases.
 | 
			
		||||
	paths := make(map[string]bool)
 | 
			
		||||
	paths := make(map[string]struct{})
 | 
			
		||||
	for typeName := range typeNames {
 | 
			
		||||
		t, ok := typeMap[typeName]
 | 
			
		||||
		if !ok {
 | 
			
		||||
@@ -267,29 +325,17 @@ func accumulateImports(srvDesc *srvdesc.SCPD, typeMap typedesc.TypeMap) (*import
 | 
			
		||||
		}
 | 
			
		||||
		pkgPath := t.GoType.PkgPath()
 | 
			
		||||
		if pkgPath == "" {
 | 
			
		||||
			// Builtin type, ignore.
 | 
			
		||||
			// Builtin type, no import needed.
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		paths[pkgPath] = true
 | 
			
		||||
	}
 | 
			
		||||
	sortedPaths := make([]string, 0, len(paths))
 | 
			
		||||
	for path := range paths {
 | 
			
		||||
		sortedPaths = append(sortedPaths, path)
 | 
			
		||||
		paths[pkgPath] = struct{}{}
 | 
			
		||||
	}
 | 
			
		||||
	sortedPaths := maps.Keys(paths)
 | 
			
		||||
	sort.Strings(sortedPaths)
 | 
			
		||||
 | 
			
		||||
	// Generate import aliases.
 | 
			
		||||
	index := 1
 | 
			
		||||
	aliasByPath := make(map[string]string, len(paths))
 | 
			
		||||
	importLines := make([]importItem, 0, len(paths))
 | 
			
		||||
	// Generate import aliases in deterministic order.
 | 
			
		||||
	for _, path := range sortedPaths {
 | 
			
		||||
		alias := fmt.Sprintf("pkg%d", index)
 | 
			
		||||
		index++
 | 
			
		||||
		importLines = append(importLines, importItem{
 | 
			
		||||
			Alias: alias,
 | 
			
		||||
			Path:  path,
 | 
			
		||||
		})
 | 
			
		||||
		aliasByPath[path] = alias
 | 
			
		||||
		imps.getAliasForPath(path)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Populate typeByName.
 | 
			
		||||
@@ -297,28 +343,27 @@ func accumulateImports(srvDesc *srvdesc.SCPD, typeMap typedesc.TypeMap) (*import
 | 
			
		||||
	for typeName := range typeNames {
 | 
			
		||||
		goType := typeMap[typeName]
 | 
			
		||||
		pkgPath := goType.GoType.PkgPath()
 | 
			
		||||
		alias := aliasByPath[pkgPath]
 | 
			
		||||
		td := typeDesc{
 | 
			
		||||
			Name: goType.GoType.Name(),
 | 
			
		||||
		}
 | 
			
		||||
		if alias == "" {
 | 
			
		||||
		if pkgPath == "" {
 | 
			
		||||
			// Builtin type.
 | 
			
		||||
			td.AbsRef = td.Name
 | 
			
		||||
			td.Ref = td.Name
 | 
			
		||||
		} else {
 | 
			
		||||
			td.AbsRef = strconv.Quote(pkgPath) + "." + td.Name
 | 
			
		||||
			td.Ref = alias + "." + td.Name
 | 
			
		||||
			td.Ref = imps.getAliasForPath(pkgPath) + "." + td.Name
 | 
			
		||||
		}
 | 
			
		||||
		typeByName[typeName] = td
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &imports{
 | 
			
		||||
		TypeByName:  typeByName,
 | 
			
		||||
		ImportLines: importLines,
 | 
			
		||||
	return &types{
 | 
			
		||||
		TypeByName:    typeByName,
 | 
			
		||||
		StringVarDefs: stringVarDefs,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type typeVisitor func(typeName string)
 | 
			
		||||
type typeVisitor func(sv *srvdesc.StateVariable)
 | 
			
		||||
 | 
			
		||||
// visitTypesSCPD calls `visitor` with each data type name (e.g. "ui4") referenced
 | 
			
		||||
// by action arguments.`
 | 
			
		||||
@@ -337,14 +382,14 @@ func visitTypesAction(action *srvdesc.Action, visitor typeVisitor) error {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		visitor(sv.DataType)
 | 
			
		||||
		visitor(sv)
 | 
			
		||||
	}
 | 
			
		||||
	for _, arg := range action.OutArgs {
 | 
			
		||||
		sv, err := arg.RelatedStateVariable()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		visitor(sv.DataType)
 | 
			
		||||
		visitor(sv)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user