Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 96 additions & 4 deletions runs/service/run_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ func (s *RunService) CreateRun(
}

// Compute storage URIs before DB insert so they're persisted in the ActionSpec
s.fillDefaultInputs(ctx, req.Msg)
inputPrefix := buildInputPrefix(s.storagePrefix, org, project, domain, name)
runOutputBase := buildRunOutputBase(s.storagePrefix, org, project, domain, name)

Expand Down Expand Up @@ -212,16 +213,107 @@ func (s *RunService) CreateRun(
// We should do this after we persist the state to the DB, and when run service fully relies on DB and do not get
// status from the state client directly.

// Build response (simplified - you'd convert the full Run model)
resp := &workflow.CreateRunResponse{
return connect.NewResponse(s.buildCreateRunResponse(run)), nil
}

// buildCreateRunResponse builds CreateRun-specific response payload.
// Keep this separate from convertRunToProto so ListRuns/WatchRuns contract remains unchanged.
func (s *RunService) buildCreateRunResponse(run *models.Run) *workflow.CreateRunResponse {
if run == nil {
return &workflow.CreateRunResponse{}
}

runID := &common.RunIdentifier{
Org: run.Org,
Project: run.Project,
Domain: run.Domain,
Name: run.Name,
}

return &workflow.CreateRunResponse{
Run: &workflow.Run{
Action: &workflow.Action{
Id: actionID,
Id: &common.ActionIdentifier{
Run: runID,
Name: run.Name,
},
Metadata: &workflow.ActionMetadata{},
Status: &workflow.ActionStatus{
Phase: common.ActionPhase(run.Phase),
StartTime: timestamppb.New(run.CreatedAt),
Attempts: 1,
CacheStatus: core.CatalogCacheStatus_CACHE_DISABLED,
},
},
},
}
}

return connect.NewResponse(resp), nil
// fillDefaultInputs merges task default inputs into CreateRunRequest.Inputs without
// overriding user-provided values.
func (s *RunService) fillDefaultInputs(ctx context.Context, req *workflow.CreateRunRequest) {
taskSpec := s.getCreateRunTaskSpec(ctx, req)
if taskSpec == nil || len(taskSpec.GetDefaultInputs()) == 0 {
return
}

if req.Inputs == nil {
req.Inputs = &task.Inputs{}
}

existing := make(map[string]struct{}, len(req.Inputs.GetLiterals()))
for _, input := range req.Inputs.GetLiterals() {
existing[input.GetName()] = struct{}{}
}

for _, defaultInput := range taskSpec.GetDefaultInputs() {
if defaultInput == nil || defaultInput.GetName() == "" {
continue
}
if _, ok := existing[defaultInput.GetName()]; ok {
continue
}

defaultLiteral := defaultInput.GetParameter().GetDefault()
if defaultLiteral == nil {
continue
}

cloned, ok := proto.Clone(defaultLiteral).(*core.Literal)
if !ok || cloned == nil {
cloned = defaultLiteral
}
req.Inputs.Literals = append(req.Inputs.Literals, &task.NamedLiteral{
Name: defaultInput.GetName(),
Value: cloned,
})
existing[defaultInput.GetName()] = struct{}{}
}
}

// getCreateRunTaskSpec returns the TaskSpec from a CreateRunRequest for default
// input merging. If only task_id is provided, it tries to load and decode the
// deployed task spec from the repository.
func (s *RunService) getCreateRunTaskSpec(ctx context.Context, req *workflow.CreateRunRequest) *task.TaskSpec {
switch t := req.GetTask().(type) {
case *workflow.CreateRunRequest_TaskSpec:
return t.TaskSpec
case *workflow.CreateRunRequest_TaskId:
taskModel, err := s.repo.TaskRepo().GetTask(ctx, transformers.ToTaskKey(t.TaskId))
if err != nil {
logger.Warnf(ctx, "CreateRun: cannot load task %s/%s/%s/%s:%s for default inputs: %v",
t.TaskId.GetOrg(), t.TaskId.GetProject(), t.TaskId.GetDomain(), t.TaskId.GetName(), t.TaskId.GetVersion(), err)
return nil
}
var spec task.TaskSpec
if err := proto.Unmarshal(taskModel.TaskSpec, &spec); err != nil {
logger.Warnf(ctx, "CreateRun: failed to decode task spec for default inputs: %v", err)
return nil
}
return &spec
default:
return nil
}
}

// AbortRun aborts a run
Expand Down
148 changes: 148 additions & 0 deletions runs/service/run_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,154 @@ func newTestService(t *testing.T) (*repoMocks.ActionRepo, *mockActionsClient, *R
return actionRepo, actionsClient, svc
}

func TestFillDefaultInputs(t *testing.T) {
svc := &RunService{}

req := &workflow.CreateRunRequest{
Task: &workflow.CreateRunRequest_TaskSpec{
TaskSpec: &task.TaskSpec{
DefaultInputs: []*task.NamedParameter{
{
Name: "x",
Parameter: &core.Parameter{
Behavior: &core.Parameter_Default{
Default: &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_Integer{Integer: 42},
},
},
},
},
},
},
},
},
{
Name: "y",
Parameter: &core.Parameter{
Behavior: &core.Parameter_Default{
Default: &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_StringValue{StringValue: "default"},
},
},
},
},
},
},
},
},
},
},
},
Inputs: &task.Inputs{
Literals: []*task.NamedLiteral{
{
Name: "x",
Value: &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_Integer{Integer: 7},
},
},
},
},
},
},
},
},
}

svc.fillDefaultInputs(context.Background(), req)

assert.Len(t, req.Inputs.Literals, 2)
got := make(map[string]*core.Literal, len(req.Inputs.Literals))
for _, nl := range req.Inputs.Literals {
got[nl.Name] = nl.Value
}
assert.Equal(t, int64(7), got["x"].GetScalar().GetPrimitive().GetInteger(), "provided input should not be overwritten")
assert.Equal(t, "default", got["y"].GetScalar().GetPrimitive().GetStringValue(), "missing input should be filled from default")
}

func TestCreateRunResponseIncludesMetadataAndStatus(t *testing.T) {
actionRepo := &repoMocks.ActionRepo{}
actionsClient := &mockActionsClient{}
repo := &repoMocks.Repository{}
store := &storageMocks.ComposedProtobufStore{}
dataStore := &storage.DataStore{ComposedProtobufStore: store}

repo.On("ActionRepo").Return(actionRepo)

svc := &RunService{
repo: repo,
actionsClient: actionsClient,
storagePrefix: "s3://flyte-data",
dataStore: dataStore,
}

runID := &common.RunIdentifier{
Org: "test-org",
Project: "test-project",
Domain: "test-domain",
Name: "rtest12345",
}
createdAt := time.Now().UTC().Truncate(time.Second)

store.On("WriteProtobuf", mock.Anything, mock.Anything, storage.Options{}, mock.Anything).Return(nil).Once()

actionRepo.On("CreateRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(&models.Run{
Org: runID.Org,
Project: runID.Project,
Domain: runID.Domain,
Name: runID.Name,
Phase: int32(common.ActionPhase_ACTION_PHASE_QUEUED),
CreatedAt: createdAt,
}, nil)

actionsClient.On("Enqueue", mock.Anything, mock.Anything).
Return(connect.NewResponse(&actions.EnqueueResponse{}), nil)

resp, err := svc.CreateRun(context.Background(), connect.NewRequest(&workflow.CreateRunRequest{
Id: &workflow.CreateRunRequest_RunId{
RunId: runID,
},
Task: &workflow.CreateRunRequest_TaskSpec{
TaskSpec: &task.TaskSpec{},
},
}))
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.NotNil(t, resp.Msg.GetRun())
assert.NotNil(t, resp.Msg.GetRun().GetAction())
assert.NotNil(t, resp.Msg.GetRun().GetAction().GetId())
assert.Equal(t, runID.Name, resp.Msg.GetRun().GetAction().GetId().GetName())
assert.NotNil(t, resp.Msg.GetRun().GetAction().GetMetadata())

status := resp.Msg.GetRun().GetAction().GetStatus()
assert.NotNil(t, status)
assert.Equal(t, common.ActionPhase_ACTION_PHASE_QUEUED, status.GetPhase())
assert.NotNil(t, status.GetStartTime())
assert.True(t, status.GetStartTime().AsTime().Equal(createdAt))
assert.Equal(t, uint32(1), status.GetAttempts())
assert.Equal(t, core.CatalogCacheStatus_CACHE_DISABLED, status.GetCacheStatus())
assert.Nil(t, status.EndTime)
assert.Nil(t, status.DurationMs)

repo.AssertExpectations(t)
actionRepo.AssertExpectations(t)
actionsClient.AssertExpectations(t)
store.AssertExpectations(t)
}

func TestAbortRun(t *testing.T) {
runID := &common.RunIdentifier{
Org: "test-org",
Expand Down
Loading