@@ -245,14 +245,12 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
245245
246246 fuse_normalize = reader_cfg .get ('fuse_normalize' , False )
247247 sample_transforms = reader_cfg ['sample_transforms' ]
248- hpi_dynamic_shape = None
249248 for st in sample_transforms [1 :]:
250249 for key , value in st .items ():
251250 p = {'type' : key }
252251 if key == 'Resize' :
253252 if int (image_shape [1 ]) != - 1 :
254253 value ['target_size' ] = image_shape [1 :]
255- hpi_dynamic_shape = image_shape [1 :]
256254 value ['interp' ] = value .get ('interp' , 1 ) # cv2.INTER_LINEAR
257255 if fuse_normalize and key == 'NormalizeImage' :
258256 continue
@@ -277,7 +275,7 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
277275 preprocess_list .append (p )
278276 break
279277
280- return preprocess_list , label_list , hpi_dynamic_shape
278+ return preprocess_list , label_list
281279
282280
283281def _parse_tracker (tracker_cfg ):
@@ -287,7 +285,7 @@ def _parse_tracker(tracker_cfg):
287285 return tracker_params
288286
289287
290- def _dump_infer_config (config , path , image_shape , model ):
288+ def _dump_infer_config (config , path , image_shape , model , input_spec ):
291289 arch_state = False
292290 from ppdet .core .config .yaml_helpers import setup_orderdict
293291 setup_orderdict ()
@@ -381,34 +379,47 @@ def _dump_infer_config(config, path, image_shape, model):
381379 reader_cfg = config ['TestReader' ]
382380 dataset_cfg = config ['TestDataset' ]
383381
384- infer_cfg ['Preprocess' ], infer_cfg ['label_list' ], hpi_dynamic_shape = _parse_reader (
382+ infer_cfg ['Preprocess' ], infer_cfg ['label_list' ] = _parse_reader (
385383 reader_cfg , dataset_cfg , config ['metric' ], label_arch , image_shape [1 :])
386384 if config .get ("uniform_output_enabled" , None ):
385+ for d in input_spec :
386+ if 'image' in d :
387+ hpi_dynamic_shape = list (d ['image' ].shape [2 :])
387388 def get_dynamic_shapes (hpi_shape ):
388389 return [[1 , 3 ] + hpi_shape , [1 , 3 ] + hpi_shape , [8 , 3 ] + hpi_shape ]
389390
390- dynamic_shapes = get_dynamic_shapes (hpi_dynamic_shape ) if hpi_dynamic_shape else [
391+ dynamic_shapes = get_dynamic_shapes (hpi_dynamic_shape ) if hpi_dynamic_shape != [ - 1 , - 1 ] else [
391392 [1 , 3 , 320 , 320 ],
392393 [1 , 3 , 640 , 640 ],
393394 [8 , 3 , 1280 , 1280 ]
394395 ]
395396 shapes = {
396397 "image" : dynamic_shapes ,
397- "im_shape" : [[1 , 2 ], [1 , 2 ], [8 , 2 ]],
398398 "scale_factor" : [[1 , 2 ], [1 , 2 ], [8 , 2 ]]
399399 }
400- trt_dynamic_shape = [
401- [dim for _ in range (shape [0 ]) for dim in shape [2 :]]
402- for shape in dynamic_shapes
403- ]
404400 trt_dynamic_shape_input_data = {
405- "im_shape" : trt_dynamic_shape ,
406401 "scale_factor" : [
407402 [2 , 2 ],
408403 [1 , 1 ],
409404 [0.67 for _ in range (2 * shapes ["scale_factor" ][- 1 ][0 ])]
410405 ]
411406 }
407+ model_names_required_imgsize = [
408+ "DETR" ,
409+ "DINO" ,
410+ "RCNN" ,
411+ "YOLOv3" ,
412+ "CenterNet" ,
413+ "BlazeFace" ,
414+ "BlazeFace-FPN-SSH" ,
415+ ]
416+ if any (name in config .get ('pdx_model_name' , None ) for name in model_names_required_imgsize ):
417+ shapes ["im_shape" ] = [[1 , 2 ], [1 , 2 ], [8 , 2 ]]
418+ trt_dynamic_shape = [
419+ [dim for _ in range (shape [0 ]) for dim in shape [2 :]]
420+ for shape in dynamic_shapes
421+ ]
422+ trt_dynamic_shape_input_data ["im_shape" ] = trt_dynamic_shape
412423 hpi_config = OrderedDict ({
413424 "backend_configs" : OrderedDict ({
414425 "paddle_infer" : OrderedDict ({
0 commit comments