Task Execution Graph
Tasks in Airtrain are subclasses of Skills, inheriting all skill capabilities while adding execution graph functionality. This means tasks can implement all skill methods (process, validate, evaluate) while also defining complex execution flows.
Task as a Skill
class Task(Skill):
"""
Tasks extend Skills with execution graph capabilities.
Inherits all Skill methods and adds graph execution logic.
"""
input_schema: Schema
output_schema: Schema
# Inherited from Skill
def process(self, input_data):
"""Override Skill's process method with graph execution"""
return self.execute_graph(input_data)
def validate_input(self, input_data):
"""Can be overridden or use Skill's default"""
pass
def validate_output(self, output_data):
"""Can be overridden or use Skill's default"""
pass
def evaluate(self, test_dataset=None):
"""Can be overridden or use Skill's default"""
pass
# Task-specific methods
def build_graph(self):
"""Define the execution graph"""
pass
def execute_graph(self, input_data):
"""Execute the graph with input data"""
pass
The task execution graph represents the core logic flow of a task, orchestrating how different skills interact and process data. It defines the sequence and relationships between skills while managing permissions and data transformations.
Graph Structure
from airtrain import TaskGraph, SkillNode, ConnectorNode
class DataAnalysisGraph(TaskGraph):
input_schema = DataAnalysisInput
output_schema = DataAnalysisOutput
def build_graph(self):
# Define nodes and connections
fetch_data = SkillNode(FetchDataSkill)
clean_data = SkillNode(DataCleaningSkill)
analyze_data = SkillNode(AnalysisSkill)
# Connect nodes
self.connect(fetch_data, clean_data)
self.connect(clean_data, analyze_data)
Graph Components
Skill Nodes
Regular nodes that execute specific skills:
class SkillNode:
def __init__(self, skill: Skill):
self.skill = skill
self.input_schema = skill.input_schema
self.output_schema = skill.output_schema
Connector Nodes
Special nodes that transform data between incompatible schemas:
class JsonToDataFrameConnector(ConnectorNode):
input_schema = JsonSchema
output_schema = DataFrameSchema
def transform(self, data):
"""Convert JSON to DataFrame format"""
return pd.DataFrame(data)
Execution Flow
1. Input Validation
def validate_input(self, context: ExecutionContext):
"""
Validate input before graph execution
Args:
context: Execution context containing input data
Raises:
InputValidationError: If validation fails
"""
try:
self.input_schema.validate(context.input_data)
except ValidationError as e:
raise InputValidationError(str(e))
2. Graph Execution
def execute(self, context: ExecutionContext):
"""
Execute the task graph
Args:
context: Execution context with input data and permissions
Returns:
Execution results
"""
try:
self.validate_input(context)
result = self._execute_nodes(context)
self.validate_output(result)
return result
except Exception as e:
self.handle_execution_error(e)
3. Output Validation
def validate_output(self, result: Any):
"""
Validate graph execution output
Args:
result: Execution result
Raises:
OutputValidationError: If validation fails
"""
try:
self.output_schema.validate(result)
except ValidationError as e:
raise OutputValidationError(str(e))
Permission Management
Permissions are managed at the graph level (legacy approach):
class AnalysisGraph(TaskGraph):
required_permissions = [
DatabasePermission,
AnalyticsPermission
]
def check_permissions(self, context):
"""Verify all required permissions are available"""
for permission in self.required_permissions:
if not context.has_permission(permission):
raise PermissionError(f"Missing required permission: {permission}")
Error Handling
Comprehensive error handling system:
def handle_execution_error(self, error: Exception):
"""
Handle errors during graph execution
Args:
error: The caught exception
"""
error_context = {
'graph_id': self.id,
'timestamp': datetime.now(),
'error_type': type(error).__name__,
'error_message': str(error)
}
# Log to default logger
logger.error(f"Graph execution failed: {error_context}")
# Optional: Send to external error tracking
if self.error_tracker:
self.error_tracker.capture_exception(error, context=error_context)
Graph Visualization
Tasks can generate visual representations of their execution graphs:
def visualize(self):
"""Generate visual representation of the graph"""
graph = nx.DiGraph()
for node in self.nodes:
graph.add_node(node.id, label=node.name)
for edge in self.edges:
graph.add_edge(edge.source.id, edge.target.id)
return graph.draw()
Example Implementation
Here's a complete example of a task execution graph:
class DataProcessingGraph(TaskGraph):
input_schema = RawDataSchema
output_schema = ProcessedDataSchema
def build_graph(self):
# Create nodes
json_input = SkillNode(JsonParserSkill)
df_converter = ConnectorNode(JsonToDataFrameConnector)
data_cleaner = SkillNode(DataCleaningSkill)
analyzer = SkillNode(AnalysisSkill)
formatter = SkillNode(OutputFormatterSkill)
# Connect nodes
self.connect(json_input, df_converter)
self.connect(df_converter, data_cleaner)
self.connect(data_cleaner, analyzer)
self.connect(analyzer, formatter)
Best Practices
-
Graph Design
- Keep graphs modular
- Use connector nodes for transformations
- Validate data between critical nodes
- Document node dependencies
-
Error Handling
- Implement comprehensive error catching
- Provide detailed error context
- Use appropriate logging levels
- Consider retry mechanisms
-
Performance
- Optimize node execution order
- Consider parallel execution where possible
- Monitor execution times
- Cache intermediate results when appropriate
-
Maintenance
- Document graph structure
- Version control graph definitions
- Monitor error patterns
- Regular performance reviews
Next Steps
Learn how to create complex execution graphs, implement custom connector nodes, and optimize graph performance.
Control Flow
Tasks support complex control flow patterns using conditional nodes and loops:
Conditional Execution
class AnalysisGraph(TaskGraph):
def build_graph(self):
data_fetch = SkillNode(FetchDataSkill)
# Create conditional branch
if_node = IfNode(
condition=lambda data: len(data) > 1000,
true_branch=SkillNode(LargeDataProcessorSkill),
false_branch=SkillNode(SmallDataProcessorSkill)
)
self.connect(data_fetch, if_node)
Loop Execution
class BatchProcessingGraph(TaskGraph):
def build_graph(self):
data_source = SkillNode(DataSourceSkill)
# While loop based on condition
while_node = WhileNode(
condition=lambda data: data.has_next_batch(),
body=[
SkillNode(ProcessBatchSkill),
SkillNode(UpdateProgressSkill)
]
)
# For loop with fixed iterations
for_node = ForNode(
iterator=lambda data: data.get_chunks(size=100),
body=SkillNode(ChunkProcessorSkill)
)
self.connect(data_source, while_node)
self.connect(while_node, for_node)
Combined Schemas
Schemas can be combined using OR operations to handle multiple input/output types:
from airtrain import SchemaOR
# Define possible input schemas
class TextInputSchema(BaseSchema):
text: str
language: str = "en"
class FileInputSchema(BaseSchema):
file_path: str
encoding: str = "utf-8"
# Combine schemas using OR
combined_input = SchemaOR(
TextInputSchema,
FileInputSchema
)
class TextProcessingGraph(TaskGraph):
input_schema = combined_input
def build_graph(self):
# Handle different input types
input_handler = SkillNode(
skill=InputHandlerSkill,
input_schema=combined_input
)
processor = SkillNode(TextProcessorSkill)
self.connect(input_handler, processor)
Schema OR Validation
When using combined schemas, validation handles multiple possibilities:
def validate_or_schema(data: Any) -> None:
"""
Validate data against multiple possible schemas
Args:
data: Input data to validate
Raises:
ValidationError: If data doesn't match any schema
"""
errors = []
for schema in self.schemas:
try:
schema.validate(data)
return # Valid against this schema
except ValidationError as e:
errors.append(e)
raise ValidationError(
f"Data matches none of the schemas: {errors}"
)
Example with Control Flow and Combined Schemas
Here's a complete example combining control flow and schema combinations:
class DocumentProcessingGraph(TaskGraph):
input_schema = SchemaOR(
PDFSchema,
WordDocSchema,
TextFileSchema
)
def build_graph(self):
# Input handling
document_input = SkillNode(DocumentInputSkill)
# Format-specific processing
format_switch = IfNode(
conditions=[
(lambda doc: doc.is_pdf(), SkillNode(PDFProcessorSkill)),
(lambda doc: doc.is_word(), SkillNode(WordProcessorSkill)),
(lambda doc: doc.is_text(), SkillNode(TextProcessorSkill))
],
default=SkillNode(DefaultProcessorSkill)
)
# Batch processing for large documents
batch_processor = ForNode(
iterator=lambda doc: doc.get_pages(),
body=[
SkillNode(PageProcessorSkill),
IfNode(
condition=lambda page: page.has_images,
true_branch=SkillNode(ImageExtractionSkill)
)
]
)
# Connect nodes
self.connect(document_input, format_switch)
self.connect(format_switch, batch_processor)