diff --git a/go/internal/feast/metrics/metrics.go b/go/internal/feast/metrics/metrics.go index 804eef6fa1b..d4783f257b7 100644 --- a/go/internal/feast/metrics/metrics.go +++ b/go/internal/feast/metrics/metrics.go @@ -30,7 +30,6 @@ var ( TimeHistogramType = reflect.TypeOf((*TimeHistogram)(nil)).Elem() ) - func RegisterTimeHistogram(name, help, namespace string, labelNames []string, tag reflect.StructTag) (func(prometheus.Labels) interface{}, prometheus.Collector, error) { f, collector, err := prometheusvanilla.BuildHistogram(name, help, namespace, labelNames, tag) if err != nil { diff --git a/go/internal/feast/onlinestore/postgresonlinestore.go b/go/internal/feast/onlinestore/postgresonlinestore.go index 4813f341db7..4077a9e06fa 100644 --- a/go/internal/feast/onlinestore/postgresonlinestore.go +++ b/go/internal/feast/onlinestore/postgresonlinestore.go @@ -194,4 +194,4 @@ func buildPostgresConnString(config map[string]interface{}) string { } return connURL.String() -} \ No newline at end of file +} diff --git a/go/internal/feast/server/http_server.go b/go/internal/feast/server/http_server.go index adfd40110e7..2f60444c849 100644 --- a/go/internal/feast/server/http_server.go +++ b/go/internal/feast/server/http_server.go @@ -8,6 +8,7 @@ import ( "runtime" "strconv" "time" + "crypto/tls" "github.com/feast-dev/feast/go/internal/feast" "github.com/feast-dev/feast/go/internal/feast/model" @@ -396,6 +397,37 @@ func (s *httpServer) Serve(host string, port int) error { return err } +func (s *httpServer) ServeTLS(host string, port int, certFile string, keyFile string) error { + mux := http.NewServeMux() + mux.Handle("/get-online-features", metricsMiddleware(recoverMiddleware(http.HandlerFunc(s.getOnlineFeatures)))) + mux.Handle("/health", metricsMiddleware(http.HandlerFunc(healthCheckHandler))) + s.server = &http.Server{ + Addr: fmt.Sprintf("%s:%d", host, port), + Handler: mux, + ReadTimeout: 5 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 15 * time.Second, + TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + // For production, use proper certificates + // Prefer server's cipher suites + PreferServerCipherSuites: true, + CurvePreferences: []tls.CurveID{ + tls.CurveP256, + tls.X25519MLKEM768, + tls.SecP256r1MLKEM768, + }, + }, + } + err := s.server.ListenAndServeTLS(certFile, keyFile) + // Don't return the error if it's caused by graceful shutdown using Stop() + if err == http.ErrServerClosed { + return nil + } + log.Fatal().Stack().Err(err).Msg("Failed to start HTTPS server") + return err +} + func healthCheckHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) fmt.Fprintf(w, "Healthy") diff --git a/go/main.go b/go/main.go index f49a27efa46..a0e72720332 100644 --- a/go/main.go +++ b/go/main.go @@ -11,6 +11,7 @@ import ( "strings" "sync" "syscall" + "time" "github.com/feast-dev/feast/go/internal/feast" "github.com/feast-dev/feast/go/internal/feast/registry" @@ -47,6 +48,10 @@ func (s *RealServerStarter) StartHttpServer(fs *feast.FeatureStore, host string, return StartHttpServer(fs, host, port, metricsPort, writeLoggedFeaturesCallback, loggingOpts) } +func (s *RealServerStarter) StartHttpsServer(fs *feast.FeatureStore, host string, port int, metricsPort int, certFile string, keyFile string, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback, loggingOpts *logging.LoggingOptions) error { + return StartHttpsServer(fs, host, port, metricsPort, certFile, keyFile, writeLoggedFeaturesCallback, loggingOpts) +} + func (s *RealServerStarter) StartGrpcServer(fs *feast.FeatureStore, host string, port int, metricsPort int, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback, loggingOpts *logging.LoggingOptions) error { return StartGrpcServer(fs, host, port, metricsPort, writeLoggedFeaturesCallback, loggingOpts) } @@ -58,18 +63,22 @@ func main() { port := 8080 metricsPort := 9090 server := RealServerStarter{} + certFile := "" + keyFile := "" // Current Directory repoPath, err := os.Getwd() if err != nil { log.Error().Stack().Err(err).Msg("Failed to get current directory") } - flag.StringVar(&serverType, "type", serverType, "Specify the server type (http or grpc)") + flag.StringVar(&serverType, "type", serverType, "Specify the server type (http, https or grpc)") flag.StringVar(&repoPath, "chdir", repoPath, "Repository path where feature store yaml file is stored") flag.StringVar(&host, "host", host, "Specify a host for the server") flag.IntVar(&port, "port", port, "Specify a port for the server") flag.IntVar(&metricsPort, "metrics-port", metricsPort, "Specify a port for the metrics server") + flag.StringVar(&certFile, "tls-cert-file", "", "Path to the TLS certificate file") + flag.StringVar(&keyFile, "tls-key-file", "", "Path to the TLS key file") flag.Parse() // Initialize tracer @@ -119,8 +128,10 @@ func main() { err = server.StartHttpServer(fs, host, port, metricsPort, nil, loggingOptions) } else if serverType == "grpc" { err = server.StartGrpcServer(fs, host, port, metricsPort, nil, loggingOptions) + } else if serverType == "https" { + err = server.StartHttpsServer(fs, host, port, metricsPort, certFile, keyFile, nil, loggingOptions) } else { - fmt.Println("Unknown server type. Please specify 'http' or 'grpc'.") + fmt.Println("Unknown server type. Please specify 'http' or 'grpc' or 'https'.") } if err != nil { @@ -234,13 +245,16 @@ func StartHttpServer(fs *feast.FeatureStore, host string, port int, metricsPort } ser := server.NewHttpServer(fs, loggingService) log.Info().Msgf("Starting a HTTP server on host %s, port %d", host, port) + // Start metrics server - metricsServer := &http.Server{Addr: fmt.Sprintf(":%d", metricsPort)} + mux := http.NewServeMux() + mux.Handle("/metrics", promhttp.Handler()) + metricsServer := &http.Server{ + Addr: fmt.Sprintf(":%d", metricsPort), + Handler: mux, + } go func() { log.Info().Msgf("Starting metrics server on port %d", metricsPort) - mux := http.NewServeMux() - mux.Handle("/metrics", promhttp.Handler()) - metricsServer.Handler = mux if err := metricsServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Error().Err(err).Msg("Failed to start metrics server") } @@ -248,6 +262,7 @@ func StartHttpServer(fs *feast.FeatureStore, host string, port int, metricsPort stop := make(chan os.Signal, 1) signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(stop) var wg sync.WaitGroup wg.Add(1) @@ -263,7 +278,9 @@ func StartHttpServer(fs *feast.FeatureStore, host string, port int, metricsPort log.Error().Err(err).Msg("Error when stopping the HTTP server") } log.Info().Msg("Stopping metrics server...") - if err := metricsServer.Shutdown(context.Background()); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := metricsServer.Shutdown(ctx); err != nil { log.Error().Err(err).Msg("Error stopping metrics server") } if loggingService != nil { @@ -320,3 +337,72 @@ func newTracerProvider(exp sdktrace.SpanExporter) (*sdktrace.TracerProvider, err sdktrace.WithResource(r), ), nil } + +// StartHttpsServer starts HTTP server with TLS. Requires TLS_CERT_FILE and TLS_KEY_FILE env vars. +func StartHttpsServer(fs *feast.FeatureStore, host string, port int, metricsPort int, certFile string, keyFile string, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback, loggingOpts *logging.LoggingOptions) error { + if certFile == "" || keyFile == "" { + return fmt.Errorf("TLS_CERT_FILE and TLS_KEY_FILE must be set") + } + + loggingService, err := constructLoggingService(fs, writeLoggedFeaturesCallback, loggingOpts) + if err != nil { + return err + } + ser := server.NewHttpServer(fs, loggingService) + log.Info().Msgf("Starting a HTTPS server on host %s, port %d", host, port) + + // Start metrics server + mux := http.NewServeMux() + mux.Handle("/metrics", promhttp.Handler()) + metricsServer := &http.Server{ + Addr: fmt.Sprintf(":%d", metricsPort), + Handler: mux, + } + go func() { + log.Info().Msgf("Starting metrics server on port %d", metricsPort) + if err := metricsServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Error().Err(err).Msg("Failed to start metrics server") + } + }() + + stop := make(chan os.Signal, 1) + signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(stop) + + var wg sync.WaitGroup + wg.Add(1) + serverExited := make(chan struct{}) + go func() { + defer wg.Done() + select { + case <-stop: + // Received SIGINT/SIGTERM. Perform graceful shutdown. + log.Info().Msg("Stopping the HTTP server...") + err := ser.Stop() + if err != nil { + log.Error().Err(err).Msg("Error when stopping the HTTP server") + } + log.Info().Msg("Stopping metrics server...") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := metricsServer.Shutdown(ctx); err != nil { + log.Error().Err(err).Msg("Error stopping metrics server") + } + if loggingService != nil { + loggingService.Stop() + } + log.Info().Msg("HTTP server terminated") + case <-serverExited: + // Server exited (e.g. startup error), ensure metrics server is stopped + metricsServer.Shutdown(context.Background()) + if loggingService != nil { + loggingService.Stop() + } + } + }() + + err = ser.ServeTLS(host, port, certFile, keyFile) + close(serverExited) + wg.Wait() + return err +}