Commit 17e6a17f authored by qiuqunfeng's avatar qiuqunfeng
Browse files

Add hosts column to gateway_listeners and update related WAF service methods

- Modified SQL migration to add 'hosts' column to gateway_listeners table
- Updated model, service, and controller to support hosts in gateway listener operations
- Added new GetWafGatewayInfo method to retrieve WAF gateway information
- Refactored WAF service methods to handle hosts and improve flexibility
parent 08f46675
...@@ -12,6 +12,7 @@ func SetWafRouter(e *gin.Engine, clusterClientManager *utils.ClusterClientManage ...@@ -12,6 +12,7 @@ func SetWafRouter(e *gin.Engine, clusterClientManager *utils.ClusterClientManage
wafController := controller.NewWafController(clusterClientManager, db) wafController := controller.NewWafController(clusterClientManager, db)
v1.GET("/:region_code/:namespace/:gateway_name", wafController.Waf) v1.GET("/:region_code/:namespace/:gateway_name", wafController.Waf)
// v1.POST("/gateway", wafController.GetWafGatewayInfo)
v1.POST("/", wafController.CreateWaf) v1.POST("/", wafController.CreateWaf)
v1.PUT("mode", wafController.UpdateMode) v1.PUT("mode", wafController.UpdateMode)
v1.PUT("rules", wafController.UpdateRule) v1.PUT("rules", wafController.UpdateRule)
......
...@@ -42,7 +42,8 @@ CREATE TABLE gateway_listeners ( ...@@ -42,7 +42,8 @@ CREATE TABLE gateway_listeners (
namespace VARCHAR(255) NOT NULL, namespace VARCHAR(255) NOT NULL,
region_code VARCHAR(50) NOT NULL, region_code VARCHAR(50) NOT NULL,
port INTEGER NOT NULL, port INTEGER NOT NULL,
enable BOOLEAN NOT NULL enable BOOLEAN NOT NULL,
hosts TEXT
); );
-- Add indexes for better query performance -- Add indexes for better query performance
......
...@@ -2,6 +2,7 @@ package controller ...@@ -2,6 +2,7 @@ package controller
import ( import (
"context" "context"
"errors"
"strconv" "strconv"
"time" "time"
...@@ -40,6 +41,30 @@ func (c *WafController) Waf(ctx *gin.Context) { ...@@ -40,6 +41,30 @@ func (c *WafController) Waf(ctx *gin.Context) {
utils.AssembleResponse(ctx, resp, nil) utils.AssembleResponse(ctx, resp, nil)
} }
func (c *WafController) GetWafGatewayInfo(ctx *gin.Context) {
ctx1, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cookie := ctx.Request.Header.Get("Cookie")
if cookie == "" {
utils.AssembleResponse(ctx, nil, errors.New("cookie is required"))
return
}
var req service.GetWafGatewayInfoReq
if err := ctx.BindJSON(&req); err != nil {
utils.AssembleResponse(ctx, nil, err)
return
}
info, err := c.service.GetWafGatewayInfo(ctx1, &req)
if err != nil {
utils.AssembleResponse(ctx, nil, err)
return
}
utils.AssembleResponse(ctx, info, nil)
}
func (c *WafController) CreateWaf(ctx *gin.Context) { func (c *WafController) CreateWaf(ctx *gin.Context) {
ctx1, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx1, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
...@@ -148,6 +173,7 @@ func (c *WafController) EnableGatewayWaf(ctx *gin.Context) { ...@@ -148,6 +173,7 @@ func (c *WafController) EnableGatewayWaf(ctx *gin.Context) {
utils.AssembleResponse(ctx, nil, err) utils.AssembleResponse(ctx, nil, err)
return return
} }
req.Cookie = ctx.Request.Header.Get("Cookie")
err := c.service.EnableGatewayWaf(ctx1, &req) err := c.service.EnableGatewayWaf(ctx1, &req)
if err != nil { if err != nil {
......
...@@ -56,7 +56,7 @@ type WafService struct { ...@@ -56,7 +56,7 @@ type WafService struct {
RuleNum int `gorm:"column:rule_num"` RuleNum int `gorm:"column:rule_num"`
AttackNum int `gorm:"column:attack_num"` AttackNum int `gorm:"column:attack_num"`
RuleCategoryStatus *RuleCategoryStatus `gorm:"column:rule_category_status;type:json"` RuleCategoryStatus *RuleCategoryStatus `gorm:"column:rule_category_status;type:json"`
Host HostList `gorm:"column:host"` // Host HostList `gorm:"column:host"`
} }
func (WafService) TableName() string { func (WafService) TableName() string {
...@@ -123,6 +123,7 @@ type GatewayListener struct { ...@@ -123,6 +123,7 @@ type GatewayListener struct {
RegionCode string `gorm:"column:region_code"` RegionCode string `gorm:"column:region_code"`
Port int `gorm:"column:port"` Port int `gorm:"column:port"`
Enable bool `gorm:"column:enable"` Enable bool `gorm:"column:enable"`
Hosts HostList `gorm:"column:hosts"`
} }
func (GatewayListener) TableName() string { func (GatewayListener) TableName() string {
......
...@@ -5,6 +5,7 @@ import "context" ...@@ -5,6 +5,7 @@ import "context"
type Service interface { type Service interface {
// QueryIP(ip string) (*model.IPInfo, error) // QueryIP(ip string) (*model.IPInfo, error)
GetWaf(ctx context.Context, regionCode, namespace, gatewayName string) (*WafService, error) GetWaf(ctx context.Context, regionCode, namespace, gatewayName string) (*WafService, error)
GetWafGatewayInfo(ctx context.Context, req *GetWafGatewayInfoReq) (*WafService, error)
CreateWaf(ctx context.Context, req *CreateWafReq) (*WafService, error) CreateWaf(ctx context.Context, req *CreateWafReq) (*WafService, error)
UpdateMode(ctx context.Context, req *UpdateModeReq) (*WafService, error) UpdateMode(ctx context.Context, req *UpdateModeReq) (*WafService, error)
UpdateRule(ctx context.Context, req *RuleRequest) error UpdateRule(ctx context.Context, req *RuleRequest) error
......
...@@ -175,7 +175,12 @@ type GatewateInfo struct { ...@@ -175,7 +175,12 @@ type GatewateInfo struct {
Namespace string `json:"namespace"` Namespace string `json:"namespace"`
RegionCode string `json:"region_code"` RegionCode string `json:"region_code"`
ApiGatewayCrn string `json:"gateway_crn"` ApiGatewayCrn string `json:"gateway_crn"`
Hosts []string `json:"hosts"` // Hosts []string `json:"hosts"`
}
type GetWafGatewayInfoReq struct {
GatewateInfo
Cookie string `json:"cookie"`
} }
type CreateWafReq struct { type CreateWafReq struct {
...@@ -213,12 +218,14 @@ type UpdateModeReq struct { ...@@ -213,12 +218,14 @@ type UpdateModeReq struct {
type EnableListenerWafReq struct { type EnableListenerWafReq struct {
GatewateInfo GatewateInfo
Enable bool `json:"enable"` Enable bool `json:"enable"`
Hosts []string `json:"hosts"`
Port int `json:"port"` Port int `json:"port"`
} }
type EnableGatewayWafReq struct { type EnableGatewayWafReq struct {
GatewateInfo GatewateInfo
Enable bool `json:"enable"` Enable bool `json:"enable"`
Cookie string `json:"cookie"`
} }
type WafRule struct { type WafRule struct {
...@@ -266,3 +273,29 @@ type DeleteListenerReq struct { ...@@ -266,3 +273,29 @@ type DeleteListenerReq struct {
GatewateInfo GatewateInfo
Port int `json:"port"` Port int `json:"port"`
} }
type GatewayRespListenerData struct {
GatewayName string `json:"gateway_name"`
Namespace string `json:"namespace"`
ListenerName string `json:"listener_name"`
ApiGatewayCrn string `json:"apigateway_crn"`
CreateAccountName string `json:"create_account_name"`
CreateAccountID string `json:"create_account_id"`
Hosts []string `json:"hosts"`
Port int `json:"port"`
}
type GatewayResponseBase struct {
Code int `json:"code"`
Message string `json:"message"`
}
type GatewayListenerResponse struct {
GatewayResponseBase
Data GatewayRespListenerData `json:"data"`
}
type GatewayListenerResponseList struct {
GatewayResponseBase
Data []GatewayRespListenerData `json:"data"`
}
...@@ -5,7 +5,9 @@ import ( ...@@ -5,7 +5,9 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/url"
"os" "os"
"strconv" "strconv"
"strings" "strings"
...@@ -39,7 +41,6 @@ func (s *wafService) GetWaf(ctx context.Context, regionCode, namespace, gatewayN ...@@ -39,7 +41,6 @@ func (s *wafService) GetWaf(ctx context.Context, regionCode, namespace, gatewayN
Namespace: namespace, Namespace: namespace,
GatewayName: gatewayName, GatewayName: gatewayName,
Mode: string(WafModeAlert), Mode: string(WafModeAlert),
Host: model.HostList([]string{"*"}),
} }
if err := s.db.Create(wafService).Error; err != nil { if err := s.db.Create(wafService).Error; err != nil {
return nil, fmt.Errorf("failed to create WAF service: %v", err) return nil, fmt.Errorf("failed to create WAF service: %v", err)
...@@ -56,6 +57,55 @@ func (s *wafService) GetWaf(ctx context.Context, regionCode, namespace, gatewayN ...@@ -56,6 +57,55 @@ func (s *wafService) GetWaf(ctx context.Context, regionCode, namespace, gatewayN
}, nil }, nil
} }
func (s *wafService) GetWafGatewayInfo(ctx context.Context, req *GetWafGatewayInfoReq) (*WafService, error) {
wafService := &model.WafService{}
err := s.db.Model(&model.WafService{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(wafService).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
httpRequst := http.Request{
Method: http.MethodPost,
URL: &url.URL{Scheme: "https", Host: "console.tensorsecurity.com", Path: "/api/v1/waf/gateway"},
Header: http.Header{
"Cookie": []string{req.Cookie},
},
Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{"gateway_name": "%s", "namespace": "%s", "region_code": "%s"}`, req.GatewayName, req.Namespace, req.RegionCode))),
}
resp, err := http.DefaultClient.Do(&httpRequst)
if err != nil {
return nil, fmt.Errorf("failed to get WAF service: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read WAF service: %v", err)
}
var wafService model.WafService
err = json.Unmarshal(body, &wafService)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal WAF service: %v", err)
}
wafService.ID = 0
wafService.RuleCategoryStatus = nil
wafService.RuleNum = 0
wafService.AttackNum = 0
// wafService.Host = model.HostList([]string{"*"})
wafService.Mode = string(WafModeAlert)
err = s.db.Create(wafService).Error
if err != nil {
return nil, fmt.Errorf("failed to create WAF service: %v", err)
}
} else {
return nil, fmt.Errorf("failed to query WAF service: %v", err)
}
}
return &WafService{
GatewayName: wafService.GatewayName,
Mode: wafService.Mode,
RuleNum: wafService.RuleNum,
AttackNum: wafService.AttackNum,
}, nil
}
func (s *wafService) getRulesForService(req *CreateWafReq) ([]v1alpha1.Rule, error) { func (s *wafService) getRulesForService(req *CreateWafReq) ([]v1alpha1.Rule, error) {
rules := []v1alpha1.Rule{} rules := []v1alpha1.Rule{}
ruleCategories := []model.WafRuleCategory{} ruleCategories := []model.WafRuleCategory{}
...@@ -74,7 +124,6 @@ func (s *wafService) getRulesForService(req *CreateWafReq) ([]v1alpha1.Rule, err ...@@ -74,7 +124,6 @@ func (s *wafService) getRulesForService(req *CreateWafReq) ([]v1alpha1.Rule, err
Namespace: req.Namespace, Namespace: req.Namespace,
GatewayName: req.GatewayName, GatewayName: req.GatewayName,
Mode: string(WafModeAlert), Mode: string(WafModeAlert),
Host: model.HostList(req.Host),
} }
if err := s.db.Create(wafService).Error; err != nil { if err := s.db.Create(wafService).Error; err != nil {
return nil, fmt.Errorf("failed to create WAF service: %v", err) return nil, fmt.Errorf("failed to create WAF service: %v", err)
...@@ -151,64 +200,70 @@ func (s *wafService) CreateWaf(ctx context.Context, req *CreateWafReq) (*WafServ ...@@ -151,64 +200,70 @@ func (s *wafService) CreateWaf(ctx context.Context, req *CreateWafReq) (*WafServ
} }
// Get enabled rule categories from DB // Get enabled rule categories from DB
var ruleCategories []model.WafRuleCategory // var ruleCategories []model.WafRuleCategory
if err := s.db.Model(&model.WafRuleCategory{}).Where("status = ?", 0).Find(&ruleCategories).Error; err != nil { // if err := s.db.Model(&model.WafRuleCategory{}).Where("status = ?", 0).Find(&ruleCategories).Error; err != nil {
return nil, fmt.Errorf("failed to get rule categories: %v", err) // return nil, fmt.Errorf("failed to get rule categories: %v", err)
} // }
// Get existing WAF service config if any // // Get existing WAF service config if any
wafService := &model.WafService{} // wafService := &model.WafService{}
err := s.db.Model(&model.WafService{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(wafService).Error // err := s.db.Model(&model.WafService{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(wafService).Error
// if err != nil {
// if err == gorm.ErrRecordNotFound {
// // Create new WAF service record if not found
// wafService = &model.WafService{
// RegionCode: req.RegionCode,
// Namespace: req.Namespace,
// GatewayName: req.GatewayName,
// Mode: string(WafModeAlert),
// // Host: model.HostList(req.Host),
// }
// if err := s.db.Create(wafService).Error; err != nil {
// return nil, fmt.Errorf("failed to create WAF service: %v", err)
// }
// } else {
// return nil, fmt.Errorf("failed to query WAF service: %v", err)
// }
// }
// // Determine which rule categories to enable
// var enabledCategories []model.WafRuleCategory
// if wafService.RuleCategoryStatus != nil && len(wafService.RuleCategoryStatus.CategoryID) > 0 {
// // Only include categories not already enabled
// for _, category := range ruleCategories {
// for _, id := range wafService.RuleCategoryStatus.CategoryID {
// if id == category.CategoryID {
// enabledCategories = append(enabledCategories, category)
// continue
// }
// }
// }
// } else {
// // Enable all categories if none specified
// enabledCategories = ruleCategories
// }
// // Add rules from enabled categories
// for _, category := range enabledCategories {
// for _, rule := range category.Rules {
// service.Spec.Rules = append(service.Spec.Rules, v1alpha1.Rule{
// ID: rule.ID,
// Level: rule.Level,
// Name: rule.Name,
// Type: rule.Type,
// Description: rule.Description,
// Expr: rule.Expr,
// Mode: rule.Mode,
// })
// }
// }
rules, err := s.getRulesForService(req)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { return nil, fmt.Errorf("failed to get rules for service: %v", err)
// Create new WAF service record if not found
wafService = &model.WafService{
RegionCode: req.RegionCode,
Namespace: req.Namespace,
GatewayName: req.GatewayName,
Mode: string(WafModeAlert),
Host: model.HostList(req.Host),
}
if err := s.db.Create(wafService).Error; err != nil {
return nil, fmt.Errorf("failed to create WAF service: %v", err)
}
} else {
return nil, fmt.Errorf("failed to query WAF service: %v", err)
}
}
// Determine which rule categories to enable
var enabledCategories []model.WafRuleCategory
if wafService.RuleCategoryStatus != nil && len(wafService.RuleCategoryStatus.CategoryID) > 0 {
// Only include categories not already enabled
for _, category := range ruleCategories {
for _, id := range wafService.RuleCategoryStatus.CategoryID {
if id == category.CategoryID {
enabledCategories = append(enabledCategories, category)
continue
}
}
}
} else {
// Enable all categories if none specified
enabledCategories = ruleCategories
}
// Add rules from enabled categories
for _, category := range enabledCategories {
for _, rule := range category.Rules {
service.Spec.Rules = append(service.Spec.Rules, v1alpha1.Rule{
ID: rule.ID,
Level: rule.Level,
Name: rule.Name,
Type: rule.Type,
Description: rule.Description,
Expr: rule.Expr,
Mode: rule.Mode,
})
}
} }
service.Spec.Rules = rules
// Create the WAF service in Kubernetes // Create the WAF service in Kubernetes
client := s.clusterClientManager.GetClient(req.RegionCode) client := s.clusterClientManager.GetClient(req.RegionCode)
...@@ -389,6 +444,7 @@ func (s *wafService) EnableListenerWaf(ctx context.Context, req *EnableListenerW ...@@ -389,6 +444,7 @@ func (s *wafService) EnableListenerWaf(ctx context.Context, req *EnableListenerW
RegionCode: req.RegionCode, RegionCode: req.RegionCode,
Port: int(req.Port), Port: int(req.Port),
Enable: req.Enable, Enable: req.Enable,
Hosts: req.Hosts,
} }
err = s.db.Model(&model.GatewayListener{}).Create(listener).Error err = s.db.Model(&model.GatewayListener{}).Create(listener).Error
if err != nil { if err != nil {
...@@ -405,24 +461,23 @@ func (s *wafService) EnableListenerWaf(ctx context.Context, req *EnableListenerW ...@@ -405,24 +461,23 @@ func (s *wafService) EnableListenerWaf(ctx context.Context, req *EnableListenerW
return err return err
} }
wafService := &model.WafService{} // wafService := &model.WafService{}
err = s.db.Model(&model.WafService{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(wafService).Error // err = s.db.Model(&model.WafService{}).Where("gateway_name = ? AND namespace = ? AND region_code = ?", req.GatewayName, req.Namespace, req.RegionCode).First(wafService).Error
if err != nil { // if err != nil {
if err == gorm.ErrRecordNotFound { // if err == gorm.ErrRecordNotFound {
wafService = &model.WafService{ // wafService = &model.WafService{
GatewayName: req.GatewayName, // GatewayName: req.GatewayName,
Namespace: req.Namespace, // Namespace: req.Namespace,
RegionCode: req.RegionCode, // RegionCode: req.RegionCode,
Mode: string(WafModeAlert), // Mode: string(WafModeAlert),
Host: model.HostList([]string{"*"}), // }
} // if err := s.db.Create(wafService).Error; err != nil {
if err := s.db.Create(wafService).Error; err != nil { // return err
return err // }
} // } else {
} else { // return err
return err // }
} // }
}
if listener.Enable { if listener.Enable {
log.Info().Msgf("Create WAF for listener %s", listener.GatewayName) log.Info().Msgf("Create WAF for listener %s", listener.GatewayName)
...@@ -433,7 +488,7 @@ func (s *wafService) EnableListenerWaf(ctx context.Context, req *EnableListenerW ...@@ -433,7 +488,7 @@ func (s *wafService) EnableListenerWaf(ctx context.Context, req *EnableListenerW
RegionCode: req.RegionCode, RegionCode: req.RegionCode,
}, },
Port: uint32(req.Port), Port: uint32(req.Port),
Host: wafService.Host, Host: req.Hosts,
}) })
if err != nil { if err != nil {
return err return err
...@@ -452,36 +507,49 @@ func (s *wafService) EnableListenerWaf(ctx context.Context, req *EnableListenerW ...@@ -452,36 +507,49 @@ func (s *wafService) EnableListenerWaf(ctx context.Context, req *EnableListenerW
return err return err
} }
} }
return nil return nil
} }
func (s *wafService) EnableGatewayWaf(ctx context.Context, req *EnableGatewayWafReq) error { func getGatewayNameFromCrn(crn string) string {
if req.Enable { // crn:ucs::apigateway:lf-tst7:214613666997:instance/testaaa
// Get listener list from API parts := strings.Split(crn, "/")
return parts[len(parts)-1]
}
func (s *wafService) ListListenerFromApiGateway(ctx context.Context, apiGatewayCrn string, regionCode string, cookie string) ([]GatewayRespListenerData, error) {
body, err := json.Marshal(map[string]string{ body, err := json.Marshal(map[string]string{
"apigateway_crn": req.ApiGatewayCrn, "apigateway_crn": apiGatewayCrn,
"region_code": req.RegionCode, "region_code": regionCode,
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal request body: %v", err) return nil, fmt.Errorf("failed to marshal request body: %v", err)
} }
resp, err := http.Post("https://csm.console.test.tg.unicom.local/apigatewaymng/listener/lf-tst7/list_listeners", "application/json", bytes.NewBuffer(body)) request, err := http.NewRequest("POST", "https://csm.console.test.tg.unicom.local/apigatewaymng/listener/lf-tst7/list_listeners", bytes.NewBuffer(body))
if err != nil { if err != nil {
return fmt.Errorf("failed to get listener list: %v", err) return nil, fmt.Errorf("failed to create request: %v", err)
}
request.Header.Set("Cookie", cookie)
resp, err := http.DefaultClient.Do(request)
if err != nil {
return nil, fmt.Errorf("failed to get listener list: %v", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
// Parse response // Parse response
var listeners []struct { var response GatewayListenerResponseList
ApigatewayCrn string `json:"apigateway_crn"`
Port uint32 `json:"port"` if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
Hosts []string `json:"hosts"` return nil, fmt.Errorf("failed to parse listener list: %v", err)
}
if err := json.NewDecoder(resp.Body).Decode(&listeners); err != nil {
return fmt.Errorf("failed to parse listener list: %v", err)
} }
return response.Data, nil
}
func (s *wafService) EnableGatewayWaf(ctx context.Context, req *EnableGatewayWafReq) error {
if req.Enable {
listeners, err := s.ListListenerFromApiGateway(ctx, req.ApiGatewayCrn, req.RegionCode, req.Cookie)
if err != nil {
return fmt.Errorf("failed to get listener list: %v", err)
}
// Create WAF for each listener // Create WAF for each listener
for _, listener := range listeners { for _, listener := range listeners {
if _, err := s.CreateWaf(ctx, &CreateWafReq{ if _, err := s.CreateWaf(ctx, &CreateWafReq{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment