Commit 7df18457 authored by qiuqunfeng's avatar qiuqunfeng
Browse files

Implement WAF rules calculation and update logic

- Add calculateCrdWafRules method to compute WAF rules based on enabled categories.
- Introduce updateRulesForCrd method to update WAF service rules in the cluster.
- Modify UpdateRule method to include rules update after fetching the WAF service record.
- Remove commented-out GetRuleCategories method for code clarity.
parent 43bf5d54
......@@ -382,15 +382,6 @@ func (s *wafService) UpdateMode(ctx context.Context, req *UpdateModeReq) (*WafSe
}, nil
}
// func (s *wafService) GetRuleCategories(ctx context.Context) ([]WafRuleCategory, error) {
// var categories []WafRuleCategory
// err := s.db.Table("waf_rule_categories").Find(&categories).Error
// if err != nil {
// return nil, err
// }
// return categories, nil
// }
func (s *wafService) GetRules(ctx context.Context, categoryID string) ([]WafRule, error) {
var rules []WafRule
err := s.db.Table("waf_rules").Where("category_id = ?", categoryID).Find(&rules).Error
......@@ -621,9 +612,80 @@ func (s *wafService) DeleteGatewayWaf(ctx context.Context, req *GatewateInfo) er
return nil
}
func (s *wafService) calculateCrdWafRules(ctx context.Context, req *RuleRequest, wafService *model.WafService) ([]v1alpha1.Rule, error) {
rules := []v1alpha1.Rule{}
ruleCategories := []model.WafRuleCategory{}
if err := s.db.WithContext(ctx).Model(&model.WafRuleCategory{}).Where("status = ?", 0).Find(&ruleCategories).Error; err != nil {
return nil, fmt.Errorf("failed to get rule categories: %v", err)
}
// Determine which rule categories to enable
var enabledCategories []model.WafRuleCategory
if wafService.RuleCategoryStatus != nil && len(wafService.RuleCategoryStatus.CategoryID) > 0 {
// Only include categories not already enabled
for _, category := range ruleCategories {
for _, id := range wafService.RuleCategoryStatus.CategoryID {
if id == category.CategoryID {
enabledCategories = append(enabledCategories, category)
continue
}
}
}
} else {
// Enable all categories if none specified
enabledCategories = ruleCategories
}
for _, category := range enabledCategories {
for _, rule := range category.Rules {
rules = append(rules, v1alpha1.Rule{
ID: rule.ID,
Level: rule.Level,
Name: rule.Name,
Type: rule.Type,
Description: rule.Description,
Expr: rule.Expr,
Mode: rule.Mode,
})
}
}
return rules, nil
}
func (s *wafService) updateRulesForCrd(ctx context.Context, req *RuleRequest, wafService *model.WafService) error {
client := s.clusterClientManager.GetClient(req.RegionCode)
if client == nil {
return fmt.Errorf("failed to get cluster client")
}
serviceList, err := client.Versioned.WafV1alpha1().Services(req.Namespace).List(ctx, metav1.ListOptions{LabelSelector: fmt.Sprintf("apigateway_name=%s", req.GatewayName)})
if err != nil {
return fmt.Errorf("failed to get WAF service: %v", err)
}
if len(serviceList.Items) == 0 {
log.Info().Msgf("WAF service not found for gateway %s", req.GatewayName)
return nil
}
rules, err := s.calculateCrdWafRules(ctx, req, wafService)
if err != nil {
return fmt.Errorf("failed to calculate WAF rules: %v", err)
}
for _, service := range serviceList.Items {
service.Spec.Rules = rules
_, err = client.Versioned.WafV1alpha1().Services(req.Namespace).Update(ctx, &service, metav1.UpdateOptions{})
if err != nil {
return fmt.Errorf("failed to update WAF service: %v", err)
}
}
return nil
}
func (s *wafService) UpdateRule(ctx context.Context, req *RuleRequest) error {
wafService := &model.WafService{}
err := s.db.Model(&model.WafService{}).Where("gateway_name = ?", req.GatewayName).First(wafService).Error
err := s.db.Model(&model.WafService{}).Where("gateway_name = ? and namespace = ? and region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(wafService).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
// Create new WAF service record if not found
......@@ -652,6 +714,10 @@ func (s *wafService) UpdateRule(ctx context.Context, req *RuleRequest) error {
return fmt.Errorf("failed to update WAF service mode: %v", err)
}
}
err = s.updateRulesForCrd(ctx, req, wafService)
if err != nil {
return fmt.Errorf("failed to update WAF rules: %v", err)
}
return nil
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment