MiniGPT4¶
- class mmpretrain.models.multimodal.MiniGPT4(vision_encoder, q_former_model, lang_encoder, tokenizer, task='caption', freeze_vit=True, freeze_q_former=True, num_query_token=32, prompt_template='###Human: {} ###Assistant: ', raw_prompts=None, max_txt_len=32, end_sym='\n', generation_cfg={}, data_preprocessor=None, init_cfg=None)[source]¶
The multi-modality model of MiniGPT-4.
The implementation of MiniGPT-4. Modified from https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/models/mini_gpt4.py
- Args:
vision_encoder (dict): The config for vision encoder. q_former_model (dict): The config for Qformer. lang_encoder (dict): The config for language model. tokenizer (dict): The config for tokenizer. task (str): To define the task, which control the processing of text.
Defaults to ‘caption’.
freeze_vit (bool): Freeze the training of ViT. Defaults to True. freeze_q_former (bool): Freeze the training of Qformer. Defaults to
True.
- num_query_token (int): Number of query tokens of Qformer. Defaults to
- prompt_template (str): Prompt template of the model. Defaults to
‘###Human: {} ###Assistant: ‘.
raw_prompts (list): Prompts for training. Defaults to None. max_txt_len (int): Max token length while doing tokenization. Defaults
to 32.
end_sym (str): Ended symbol of the sequence. Defaults to ‘
- ‘.
- generation_cfg (dict): The config of text generation. Defaults to
dict().
- data_preprocessor (
BaseDataPreprocessor): Used for pre-processing data sampled by dataloader to the format accepted by
forward(). Defaults to None.
init_cfg (dict): Initialization config dict. Defaults to None.
- forward(images, data_samples=None, mode='predict', **kwargs)[source]¶
The unified entry for a forward process in both training and test. The method accepts the following modes:
“predict”: Forward and return a list of data samples contain the predict results.
- Parameters:
images (torch.Tensor) – the preprocessed image tensor of shape
(N, C, H, W).data_samples (List[DataSample], optional) – The annotation data of every samples. Defaults to None.
mode (str) – Return what kind of value. Defaults to ‘predict’.
- loss(images, data_samples=None)[source]¶
The forward function in training.
- Parameters:
inputs (List[torch.Tensor]) – The input images.
data_samples (List[DataSample]) – All elements required during the forward function.
- Returns:
A dictionary of loss components.
- Return type:
Dict[str, torch.Tensor]
- post_process(outputs, data_samples)[source]¶
Perform post process for outputs for different task.
- Parameters:
outputs (torch.Tensor) – The generated outputs.
data_samples (List[DataSample], optional) – The annotation data of every samples.
- Returns:
Return list of data samples.
- Return type:
List[DataSample]
- prompt_wrap(img_embeds, atts_img, prompt)[source]¶
The function to wrap the image and prompt.
Currently, the function only supports applying one prompt to all input images in the one batch.
- Parameters:
img_embeds (torch.Tensor) – The embedding of the input images.
atts_img (torch.Tensor) – Attention map of the image embeddings.
prompt (str) – The prompt of the batch data.
- Returns:
The embedding and attention map.
- Return type:
Tuple[torch.Tensor, torch.Tensor]