2023-05-05 15:07:29 +00:00
|
|
|
package oci
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
|
|
|
|
v1 "github.com/opencontainers/image-spec/specs-go/v1"
|
2023-05-28 12:39:49 +00:00
|
|
|
"go.uber.org/zap"
|
2023-05-05 15:07:29 +00:00
|
|
|
"oras.land/oras-go/v2"
|
2023-05-28 12:39:49 +00:00
|
|
|
"oras.land/oras-go/v2/content"
|
2023-05-05 15:07:29 +00:00
|
|
|
"oras.land/oras-go/v2/content/file"
|
2023-05-28 12:39:49 +00:00
|
|
|
"oras.land/oras-go/v2/registry"
|
2023-05-05 15:07:29 +00:00
|
|
|
"oras.land/oras-go/v2/registry/remote"
|
2023-05-28 12:39:49 +00:00
|
|
|
"path"
|
|
|
|
"strconv"
|
2023-05-05 15:07:29 +00:00
|
|
|
)
|
|
|
|
|
2023-05-28 12:39:49 +00:00
|
|
|
func PullOciImage(ctx context.Context, regName, repoName, tag, modelsDir string) (modelPath string, modelType tools.ModelType, imgWidth, imgHeight int, horizon int, err error) {
|
2023-05-05 15:07:29 +00:00
|
|
|
|
2023-05-28 12:39:49 +00:00
|
|
|
repo, err := getRepository(ctx, regName, repoName)
|
2023-05-05 15:07:29 +00:00
|
|
|
if err != nil {
|
2023-05-28 12:39:49 +00:00
|
|
|
err = fmt.Errorf("unable to fetch oci artifact from '%s/%s: %w", regName, repoName, err)
|
2023-05-05 15:07:29 +00:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2023-05-28 12:39:49 +00:00
|
|
|
manifest, err := fetchManifest(ctx, repo, tag)
|
2023-05-05 15:07:29 +00:00
|
|
|
if err != nil {
|
2023-05-28 12:39:49 +00:00
|
|
|
err = fmt.Errorf("unable to fetch manifest '%s/%s:%s': %w", regName, repoName, tag, err)
|
2023-05-05 15:07:29 +00:00
|
|
|
return
|
|
|
|
}
|
2023-05-28 12:39:49 +00:00
|
|
|
zap.S().Infof("Manifest: %v", manifest)
|
2023-05-05 15:07:29 +00:00
|
|
|
|
2023-05-28 12:39:49 +00:00
|
|
|
// 0. Create a file store
|
|
|
|
modelStore := path.Join(modelsDir, manifest.Annotations["category"])
|
|
|
|
fs, err := file.New(modelStore)
|
2023-05-05 15:07:29 +00:00
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
2023-05-28 12:39:49 +00:00
|
|
|
defer fs.Close()
|
2023-05-05 15:07:29 +00:00
|
|
|
|
2023-05-28 12:39:49 +00:00
|
|
|
// 2. Copy from the remote repoName to the file store
|
2023-05-05 15:07:29 +00:00
|
|
|
_, err = oras.Copy(ctx, repo, tag, fs, tag, oras.DefaultCopyOptions)
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
modelType = tools.ParseModelType(manifest.Annotations["type"])
|
|
|
|
imgWidth, err = strconv.Atoi(manifest.Annotations["img_width"])
|
|
|
|
if err != nil {
|
|
|
|
err = fmt.Errorf("unable to convert image width '%v' to integer: %w", manifest.Annotations["img_width"], err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
imgHeight, err = strconv.Atoi(manifest.Annotations["img_height"])
|
|
|
|
if err != nil {
|
|
|
|
err = fmt.Errorf("unable to convert image height '%v' to integer: %w", manifest.Annotations["img_height"], err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if _, ok := manifest.Annotations["horizon"]; ok {
|
|
|
|
horizon, err = strconv.Atoi(manifest.Annotations["horizon"])
|
|
|
|
if err != nil {
|
|
|
|
err = fmt.Errorf("unable to convert horizon '%v' to integer: %v", manifest.Annotations["horizon"], err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
horizon = 0
|
|
|
|
}
|
|
|
|
modelPath = path.Join(modelStore, manifest.Layers[0].Annotations["org.opencontainers.image.title"])
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2023-05-28 12:39:49 +00:00
|
|
|
func getRepository(ctx context.Context, registryName string, repoName string) (registry.Repository, error) {
|
|
|
|
|
|
|
|
reg, err := remote.NewRegistry(registryName)
|
2023-05-05 15:07:29 +00:00
|
|
|
if err != nil {
|
2023-05-28 12:39:49 +00:00
|
|
|
return nil, fmt.Errorf("bad registry '%v': %w", registryName, err)
|
2023-05-05 15:07:29 +00:00
|
|
|
}
|
2023-05-28 12:39:49 +00:00
|
|
|
reg.RepositoryOptions.PlainHTTP = true
|
|
|
|
|
|
|
|
// For debug
|
|
|
|
//reg.Repositories(ctx, "", func(repos []string) error {
|
|
|
|
// for _, r := range repos {
|
|
|
|
// zap.S().Debugf("found repo %v", r)
|
|
|
|
// }
|
|
|
|
// return nil
|
|
|
|
//})
|
|
|
|
|
|
|
|
repo, err := reg.Repository(ctx, repoName)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("unable to instanciate new repository: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// For debug
|
|
|
|
/*
|
|
|
|
repo.Tags(ctx, "", func(tags []string) error {
|
|
|
|
for _, t := range tags {
|
|
|
|
zap.S().Debugf("found tag '%v'", t)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
})
|
|
|
|
*/
|
|
|
|
return repo, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func fetchManifest(ctx context.Context, repo registry.Repository, tag string) (*v1.Manifest, error) {
|
2023-05-05 15:07:29 +00:00
|
|
|
|
|
|
|
descriptor, err := repo.Resolve(ctx, tag)
|
2023-05-28 12:39:49 +00:00
|
|
|
zap.S().Debugf("model descriptor: %#v", descriptor)
|
2023-05-05 15:07:29 +00:00
|
|
|
if err != nil {
|
2023-05-28 12:39:49 +00:00
|
|
|
return nil, fmt.Errorf("unexpected error on tag resolving: %w", err)
|
2023-05-05 15:07:29 +00:00
|
|
|
}
|
|
|
|
rc, err := repo.Fetch(ctx, descriptor)
|
|
|
|
if err != nil {
|
2023-05-28 12:39:49 +00:00
|
|
|
return nil, fmt.Errorf("unable to fetch manifest for image '%s:%s': %w", repo, tag, err)
|
2023-05-05 15:07:29 +00:00
|
|
|
}
|
|
|
|
defer rc.Close() // don't forget to close
|
|
|
|
pulledBlob, err := content.ReadAll(rc, descriptor)
|
|
|
|
if err != nil {
|
2023-05-28 12:39:49 +00:00
|
|
|
return nil, fmt.Errorf("unable to read manifest content for image '%s:%s': %w", repo, tag, err)
|
2023-05-05 15:07:29 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
var manifest v1.Manifest
|
|
|
|
err = json.Unmarshal(pulledBlob, &manifest)
|
|
|
|
if err != nil {
|
2023-05-28 12:39:49 +00:00
|
|
|
return nil, fmt.Errorf("unable to unmarsh json manifest content for image '%s:%s': %w", repo, tag, err)
|
2023-05-05 15:07:29 +00:00
|
|
|
}
|
|
|
|
return &manifest, nil
|
|
|
|
}
|