package main
import (
"fmt"
"net"
"os"
"os/signal"
"path"
"strconv"
"strings"
"sync"
"syscall"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc"
pluginAPI "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
)
const (
connectTimeout = 5 * time.Second
devicesCount = 64
resourceName = "dummy.com/npu"
//pluginSock = "/var/lib/kubelet/device-plugins/dummy.sock"
pluginSock = "dummy.sock"
)
type DummyDevicePlugin struct {
devices map[string]*pluginAPI.Device
socket string
server *grpc.Server
mutex sync.Mutex
sigs chan os.Signal
}
// Create DummyDevicePlugin
func NewDevicePlugin() *DummyDevicePlugin {
dp := &DummyDevicePlugin{
devices: make(map[string]*pluginAPI.Device),
socket: path.Join(pluginAPI.DevicePluginPath, pluginSock),
}
dp.init()
return dp
}
// Init initialize the device plugin
func (dp *DummyDevicePlugin) init() error {
fmt.Println("Initializing device plugin")
dp.devices = make(map[string]*pluginAPI.Device, devicesCount)
for i := 0; i < devicesCount; i++ {
name := strconv.Itoa(i)
dev := &pluginAPI.Device{
ID: name,
Health: pluginAPI.Healthy,
}
dp.devices[name] = dev
}
// Wait for signals
dp.sigs = make(chan os.Signal, 1)
signal.Notify(dp.sigs, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
return nil
}
// discoverDevices get device list
// Monitor and update per health check
func (dp *DummyDevicePlugin) discoverDevices() map[string]*pluginAPI.Device {
healthyDevices := make(map[string]*pluginAPI.Device)
for _, dev := range dp.devices {
if dev.Health == pluginAPI.Healthy {
healthyDevices[dev.ID] = dev
}
}
fmt.Println("Healthy devices found:", len(healthyDevices))
return healthyDevices
}
// Start starts the gRPC server of device plugin
func (dp *DummyDevicePlugin) Start() error {
err := dp.cleanup()
if err != nil {
return fmt.Errorf("failed to clean existing socket file")
}
listen_sock, err := net.Listen("unix", dp.socket)
if err != nil {
return fmt.Errorf("failed to listen on plugin socket")
}
dp.server = grpc.NewServer([]grpc.ServerOption{}...)
pluginAPI.RegisterDevicePluginServer(dp.server, dp)
go dp.server.Serve(listen_sock)
fmt.Println("Device plugin gRPC server begins to serve at:", dp.socket)
// Wait for server to start by launching a blocking connection
conn, err := grpc.Dial(dp.socket, grpc.WithInsecure(), grpc.WithBlock(),
grpc.WithTimeout(connectTimeout),
grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("unix", addr, timeout)
}),
)
if err != nil {
return fmt.Errorf("failed to dial to plugin socket")
}
conn.Close()
go dp.healthCheck()
fmt.Println("Device plugin gRPC server is ready")
return nil
}
// Stop stops the gRPC server
func (dp *DummyDevicePlugin) StopServer() error {
if dp.server == nil {
return nil
}
dp.server.Stop()
dp.server = nil
fmt.Println("Device plugin gRPC server is stopped")
return dp.cleanup()
}
// healthCheck
// TODO: monitor and update devices
func (dp *DummyDevicePlugin) healthCheck() error {
for {
time.Sleep(60 * time.Second)
}
}
func (dp *DummyDevicePlugin) cleanup() error {
if err := os.Remove(dp.socket); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
func (dp *DummyDevicePlugin) exit() {
dp.sigs <- syscall.SIGTERM
}
// Register with kubelet
func Register() error {
conn, err := grpc.Dial(pluginAPI.KubeletSocket, grpc.WithInsecure(),
grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("unix", addr, timeout)
}))
if err != nil {
return fmt.Errorf("failed to connect to kubelet: %v", err)
}
defer conn.Close()
client := pluginAPI.NewRegistrationClient(conn)
req := &pluginAPI.RegisterRequest{
Version: pluginAPI.Version,
// Name of socket that device plugin is listening
Endpoint: pluginSock,
ResourceName: resourceName,
}
_, err = client.Register(context.Background(), req)
if err != nil {
return fmt.Errorf("failed to register to kubelet: %v", err)
}
fmt.Println("Device plugin registers to kubelet")
return nil
}
// ListAndWatch returns a stream of List of Devices
// Whenever a device state change or a device disappears, ListAndWatch returns the new list
// lists devices and update that list according to the health status
func (dp *DummyDevicePlugin) ListAndWatch(emtpy *pluginAPI.Empty, stream pluginAPI.DevicePlugin_ListAndWatchServer) error {
fmt.Println("ListAndWatch starts")
for {
// Build response
resp := new(pluginAPI.ListAndWatchResponse)
healthyDevices := dp.discoverDevices()
for _, dev := range healthyDevices {
resp.Devices = append(resp.Devices, dev)
}
// Send response
//fmt.Println("ListAndWatch sends devices")
if err := stream.Send(resp); err != nil {
fmt.Println("Failed to send devices to kubelet:", err)
// FIXME: Something is wrong when sending devices to kubelet
// How about restart this device plugin
fmt.Println("Since it is failed to communicate with kubelet, let's restart device plugin")
dp.exit()
}
time.Sleep(10 * time.Second)
}
}
// Allocate is called during container creation, so that the Device Plugin can run device specific operations
// and instruct Kubelet of the steps to make the device available in the container
func (dp *DummyDevicePlugin) Allocate(ctx context.Context, reqs *pluginAPI.AllocateRequest) (*pluginAPI.AllocateResponse, error) {
// Get unallocated and healthy device
fmt.Println("Allocate starts")
ret := pluginAPI.AllocateResponse{}
fmt.Println("Recv request:", reqs.ContainerRequests)
for _, req := range reqs.ContainerRequests {
fmt.Println("Recv request DevicesIDs:", req.DevicesIDs)
// Discover healthy devices
healthyDevices := dp.discoverDevices()
if len(healthyDevices) < len(req.DevicesIDs) {
fmt.Println("Number of available devices is less than request devices:", len(healthyDevices), len(req.DevicesIDs))
return nil, fmt.Errorf("invalid allocate request of devices count: %d", len(req.DevicesIDs))
}
// Allocate healthy devices, and change allocated devices to unhealthy
dp.mutex.Lock()
var ids []string
device_allocated := 0
for _, dev := range healthyDevices {
ids = append(ids, dev.ID)
dp.devices[dev.ID].Health = pluginAPI.Unhealthy
device_allocated++
if device_allocated >= len(req.DevicesIDs) {
break
}
}
dp.mutex.Unlock()
// For NV, it passes devices to ENV NVIDIA_VISIBLE_DEVICES
fmt.Println("Allocate devices:", ids)
resp := pluginAPI.ContainerAllocateResponse{
Envs: map[string]string{"DUMMY_VISIBLE_DEVICES": strings.Join(ids, ",")},
}
ret.ContainerResponses = append(ret.ContainerResponses, &resp)
}
return &ret, nil
}
// GetPreferredAllocation returns a preferred set of devices to allocate from a list of available ones.
// The resulting preferred allocation is not guaranteed to be the allocation ultimately performed by the devicemanager.
// It is only designed to help the devicemanager make a more informed allocation decision when possible.
func (dp *DummyDevicePlugin) GetPreferredAllocation(_ context.Context, _ *pluginAPI.PreferredAllocationRequest) (*pluginAPI.PreferredAllocationResponse, error) {
return &pluginAPI.PreferredAllocationResponse{}, nil
}
// GetDevicePluginOptions returns options to be communicated with Device Manager
func (dp *DummyDevicePlugin) GetDevicePluginOptions(context.Context, *pluginAPI.Empty) (*pluginAPI.DevicePluginOptions, error) {
return &pluginAPI.DevicePluginOptions{}, nil
}
// PreStartContainer is called, if indicated by Device Plugin during registeration phase, before each container start.
// Device plugin can run device specific operations such as reseting the device before making devices available to the container
func (dp *DummyDevicePlugin) PreStartContainer(context.Context, *pluginAPI.PreStartContainerRequest) (*pluginAPI.PreStartContainerResponse, error) {
return &pluginAPI.PreStartContainerResponse{}, nil
}
func main() {
dp := NewDevicePlugin()
// Start grpc server
err := dp.Start()
if err != nil {
fmt.Println("Failed to start device plugin:", err)
}
fmt.Println("Start to server at:", dp.socket)
// Registers with Kubelet.
err = Register()
if err != nil {
fmt.Println("Failed to register device plugin:", err)
}
fmt.Println("Device plugin is registered")
// TODO: Watch kubelet sock file
//err = dp.watchKubelet()
//if err != nil {
// fmt.Println("Failed to watch kubelet:", err)
//}
s := <-dp.sigs
fmt.Println("Receive signal and will exit:", s)
dp.StopServer()
}
评论