package service import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "net/url" "os" "strconv" "strings" "github.com/rs/zerolog/log" "gitlab.com/tensorsecurity-rd/waf-console/internal/model" "gitlab.com/tensorsecurity-rd/waf-console/internal/utils" "gitlab.com/tensorsecurity-rd/waf-console/pkg/apis/waf.security.io/v1alpha1" "gopkg.in/yaml.v3" "gorm.io/gorm" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) type wafService struct { clusterClientManager *utils.ClusterClientManager db *gorm.DB } func NewWafService(clusterClientManager *utils.ClusterClientManager, db *gorm.DB) Service { return &wafService{clusterClientManager: clusterClientManager, db: db} } func (s *wafService) GetWaf(ctx context.Context, regionCode, namespace, gatewayName string) (*WafService, error) { wafService := &model.WafService{} err := s.db.Model(&model.WafService{}).Where("gateway_name = ? AND region_code = ? AND namespace = ?", gatewayName, regionCode, namespace).First(wafService).Error if err != nil { if err == gorm.ErrRecordNotFound { // Create new WAF service record if not found wafService = &model.WafService{ RegionCode: regionCode, Namespace: namespace, GatewayName: gatewayName, Mode: string(WafModeAlert), } if err := s.db.Create(wafService).Error; err != nil { return nil, fmt.Errorf("failed to create WAF service: %v", err) } } else { return nil, fmt.Errorf("failed to query WAF service: %v", err) } } return &WafService{ GatewayName: wafService.GatewayName, Mode: wafService.Mode, RuleNum: wafService.RuleNum, AttackNum: wafService.AttackNum, }, nil } func (s *wafService) GetWafGatewayInfo(ctx context.Context, req *GetWafGatewayInfoReq) (*WafService, error) { wafService := &model.WafService{} err := s.db.Model(&model.WafService{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(wafService).Error if err != nil { if err == gorm.ErrRecordNotFound { httpRequst := http.Request{ Method: http.MethodPost, URL: &url.URL{Scheme: "https", Host: "console.tensorsecurity.com", Path: "/api/v1/waf/gateway"}, Header: http.Header{ "Cookie": []string{req.Cookie}, }, Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{"gateway_name": "%s", "namespace": "%s", "region_code": "%s"}`, req.GatewayName, req.Namespace, req.RegionCode))), } resp, err := http.DefaultClient.Do(&httpRequst) if err != nil { return nil, fmt.Errorf("failed to get WAF service: %v", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read WAF service: %v", err) } var wafService model.WafService err = json.Unmarshal(body, &wafService) if err != nil { return nil, fmt.Errorf("failed to unmarshal WAF service: %v", err) } wafService.ID = 0 wafService.RuleCategoryStatus = nil wafService.RuleNum = 0 wafService.AttackNum = 0 // wafService.Host = model.HostList([]string{"*"}) wafService.Mode = string(WafModeAlert) err = s.db.Create(wafService).Error if err != nil { return nil, fmt.Errorf("failed to create WAF service: %v", err) } } else { return nil, fmt.Errorf("failed to query WAF service: %v", err) } } return &WafService{ GatewayName: wafService.GatewayName, Mode: wafService.Mode, RuleNum: wafService.RuleNum, AttackNum: wafService.AttackNum, }, nil } func (s *wafService) getRulesForService(req *CreateWafReq) ([]v1alpha1.Rule, error) { rules := []v1alpha1.Rule{} ruleCategories := []model.WafRuleCategory{} if err := s.db.Model(&model.WafRuleCategory{}).Where("status = ?", 0).Find(&ruleCategories).Error; err != nil { return nil, fmt.Errorf("failed to get rule categories: %v", err) } // Get existing WAF service config if any wafService := &model.WafService{} err := s.db.Model(&model.WafService{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(wafService).Error if err != nil { if err == gorm.ErrRecordNotFound { // Create new WAF service record if not found wafService = &model.WafService{ RegionCode: req.RegionCode, Namespace: req.Namespace, GatewayName: req.GatewayName, Mode: string(WafModeAlert), } if err := s.db.Create(wafService).Error; err != nil { return nil, fmt.Errorf("failed to create WAF service: %v", err) } } else { return nil, fmt.Errorf("failed to query WAF service: %v", err) } } // Determine which rule categories to enable var enabledCategories []model.WafRuleCategory if wafService.RuleCategoryStatus != nil && len(wafService.RuleCategoryStatus.CategoryID) > 0 { // Only include categories not already enabled for _, category := range ruleCategories { for _, id := range wafService.RuleCategoryStatus.CategoryID { if id == category.CategoryID { enabledCategories = append(enabledCategories, category) continue } } } } else { // Enable all categories if none specified enabledCategories = ruleCategories } for _, category := range enabledCategories { for _, rule := range category.Rules { rules = append(rules, v1alpha1.Rule{ ID: rule.ID, Level: rule.Level, Name: rule.Name, Type: rule.Type, Description: rule.Description, Expr: rule.Expr, Mode: rule.Mode, }) } } return rules, nil } func (s *wafService) CreateWaf(ctx context.Context, req *CreateWafReq) (*WafService, error) { // Create the WAF service resource name := fmt.Sprintf("%s-%d", req.GatewayName, req.Port) service := &v1alpha1.Service{ ObjectMeta: metav1.ObjectMeta{ Name: name, Namespace: req.Namespace, Labels: map[string]string{ "apigateway_name": req.GatewayName, }, }, Spec: v1alpha1.ServiceSpec{ HostNames: req.Host, ServiceName: req.GatewayName, Port: req.Port, Workload: v1alpha1.WorkloadRef{ Kind: "Deployment", Name: req.GatewayName, Namespace: req.Namespace, }, Uri: &v1alpha1.StringMatch{ Prefix: "/", }, LogConfig: &v1alpha1.LogConfig{ Enable: 1, Level: "info", }, Mode: "block", }, } // Get enabled rule categories from DB // var ruleCategories []model.WafRuleCategory // if err := s.db.Model(&model.WafRuleCategory{}).Where("status = ?", 0).Find(&ruleCategories).Error; err != nil { // return nil, fmt.Errorf("failed to get rule categories: %v", err) // } // // Get existing WAF service config if any // wafService := &model.WafService{} // err := s.db.Model(&model.WafService{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(wafService).Error // if err != nil { // if err == gorm.ErrRecordNotFound { // // Create new WAF service record if not found // wafService = &model.WafService{ // RegionCode: req.RegionCode, // Namespace: req.Namespace, // GatewayName: req.GatewayName, // Mode: string(WafModeAlert), // // Host: model.HostList(req.Host), // } // if err := s.db.Create(wafService).Error; err != nil { // return nil, fmt.Errorf("failed to create WAF service: %v", err) // } // } else { // return nil, fmt.Errorf("failed to query WAF service: %v", err) // } // } // // Determine which rule categories to enable // var enabledCategories []model.WafRuleCategory // if wafService.RuleCategoryStatus != nil && len(wafService.RuleCategoryStatus.CategoryID) > 0 { // // Only include categories not already enabled // for _, category := range ruleCategories { // for _, id := range wafService.RuleCategoryStatus.CategoryID { // if id == category.CategoryID { // enabledCategories = append(enabledCategories, category) // continue // } // } // } // } else { // // Enable all categories if none specified // enabledCategories = ruleCategories // } // // Add rules from enabled categories // for _, category := range enabledCategories { // for _, rule := range category.Rules { // service.Spec.Rules = append(service.Spec.Rules, v1alpha1.Rule{ // ID: rule.ID, // Level: rule.Level, // Name: rule.Name, // Type: rule.Type, // Description: rule.Description, // Expr: rule.Expr, // Mode: rule.Mode, // }) // } // } rules, err := s.getRulesForService(req) if err != nil { return nil, fmt.Errorf("failed to get rules for service: %v", err) } service.Spec.Rules = rules // Create the WAF service in Kubernetes client := s.clusterClientManager.GetClient(req.RegionCode) if client == nil { return nil, fmt.Errorf("failed to get cluster client: %v", err) } if _, err := client.WafV1alpha1().Services(req.Namespace).Create(ctx, service, metav1.CreateOptions{}); err != nil { return nil, fmt.Errorf("failed to create WAF service: %v", err) } return nil, nil } func (s *wafService) DeleteListenerWaf(ctx context.Context, req *DeleteListenerReq) error { client := s.clusterClientManager.GetClient(req.RegionCode) if client == nil { return fmt.Errorf("failed to get cluster client") } name := fmt.Sprintf("%s-%d", req.GatewayName, req.Port) if err := client.WafV1alpha1().Services(req.Namespace).Delete(ctx, name, metav1.DeleteOptions{}); err != nil { return fmt.Errorf("failed to delete WAF service: %v", err) } return nil } func (s *wafService) UpdateMode(ctx context.Context, req *UpdateModeReq) (*WafService, error) { // Check if WAF service exists wafService := &model.WafService{} err := s.db.Model(&model.WafService{}).Where("gateway_name = ?", req.GatewayName).First(wafService).Error if err != nil { if err == gorm.ErrRecordNotFound { // Create new WAF service record if not found wafService = &model.WafService{ RegionCode: req.RegionCode, Namespace: req.Namespace, GatewayName: req.GatewayName, Mode: string(req.Mode), } if err := s.db.Create(wafService).Error; err != nil { return nil, fmt.Errorf("failed to create WAF service: %v", err) } } else { return nil, fmt.Errorf("failed to query WAF service: %v", err) } } else { // Update mode if service exists if err := s.db.Model(wafService).Update("mode", string(req.Mode)).Error; err != nil { return nil, fmt.Errorf("failed to update WAF service mode: %v", err) } } return nil, nil } func (s *wafService) GetRuleCategories(ctx context.Context) ([]WafRuleCategory, error) { var categories []WafRuleCategory err := s.db.Table("waf_rule_categories").Find(&categories).Error if err != nil { return nil, err } return categories, nil } func (s *wafService) GetRules(ctx context.Context, categoryID string) ([]WafRule, error) { var rules []WafRule err := s.db.Table("waf_rules").Where("category_id = ?", categoryID).Find(&rules).Error if err != nil { return nil, err } return rules, nil } func (s *wafService) GetRule(ctx context.Context, ruleID int) (*WafRule, error) { var rule WafRule err := s.db.Table("waf_rules").Where("id = ?", ruleID).First(&rule).Error if err != nil { return nil, err } return &rule, nil } func (s *wafService) SaveRuleCategoryToDB(ctx context.Context) error { var categories []WafRuleCategory yamlFile, err := os.ReadFile("rules/waf-rules.yaml") if err != nil { return fmt.Errorf("error reading yaml file: %v", err) } err = yaml.Unmarshal(yamlFile, &categories) if err != nil { return fmt.Errorf("error unmarshaling yaml: %v", err) } for _, category := range categories { rules := []model.WafRule{} for _, rule := range category.Rules { rules = append(rules, model.WafRule{ ID: rule.ID, CategoryID: category.CategoryID, Level: rule.Level, Name: rule.Name, Type: rule.Type, Description: rule.Description, Expr: rule.Expr, Mode: rule.Mode, }) } model := model.WafRuleCategory{ CategoryID: category.CategoryID, Status: category.Status, CategoryEN: category.Catagory.EN, CategoryZH: category.Catagory.Zh, DescriptionEN: category.Description.EN, DescriptionZH: category.Description.Zh, Rules: model.RuleList(rules), } err = s.db.Table("waf_rule_categories").Create(&model).Error if err != nil { return err } } return nil } func (s *wafService) CreateListener(ctx context.Context, req *CreateListenerReq) (*GatewayListener, error) { listener := &model.GatewayListener{} err := s.db.Model(&model.GatewayListener{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(listener).Error if err != nil { if err == gorm.ErrRecordNotFound { listener = &model.GatewayListener{ GatewayName: req.GatewayName, Namespace: req.Namespace, RegionCode: req.RegionCode, Port: req.Port, Enable: true, } err = s.db.Model(&model.GatewayListener{}).Create(listener).Error if err != nil { return nil, err } } else { return nil, err } } return &GatewayListener{ GatewayName: listener.GatewayName, Namespace: listener.Namespace, RegionCode: listener.RegionCode, Port: listener.Port, Enable: listener.Enable, }, nil } func (s *wafService) DeleteListener(ctx context.Context, req *DeleteListenerReq) error { listener := &model.GatewayListener{} err := s.db.Model(&model.GatewayListener{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(listener).Error if err != nil { return err } err = s.db.Model(&model.GatewayListener{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).Delete(listener).Error if err != nil { return err } return nil } func (s *wafService) EnableListenerWaf(ctx context.Context, req *EnableListenerWafReq) error { listener := &model.GatewayListener{} err := s.db.Model(&model.GatewayListener{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(listener).Error if err != nil { if err == gorm.ErrRecordNotFound { listener = &model.GatewayListener{ GatewayName: req.GatewayName, Namespace: req.Namespace, RegionCode: req.RegionCode, Port: int(req.Port), Enable: req.Enable, Hosts: req.Hosts, } err = s.db.Model(&model.GatewayListener{}).Create(listener).Error if err != nil { return err } } else { return err } } listener.Enable = req.Enable err = s.db.Model(&model.GatewayListener{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).Update("enable", req.Enable).Error if err != nil { return err } // wafService := &model.WafService{} // err = s.db.Model(&model.WafService{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(wafService).Error // if err != nil { // if err == gorm.ErrRecordNotFound { // wafService = &model.WafService{ // GatewayName: req.GatewayName, // Namespace: req.Namespace, // RegionCode: req.RegionCode, // Mode: string(WafModeAlert), // } // if err := s.db.Create(wafService).Error; err != nil { // return err // } // } else { // return err // } // } if listener.Enable { log.Info().Msgf("Create WAF for listener %s", listener.GatewayName) _, err := s.CreateWaf(ctx, &CreateWafReq{ GatewateInfo: GatewateInfo{ GatewayName: req.GatewayName, Namespace: req.Namespace, RegionCode: req.RegionCode, }, Port: uint32(req.Port), Host: req.Hosts, }) if err != nil { return err } } else { log.Info().Msgf("Delete WAF for listener %s", listener.GatewayName) err = s.DeleteListenerWaf(ctx, &DeleteListenerReq{ GatewateInfo: GatewateInfo{ GatewayName: req.GatewayName, Namespace: req.Namespace, RegionCode: req.RegionCode, }, Port: req.Port, }) if err != nil { return err } } return nil } func getGatewayNameFromCrn(crn string) string { // crn:ucs::apigateway:lf-tst7:214613666997:instance/testaaa parts := strings.Split(crn, "/") return parts[len(parts)-1] } func (s *wafService) ListListenerFromApiGateway(ctx context.Context, apiGatewayCrn string, regionCode string, cookie string) ([]GatewayRespListenerData, error) { body, err := json.Marshal(map[string]string{ "apigateway_crn": apiGatewayCrn, "region_code": regionCode, }) if err != nil { return nil, fmt.Errorf("failed to marshal request body: %v", err) } request, err := http.NewRequest("POST", "https://csm.console.test.tg.unicom.local/apigatewaymng/listener/lf-tst7/list_listeners", bytes.NewBuffer(body)) if err != nil { return nil, fmt.Errorf("failed to create request: %v", err) } request.Header.Set("Cookie", cookie) resp, err := http.DefaultClient.Do(request) if err != nil { return nil, fmt.Errorf("failed to get listener list: %v", err) } defer resp.Body.Close() // Parse response var response GatewayListenerResponseList if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { return nil, fmt.Errorf("failed to parse listener list: %v", err) } return response.Data, nil } func (s *wafService) EnableGatewayWaf(ctx context.Context, req *EnableGatewayWafReq) error { if req.Enable { listeners, err := s.ListListenerFromApiGateway(ctx, req.ApiGatewayCrn, req.RegionCode, req.Cookie) if err != nil { return fmt.Errorf("failed to get listener list: %v", err) } log.Info().Msgf("listeners: %v", listeners) // Create WAF for each listener for _, listener := range listeners { gatewayName := getGatewayNameFromCrn(listener.ApiGatewayCrn) namespace := fmt.Sprintf("%s-%s", listener.CreateAccountName, listener.CreateAccountID) if _, err := s.CreateWaf(ctx, &CreateWafReq{ GatewateInfo: GatewateInfo{ GatewayName: gatewayName, Namespace: namespace, RegionCode: req.RegionCode, }, Port: uint32(listener.Port), Host: listener.Hosts, }); err != nil { return fmt.Errorf("failed to create WAF for listener %d: %v", listener.Port, err) } } } else { s.DeleteGatewayWaf(ctx, &GatewateInfo{ GatewayName: req.GatewayName, Namespace: req.Namespace, RegionCode: req.RegionCode, }) } return nil } func (s *wafService) DeleteGatewayWaf(ctx context.Context, req *GatewateInfo) error { client := s.clusterClientManager.GetClient(req.RegionCode) if client == nil { return fmt.Errorf("failed to get cluster client") } labelSelector := fmt.Sprintf("apigateway_name=%s", req.GatewayName) if err := client.WafV1alpha1().Services(req.Namespace).DeleteCollection(ctx, metav1.DeleteOptions{}, metav1.ListOptions{LabelSelector: labelSelector}); err != nil { return fmt.Errorf("failed to delete WAF service: %v", err) } return nil } func (s *wafService) UpdateRule(ctx context.Context, req *RuleRequest) error { wafService := &model.WafService{} err := s.db.Model(&model.WafService{}).Where("gateway_name = ?", req.GatewayName).First(wafService).Error if err != nil { if err == gorm.ErrRecordNotFound { // Create new WAF service record if not found wafService = &model.WafService{ RegionCode: req.RegionCode, Namespace: req.Namespace, GatewayName: req.GatewayName, Mode: string(WafModeAlert), RuleCategoryStatus: &model.RuleCategoryStatus{ CategoryID: req.CategoryID, Status: req.Status, }, } if err := s.db.Create(wafService).Error; err != nil { return fmt.Errorf("failed to create WAF service: %v", err) } } else { return fmt.Errorf("failed to query WAF service: %v", err) } } else { // Update mode if service exists if err := s.db.Model(wafService).Update("rule_category_status", model.RuleCategoryStatus{ CategoryID: req.CategoryID, Status: req.Status, }).Error; err != nil { return fmt.Errorf("failed to update WAF service mode: %v", err) } } return nil } func (s *wafService) ListListenerWafStatus(ctx context.Context, req *GatewateInfo) ([]*GatewayListener, error) { client := s.clusterClientManager.GetClient(req.RegionCode) if client == nil { return nil, fmt.Errorf("failed to get cluster client") } listenerList, err := client.WafV1alpha1().Services(req.Namespace).List(ctx, metav1.ListOptions{LabelSelector: fmt.Sprintf("apigateway_name=%s", req.GatewayName)}) if err != nil { return nil, fmt.Errorf("failed to get listener list: %v", err) } listenerStatusList := []*GatewayListener{} portList := []int{} for _, listener := range listenerList.Items { n := strings.LastIndex(listener.Name, "-") if n == -1 { return nil, fmt.Errorf("failed to get listener port: %v", listener.Name) } listenerPort := listener.Name[n+1:] listenerPortInt, err := strconv.Atoi(listenerPort) if err != nil { return nil, fmt.Errorf("failed to parse listener port: %v", err) } portList = append(portList, listenerPortInt) } for _, port := range portList { listenerStatusList = append(listenerStatusList, &GatewayListener{ GatewayName: req.GatewayName, Namespace: req.Namespace, RegionCode: req.RegionCode, Port: port, Enable: true, }) } return listenerStatusList, nil }