package service import ( "bytes" "context" "crypto/tls" "encoding/json" "fmt" "io" "net/http" "net/url" "os" "strconv" "strings" "sync" "github.com/rs/zerolog/log" "gitlab.com/tensorsecurity-rd/waf-console/internal/model" "gitlab.com/tensorsecurity-rd/waf-console/internal/utils" "gitlab.com/tensorsecurity-rd/waf-console/pkg/apis/waf.security.io/v1alpha1" "gorm.io/gorm" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/sets" ) type wafService struct { clusterClientManager *utils.ClusterClientManager db *gorm.DB gatewayUrl string } func NewWafService(clusterClientManager *utils.ClusterClientManager, db *gorm.DB, gatewayUrl string) Service { return &wafService{clusterClientManager: clusterClientManager, db: db, gatewayUrl: gatewayUrl} } func (s *wafService) GetWaf(ctx context.Context, regionCode, namespace, gatewayName string) (*WafService, error) { wafService := &model.WafService{} err := s.db.Model(&model.WafService{}).Where("gateway_name = ? AND region_code = ? AND namespace = ?", gatewayName, regionCode, namespace).First(wafService).Error if err != nil { if err == gorm.ErrRecordNotFound { // Create new WAF service record if not found wafService = &model.WafService{ RegionCode: regionCode, Namespace: namespace, GatewayName: gatewayName, Mode: string(WafModeAlert), } if err := s.db.Create(wafService).Error; err != nil { return nil, fmt.Errorf("failed to create WAF service: %v", err) } } else { return nil, fmt.Errorf("failed to query WAF service: %v", err) } } listenerWafs, err := s.ListListenerWafStatus(ctx, &GatewateInfo{ GatewayName: gatewayName, Namespace: namespace, RegionCode: regionCode, }) if err != nil { return nil, fmt.Errorf("failed to list listener WAF status: %v", err) } listeners := []string{} for _, listener := range listenerWafs { hosts := strings.Join(listener.Hosts, "@") listeners = append(listeners, fmt.Sprintf("%s-%d", hosts, listener.Port)) } return &WafService{ GatewayName: wafService.GatewayName, Mode: wafService.Mode, RuleNum: wafService.RuleNum, AttackNum: wafService.AttackNum, Listeners: listeners, }, nil } func (s *wafService) GetWafGatewayInfo(ctx context.Context, req *GetWafGatewayInfoReq) (*WafService, error) { wafService := &model.WafService{} err := s.db.Model(&model.WafService{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(wafService).Error if err != nil { if err == gorm.ErrRecordNotFound { httpRequst := http.Request{ Method: http.MethodPost, URL: &url.URL{Scheme: "https", Host: "console.tensorsecurity.com", Path: "/api/v1/waf/gateway"}, Header: http.Header{ "Cookie": []string{req.Cookie}, }, Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{"gateway_name": "%s", "namespace": "%s", "region_code": "%s"}`, req.GatewayName, req.Namespace, req.RegionCode))), } resp, err := http.DefaultClient.Do(&httpRequst) if err != nil { return nil, fmt.Errorf("failed to get WAF service: %v", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read WAF service: %v", err) } var wafService model.WafService err = json.Unmarshal(body, &wafService) if err != nil { return nil, fmt.Errorf("failed to unmarshal WAF service: %v", err) } wafService.ID = 0 wafService.RuleCategoryStatus = nil wafService.RuleNum = 0 wafService.AttackNum = 0 // wafService.Host = model.HostList([]string{"*"}) wafService.Mode = string(WafModeAlert) err = s.db.Create(wafService).Error if err != nil { return nil, fmt.Errorf("failed to create WAF service: %v", err) } } else { return nil, fmt.Errorf("failed to query WAF service: %v", err) } } return &WafService{ GatewayName: wafService.GatewayName, Mode: wafService.Mode, RuleNum: wafService.RuleNum, AttackNum: wafService.AttackNum, }, nil } func (s *wafService) getRulesForService(req *CreateWafReq) ([]v1alpha1.Rule, error) { rules := []v1alpha1.Rule{} ruleCategories := []model.WafRuleCategory{} if err := s.db.Model(&model.WafRuleCategory{}).Where("status = ?", 0).Find(&ruleCategories).Error; err != nil { return nil, fmt.Errorf("failed to get rule categories: %v", err) } // Get existing WAF service config if any wafService := &model.WafService{} err := s.db.Model(&model.WafService{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(wafService).Error if err != nil { if err == gorm.ErrRecordNotFound { // Create new WAF service record if not found wafService = &model.WafService{ RegionCode: req.RegionCode, Namespace: req.Namespace, GatewayName: req.GatewayName, Mode: string(WafModeAlert), } if err := s.db.Create(wafService).Error; err != nil { return nil, fmt.Errorf("failed to create WAF service: %v", err) } } else { return nil, fmt.Errorf("failed to query WAF service: %v", err) } } // Determine which rule categories to enable var enabledCategories []model.WafRuleCategory if wafService.RuleCategoryStatus != nil && len(wafService.RuleCategoryStatus.CategoryID) > 0 { // Only include categories not already enabled for _, category := range ruleCategories { for _, id := range wafService.RuleCategoryStatus.CategoryID { if id == category.CategoryID { enabledCategories = append(enabledCategories, category) continue } } } } else { // Enable all categories if none specified enabledCategories = ruleCategories } for _, category := range enabledCategories { for _, rule := range category.Rules { rules = append(rules, v1alpha1.Rule{ ID: rule.ID, Level: rule.Level, Name: rule.Name, Type: rule.Type, Description: rule.Description, Expr: rule.Expr, Mode: rule.Mode, }) } } return rules, nil } func (s *wafService) CreateWaf(ctx context.Context, req *CreateWafReq) (*WafService, error) { // Create the WAF service resource name := fmt.Sprintf("%s-%d", req.GatewayName, req.Port) service := &v1alpha1.Service{ ObjectMeta: metav1.ObjectMeta{ Name: name, Namespace: req.Namespace, Labels: map[string]string{ "apigateway_name": req.GatewayName, }, }, Spec: v1alpha1.ServiceSpec{ HostNames: req.Host, ServiceName: req.GatewayName, Port: req.Port, Workload: v1alpha1.WorkloadRef{ Kind: req.GatewayName, Name: req.ListenerName, Namespace: req.Namespace, ClusterKey: req.RegionCode, }, Uri: &v1alpha1.StringMatch{ Prefix: "/", }, LogConfig: &v1alpha1.LogConfig{ Enable: 1, Level: "info", }, Mode: string(req.Mode), }, } rules, err := s.getRulesForService(req) if err != nil { return nil, fmt.Errorf("failed to get rules for service: %v", err) } service.Spec.Rules = rules if len(service.Spec.Rules) == 0 { return nil, fmt.Errorf("cannot create WAF service with no rules") } // Create the WAF service in Kubernetes client := s.clusterClientManager.GetClient(req.RegionCode) if client == nil { return nil, fmt.Errorf("failed to get cluster client for region %s", req.RegionCode) } if _, err := client.WafV1alpha1().Services(req.Namespace).Create(ctx, service, metav1.CreateOptions{}); err != nil { return nil, fmt.Errorf("failed to create WAF service: %v", err) } return &WafService{ GatewayName: req.GatewayName, Mode: service.Spec.Mode, RuleNum: len(service.Spec.Rules), AttackNum: 0, }, nil } func (s *wafService) DeleteListenerWaf(ctx context.Context, req *DeleteListenerReq) error { client := s.clusterClientManager.GetClient(req.RegionCode) if client == nil { return fmt.Errorf("failed to get cluster client for region %s", req.RegionCode) } name := fmt.Sprintf("%s-%d", req.GatewayName, req.Port) if err := client.WafV1alpha1().Services(req.Namespace).Delete(ctx, name, metav1.DeleteOptions{}); err != nil { return fmt.Errorf("failed to delete WAF service: %v", err) } return nil } func (s *wafService) UpdateMode(ctx context.Context, req *UpdateModeReq) (*WafService, error) { // Check if WAF service exists wafService := &model.WafService{} err := s.db.Model(&model.WafService{}).Where("gateway_name = ?", req.GatewayName).First(wafService).Error if err != nil { if err == gorm.ErrRecordNotFound { // Create new WAF service record if not found wafService = &model.WafService{ RegionCode: req.RegionCode, Namespace: req.Namespace, GatewayName: req.GatewayName, Mode: string(req.Mode), } if err := s.db.Create(wafService).Error; err != nil { return nil, fmt.Errorf("failed to create WAF service: %v", err) } } else { return nil, fmt.Errorf("failed to query WAF service: %v", err) } } else { // Update mode if service exists if err := s.db.Model(wafService).Update("mode", string(req.Mode)).Error; err != nil { return nil, fmt.Errorf("failed to update WAF service mode: %v", err) } } // Update mode for each listener client := s.clusterClientManager.GetClient(req.RegionCode) if client == nil { return nil, fmt.Errorf("failed to get cluster client for region %s", req.RegionCode) } listenerList, err := client.WafV1alpha1().Services(req.Namespace).List(ctx, metav1.ListOptions{LabelSelector: fmt.Sprintf("apigateway_name=%s", req.GatewayName)}) if err != nil { return nil, fmt.Errorf("failed to get listener list: %v", err) } var wg sync.WaitGroup for _, listener := range listenerList.Items { wg.Add(1) listener := listener // Create new variable for goroutine listener.Spec.Mode = string(req.Mode) go func() { defer wg.Done() log.Info().Msgf("update WAF service mode: %v", listener.Name) _, err := client.WafV1alpha1().Services(req.Namespace).Update(ctx, &listener, metav1.UpdateOptions{}) if err != nil { log.Error().Msgf("failed to update WAF service mode: %v", err) } }() } wg.Wait() return &WafService{ GatewayName: req.GatewayName, Mode: string(req.Mode), }, nil } func (s *wafService) GetRuleCategories(ctx context.Context) ([]WafRuleCategory, error) { var categories []WafRuleCategory err := s.db.Table("waf_rule_categories").Find(&categories).Error if err != nil { return nil, err } return categories, nil } func (s *wafService) GetRules(ctx context.Context, categoryID string) ([]WafRule, error) { var rules []WafRule err := s.db.Table("waf_rules").Where("category_id = ?", categoryID).Find(&rules).Error if err != nil { return nil, err } return rules, nil } func (s *wafService) GetRule(ctx context.Context, ruleID int) (*WafRule, error) { var rule WafRule err := s.db.Table("waf_rules").Where("id = ?", ruleID).First(&rule).Error if err != nil { return nil, err } return &rule, nil } func (s *wafService) SaveRuleCategoryToDB(ctx context.Context) error { var categories []WafRuleCategory jsonFile, err := os.ReadFile("rules/waf-rules.json") if err != nil { return fmt.Errorf("error reading yaml file: %v", err) } // err = yaml.Unmarshal(yamlFile, &categories) err = json.Unmarshal(jsonFile, &categories) if err != nil { return fmt.Errorf("error unmarshaling yaml: %v", err) } for _, category := range categories { rules := []model.WafRule{} for _, rule := range category.Rules { rules = append(rules, model.WafRule{ ID: rule.ID, CategoryID: category.CategoryID, Level: rule.Level, Name: rule.Name, Type: rule.Type, Description: rule.Description, Expr: rule.Expr, Mode: rule.Mode, }) } model := model.WafRuleCategory{ CategoryID: category.CategoryID, Status: category.Status, CategoryEN: category.Catagory.EN, CategoryZH: category.Catagory.Zh, DescriptionEN: category.Description.EN, DescriptionZH: category.Description.Zh, Rules: model.RuleList(rules), } err = s.db.Table("waf_rule_categories").Create(&model).Error if err != nil { return err } } return nil } func (s *wafService) DeleteListener(ctx context.Context, req *DeleteListenerReq) error { listener := &model.GatewayListener{} err := s.db.Model(&model.GatewayListener{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(listener).Error if err != nil { return err } err = s.db.Model(&model.GatewayListener{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).Delete(listener).Error if err != nil { return err } return nil } func (s *wafService) EnableListenerWaf(ctx context.Context, req *EnableListenerWafReq) error { // listener := &model.GatewayListener{} // err := s.db.Model(&model.GatewayListener{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(listener).Error // if err != nil { // if err == gorm.ErrRecordNotFound { // listener = &model.GatewayListener{ // GatewayName: req.GatewayName, // Namespace: req.Namespace, // RegionCode: req.RegionCode, // Port: int(req.Port), // Enable: req.Enable, // Hosts: req.Hosts, // } // err = s.db.Model(&model.GatewayListener{}).Create(listener).Error // if err != nil { // return err // } // } else { // return err // } // } // listener.Enable = req.Enable // err = s.db.Model(&model.GatewayListener{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).Update("enable", req.Enable).Error // if err != nil { // return err // } if req.Enable { log.Info().Msgf("Create WAF for listener %s", req.GatewayName) _, err := s.CreateWaf(ctx, &CreateWafReq{ GatewateInfo: GatewateInfo{ GatewayName: req.GatewayName, Namespace: req.Namespace, RegionCode: req.RegionCode, }, Port: uint32(req.Port), Host: req.Hosts, Mode: req.Mode, ListenerName: req.ListenerName, }) if err != nil { return err } } else { log.Info().Msgf("Delete WAF for listener %s", req.GatewayName) err := s.DeleteListenerWaf(ctx, &DeleteListenerReq{ GatewateInfo: GatewateInfo{ GatewayName: req.GatewayName, Namespace: req.Namespace, RegionCode: req.RegionCode, }, Port: req.Port, }) if err != nil { return err } } return nil } func getGatewayNameFromCrn(crn string) string { // crn:ucs::apigateway:lf-tst7:214613666997:instance/testaaa parts := strings.Split(crn, "/") return parts[len(parts)-1] } func (s *wafService) listListenerFromApiGateway(ctx context.Context, apiGatewayCrn string, regionCode string, cookie string) ([]GatewayRespListenerData, error) { body, err := json.Marshal(map[string]string{ "apigateway_crn": apiGatewayCrn, "region_code": regionCode, }) if err != nil { return nil, fmt.Errorf("failed to marshal request body: %v", err) } request, err := http.NewRequestWithContext(ctx, "POST", "https://csm.console.test.tg.unicom.local/apigatewaymng/listener/lf-tst7/list_listeners", bytes.NewBuffer(body)) if err != nil { return nil, fmt.Errorf("failed to create request: %v", err) } request.Header.Set("Cookie", cookie) // Create custom transport with TLS config tr := &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, // Skip certificate verification for test environment }, } client := &http.Client{Transport: tr} resp, err := client.Do(request) if err != nil { return nil, fmt.Errorf("failed to get listener list: %v", err) } defer resp.Body.Close() log.Info().Msgf("resp: %v", resp) // Parse response var response GatewayListenerResponseList if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { return nil, fmt.Errorf("failed to parse listener list: %v", err) } log.Info().Msgf("response: %v", response) return response.Data, nil } func (s *wafService) EnableGatewayWaf(ctx context.Context, req *EnableGatewayWafReq) error { if req.Enable { listeners, err := s.listListenerFromApiGateway(ctx, req.ApiGatewayCrn, req.RegionCode, req.Cookie) if err != nil { return fmt.Errorf("failed to get listener list: %v", err) } log.Info().Msgf("listeners: %v", listeners) // Create WAF for each listener for _, listener := range listeners { gatewayName := getGatewayNameFromCrn(listener.ApiGatewayCrn) namespace := fmt.Sprintf("%s-%s", listener.CreateAccountName, listener.CreateAccountID) if _, err := s.CreateWaf(ctx, &CreateWafReq{ GatewateInfo: GatewateInfo{ GatewayName: gatewayName, Namespace: namespace, RegionCode: req.RegionCode, }, Port: uint32(listener.Port), Host: listener.Hosts, }); err != nil { return fmt.Errorf("failed to create WAF for listener %d: %v", listener.Port, err) } } } else { s.DeleteGatewayWaf(ctx, &GatewateInfo{ GatewayName: req.GatewayName, Namespace: req.Namespace, RegionCode: req.RegionCode, }) } return nil } func (s *wafService) DeleteGatewayWaf(ctx context.Context, req *GatewateInfo) error { client := s.clusterClientManager.GetClient(req.RegionCode) if client == nil { return fmt.Errorf("failed to get cluster client") } labelSelector := fmt.Sprintf("apigateway_name=%s", req.GatewayName) if err := client.WafV1alpha1().Services(req.Namespace).DeleteCollection(ctx, metav1.DeleteOptions{}, metav1.ListOptions{LabelSelector: labelSelector}); err != nil { return fmt.Errorf("failed to delete WAF service: %v", err) } return nil } func (s *wafService) UpdateRule(ctx context.Context, req *RuleRequest) error { wafService := &model.WafService{} err := s.db.Model(&model.WafService{}).Where("gateway_name = ?", req.GatewayName).First(wafService).Error if err != nil { if err == gorm.ErrRecordNotFound { // Create new WAF service record if not found wafService = &model.WafService{ RegionCode: req.RegionCode, Namespace: req.Namespace, GatewayName: req.GatewayName, Mode: string(WafModeAlert), RuleCategoryStatus: &model.RuleCategoryStatus{ CategoryID: req.CategoryID, Status: req.Status, }, } if err := s.db.Create(wafService).Error; err != nil { return fmt.Errorf("failed to create WAF service: %v", err) } } else { return fmt.Errorf("failed to query WAF service: %v", err) } } else { // Update mode if service exists if err := s.db.Model(wafService).Update("rule_category_status", model.RuleCategoryStatus{ CategoryID: req.CategoryID, Status: req.Status, }).Error; err != nil { return fmt.Errorf("failed to update WAF service mode: %v", err) } } return nil } func (s *wafService) ListListenerWafStatus(ctx context.Context, req *GatewateInfo) ([]*GatewayListener, error) { client := s.clusterClientManager.GetClient(req.RegionCode) if client == nil { return nil, fmt.Errorf("failed to get cluster client") } listenerList, err := client.WafV1alpha1().Services(req.Namespace).List(ctx, metav1.ListOptions{LabelSelector: fmt.Sprintf("apigateway_name=%s", req.GatewayName)}) if err != nil { return nil, fmt.Errorf("failed to get listener list: %v", err) } listenerStatusList := []*GatewayListener{} for _, listener := range listenerList.Items { n := strings.LastIndex(listener.Name, "-") if n == -1 { return nil, fmt.Errorf("failed to get listener port: %v", listener.Name) } listenerPort := listener.Name[n+1:] listenerPortInt, err := strconv.Atoi(listenerPort) if err != nil { return nil, fmt.Errorf("failed to parse listener port: %v", err) } // hosts := strings.Join(listener.Spec.HostNames, "@") // log.Info().Msgf("hosts: %v", hosts) listenerStatusList = append(listenerStatusList, &GatewayListener{ GatewayName: req.GatewayName, Namespace: req.Namespace, RegionCode: req.RegionCode, Port: listenerPortInt, Hosts: listener.Spec.HostNames, }) } // for _, port := range portList { // listenerStatusList = append(listenerStatusList, &GatewayListener{ // GatewayName: req.GatewayName, // Namespace: req.Namespace, // RegionCode: req.RegionCode, // Port: port, // Enable: true, // }) // } return listenerStatusList, nil } func (s *wafService) EnableListenerWafs(ctx context.Context, req *EnableListenerWafsReq) error { client := s.clusterClientManager.GetClient(req.RegionCode) if client == nil { return fmt.Errorf("failed to get cluster client") } listenerList, err := client.WafV1alpha1().Services(req.Namespace).List(ctx, metav1.ListOptions{LabelSelector: fmt.Sprintf("apigateway_name=%s", req.GatewayName)}) if err != nil { log.Error().Msgf("failed to get listener list: %v", err) return err } portList := []int{} for _, listener := range listenerList.Items { n := strings.LastIndex(listener.Name, "-") if n == -1 { return fmt.Errorf("failed to get listener port: %v", listener.Name) } listenerPort := listener.Name[n+1:] listenerPortInt, err := strconv.Atoi(listenerPort) if err != nil { return fmt.Errorf("failed to parse listener port: %v", err) } portList = append(portList, listenerPortInt) } currentPortSet := sets.NewInt(portList...) desiredPortSet := sets.NewInt() wafMap := map[int]ListenerWaf{} for _, listener := range req.Listeners { // get port from listener.HostsAndPort, like hosts1@127.0.0.1@abc.com-8080 index := strings.LastIndex(listener.HostsAndPort, "-") if index == -1 { return fmt.Errorf("failed to get listener port: %v", listener) } port := listener.HostsAndPort[index+1:] portInt, err := strconv.Atoi(port) if err != nil { return fmt.Errorf("failed to parse listener port: %v", err) } desiredPortSet.Insert(portInt) log.Info().Msgf("listener: %v", listener.Name) hosts := strings.Split(listener.HostsAndPort[:index], "@") wafMap[portInt] = ListenerWaf{ Hosts: hosts, HostsAndPort: listener.HostsAndPort, Name: listener.Name, } } // enable WAF for ports that are in the desired port set but not in the current port set addingPortSet := desiredPortSet.Difference(currentPortSet) // Get mode from waf_services table wafService := &model.WafService{} err = s.db.Model(&model.WafService{}).Where("gateway_name = ?", req.GatewayName).First(wafService).Error if err != nil { if err == gorm.ErrRecordNotFound { return fmt.Errorf("waf service not found for gateway %s", req.GatewayName) } return fmt.Errorf("failed to query waf service: %v", err) } mode := WafMode(wafService.Mode) for _, port := range addingPortSet.List() { err := s.EnableListenerWaf(ctx, &EnableListenerWafReq{ GatewateInfo: GatewateInfo{ GatewayName: req.GatewayName, Namespace: req.Namespace, RegionCode: req.RegionCode, }, Port: port, Hosts: wafMap[port].Hosts, Enable: true, Mode: mode, ListenerName: wafMap[port].Name, }) if err != nil { return fmt.Errorf("failed to enable listener WAF: %v", err) } } // delete WAF for ports that are not in the desired port set deletingPortSet := currentPortSet.Difference(desiredPortSet) for _, port := range deletingPortSet.List() { err := s.DeleteListenerWaf(ctx, &DeleteListenerReq{ GatewateInfo: GatewateInfo{ GatewayName: req.GatewayName, Namespace: req.Namespace, RegionCode: req.RegionCode, }, Port: port, }) if err != nil { return fmt.Errorf("failed to delete listener WAF: %v", err) } } return nil }