linoyts HF Staff commited on
Commit
2cdceda
·
verified ·
1 Parent(s): 086c82a

Create optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +70 -0
optimization.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+ from torchao.quantization import quantize_
8
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
9
+ import spaces
10
+ import torch
11
+ from torch.utils._pytree import tree_map
12
+
13
+
14
+ P = ParamSpec('P')
15
+
16
+
17
+ TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length')
18
+ TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length')
19
+
20
+ TRANSFORMER_DYNAMIC_SHAPES = {
21
+ 'hidden_states': {
22
+ 1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
23
+ },
24
+ 'encoder_hidden_states': {
25
+ 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
26
+ },
27
+ 'encoder_hidden_states_mask': {
28
+ 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
29
+ },
30
+ 'image_rotary_emb': ({
31
+ 0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
32
+ }, {
33
+ 0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
34
+ }),
35
+ }
36
+
37
+
38
+ INDUCTOR_CONFIGS = {
39
+ 'conv_1x1_as_mm': True,
40
+ 'epilogue_fusion': False,
41
+ 'coordinate_descent_tuning': True,
42
+ 'coordinate_descent_check_all_directions': True,
43
+ 'max_autotune': True,
44
+ 'triton.cudagraphs': True,
45
+ }
46
+
47
+
48
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
49
+
50
+ @spaces.GPU(duration=1500)
51
+ def compile_transformer():
52
+
53
+ with spaces.aoti_capture(pipeline.transformer) as call:
54
+ pipeline(*args, **kwargs)
55
+
56
+ dynamic_shapes = tree_map(lambda t: None, call.kwargs)
57
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
58
+
59
+ # quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
60
+
61
+ exported = torch.export.export(
62
+ mod=pipeline.transformer,
63
+ args=call.args,
64
+ kwargs=call.kwargs,
65
+ dynamic_shapes=dynamic_shapes,
66
+ )
67
+
68
+ return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
69
+
70
+ spaces.aoti_apply(compile_transformer(), pipeline.transformer)