Include allowed string values in generated services.
This commit is contained in:
@ -23,7 +23,9 @@ import (
|
||||
"github.com/huin/goupnp/v2alpha/description/typedesc"
|
||||
"github.com/huin/goupnp/v2alpha/description/xmlsrvdesc"
|
||||
"github.com/huin/goupnp/v2alpha/soap"
|
||||
"github.com/huin/goupnp/v2alpha/soap/types"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
soaptypes "github.com/huin/goupnp/v2alpha/soap/types"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -93,7 +95,7 @@ func run() error {
|
||||
|
||||
// Use default type map for now. Addtional types could be use instead or
|
||||
// as well as necessary for extended types.
|
||||
typeMap := types.TypeMap().Clone()
|
||||
typeMap := soaptypes.TypeMap().Clone()
|
||||
typeMap[soapActionInterface] = typedesc.TypeDesc{
|
||||
GoType: reflect.TypeOf((*soap.Action)(nil)).Elem(),
|
||||
}
|
||||
@ -160,7 +162,8 @@ func processService(
|
||||
return fmt.Errorf("transforming service description: %w", err)
|
||||
}
|
||||
|
||||
imps, err := accumulateImports(sd, typeMap)
|
||||
imps := newImports()
|
||||
types, err := accumulateTypes(sd, typeMap, imps)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -169,6 +172,7 @@ func processService(
|
||||
err = tmpl.ExecuteTemplate(buf, "service", tmplArgs{
|
||||
Manifest: srvManifest,
|
||||
Imps: imps,
|
||||
Types: types,
|
||||
SCPD: sd,
|
||||
})
|
||||
if err != nil {
|
||||
@ -221,14 +225,44 @@ type ServiceManifest struct {
|
||||
type tmplArgs struct {
|
||||
Manifest *ServiceManifest
|
||||
Imps *imports
|
||||
Types *types
|
||||
SCPD *srvdesc.SCPD
|
||||
}
|
||||
|
||||
type imports struct {
|
||||
// Maps from a type name like "ui4" to the `alias.name` for the import.
|
||||
TypeByName map[string]typeDesc
|
||||
// Each required import line, ordered by path.
|
||||
ImportLines []importItem
|
||||
// aliasByPath maps from import path to its imported alias.
|
||||
aliasByPath map[string]string
|
||||
// nextAlias is the number for the next import alias.
|
||||
nextAlias int
|
||||
}
|
||||
|
||||
func newImports() *imports {
|
||||
return &imports{
|
||||
aliasByPath: make(map[string]string),
|
||||
nextAlias: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (imps *imports) getAliasForPath(path string) string {
|
||||
if alias, ok := imps.aliasByPath[path]; ok {
|
||||
return alias
|
||||
}
|
||||
alias := fmt.Sprintf("pkg%d", imps.nextAlias)
|
||||
imps.nextAlias++
|
||||
imps.ImportLines = append(imps.ImportLines, importItem{
|
||||
Alias: alias,
|
||||
Path: path,
|
||||
})
|
||||
imps.aliasByPath[path] = alias
|
||||
return alias
|
||||
}
|
||||
|
||||
type types struct {
|
||||
// Maps from a type name like "ui4" to the `alias.name` for the import.
|
||||
TypeByName map[string]typeDesc
|
||||
StringVarDefs []stringVarDef
|
||||
}
|
||||
|
||||
type typeDesc struct {
|
||||
@ -241,17 +275,41 @@ type typeDesc struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type stringVarDef struct {
|
||||
Name string
|
||||
AllowedValues []string
|
||||
}
|
||||
|
||||
type importItem struct {
|
||||
Alias string
|
||||
Path string
|
||||
}
|
||||
|
||||
func accumulateImports(srvDesc *srvdesc.SCPD, typeMap typedesc.TypeMap) (*imports, error) {
|
||||
typeNames := make(map[string]bool)
|
||||
typeNames[soapActionInterface] = true
|
||||
// accumulateTypes creates type information, and adds any required imports for
|
||||
// them.
|
||||
func accumulateTypes(
|
||||
srvDesc *srvdesc.SCPD,
|
||||
typeMap typedesc.TypeMap,
|
||||
imps *imports,
|
||||
) (*types, error) {
|
||||
typeNames := make(map[string]struct{})
|
||||
typeNames[soapActionInterface] = struct{}{}
|
||||
|
||||
err := visitTypesSCPD(srvDesc, func(typeName string) {
|
||||
typeNames[typeName] = true
|
||||
var stringVarDefs []stringVarDef
|
||||
sortedVarNames := maps.Keys(srvDesc.VariableByName)
|
||||
sort.Strings(sortedVarNames)
|
||||
for _, svName := range sortedVarNames {
|
||||
sv := srvDesc.VariableByName[svName]
|
||||
if sv.DataType == "string" && len(sv.AllowedValues) > 0 {
|
||||
stringVarDefs = append(stringVarDefs, stringVarDef{
|
||||
Name: svName,
|
||||
AllowedValues: srvDesc.VariableByName[svName].AllowedValues,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
err := visitTypesSCPD(srvDesc, func(sv *srvdesc.StateVariable) {
|
||||
typeNames[sv.DataType] = struct{}{}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -259,7 +317,7 @@ func accumulateImports(srvDesc *srvdesc.SCPD, typeMap typedesc.TypeMap) (*import
|
||||
|
||||
// Have sorted list of import package paths. Partly for aesthetics of generated code, but also
|
||||
// to have stable-generated aliases.
|
||||
paths := make(map[string]bool)
|
||||
paths := make(map[string]struct{})
|
||||
for typeName := range typeNames {
|
||||
t, ok := typeMap[typeName]
|
||||
if !ok {
|
||||
@ -267,29 +325,17 @@ func accumulateImports(srvDesc *srvdesc.SCPD, typeMap typedesc.TypeMap) (*import
|
||||
}
|
||||
pkgPath := t.GoType.PkgPath()
|
||||
if pkgPath == "" {
|
||||
// Builtin type, ignore.
|
||||
// Builtin type, no import needed.
|
||||
continue
|
||||
}
|
||||
paths[pkgPath] = true
|
||||
}
|
||||
sortedPaths := make([]string, 0, len(paths))
|
||||
for path := range paths {
|
||||
sortedPaths = append(sortedPaths, path)
|
||||
paths[pkgPath] = struct{}{}
|
||||
}
|
||||
sortedPaths := maps.Keys(paths)
|
||||
sort.Strings(sortedPaths)
|
||||
|
||||
// Generate import aliases.
|
||||
index := 1
|
||||
aliasByPath := make(map[string]string, len(paths))
|
||||
importLines := make([]importItem, 0, len(paths))
|
||||
// Generate import aliases in deterministic order.
|
||||
for _, path := range sortedPaths {
|
||||
alias := fmt.Sprintf("pkg%d", index)
|
||||
index++
|
||||
importLines = append(importLines, importItem{
|
||||
Alias: alias,
|
||||
Path: path,
|
||||
})
|
||||
aliasByPath[path] = alias
|
||||
imps.getAliasForPath(path)
|
||||
}
|
||||
|
||||
// Populate typeByName.
|
||||
@ -297,28 +343,27 @@ func accumulateImports(srvDesc *srvdesc.SCPD, typeMap typedesc.TypeMap) (*import
|
||||
for typeName := range typeNames {
|
||||
goType := typeMap[typeName]
|
||||
pkgPath := goType.GoType.PkgPath()
|
||||
alias := aliasByPath[pkgPath]
|
||||
td := typeDesc{
|
||||
Name: goType.GoType.Name(),
|
||||
}
|
||||
if alias == "" {
|
||||
if pkgPath == "" {
|
||||
// Builtin type.
|
||||
td.AbsRef = td.Name
|
||||
td.Ref = td.Name
|
||||
} else {
|
||||
td.AbsRef = strconv.Quote(pkgPath) + "." + td.Name
|
||||
td.Ref = alias + "." + td.Name
|
||||
td.Ref = imps.getAliasForPath(pkgPath) + "." + td.Name
|
||||
}
|
||||
typeByName[typeName] = td
|
||||
}
|
||||
|
||||
return &imports{
|
||||
TypeByName: typeByName,
|
||||
ImportLines: importLines,
|
||||
return &types{
|
||||
TypeByName: typeByName,
|
||||
StringVarDefs: stringVarDefs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type typeVisitor func(typeName string)
|
||||
type typeVisitor func(sv *srvdesc.StateVariable)
|
||||
|
||||
// visitTypesSCPD calls `visitor` with each data type name (e.g. "ui4") referenced
|
||||
// by action arguments.`
|
||||
@ -337,14 +382,14 @@ func visitTypesAction(action *srvdesc.Action, visitor typeVisitor) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
visitor(sv.DataType)
|
||||
visitor(sv)
|
||||
}
|
||||
for _, arg := range action.OutArgs {
|
||||
sv, err := arg.RelatedStateVariable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
visitor(sv.DataType)
|
||||
visitor(sv)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user