@@ -565,36 +565,52 @@ def initialize_model_and_tokenizer(self):
565565 cache_dir = CACHE_DIR / save_dir
566566 cache_dir .mkdir (parents = True , exist_ok = True )
567567
568- config = BitsAndBytesConfig (
569- load_in_4bit = True ,
570- bnb_4bit_compute_dtype = torch .bfloat16 ,
571- bnb_4bit_quant_type = "nf4" ,
572- llm_int8_skip_modules = [
573- "vision_tower" ,
574- "multi_modal_projector" ,
575- "language_model.embed_tokens" ,
576- "language_model.norm" ,
577- "lm_head"
578- ]
579- )
580-
581568 processor = AutoProcessor .from_pretrained (
582569 model_id ,
583570 use_fast = True ,
584571 cache_dir = cache_dir ,
585572 token = False
586573 )
587- model = AutoModelForVision2Seq .from_pretrained (
588- model_id ,
589- quantization_config = config ,
590- torch_dtype = torch .bfloat16 ,
591- low_cpu_mem_usage = True ,
592- cache_dir = cache_dir ,
593- token = False
594- )
595- model .to (self .device )
574+
575+ if self .device == "cuda" and torch .cuda .is_available ():
576+ # Use quantization on CUDA
577+ config = BitsAndBytesConfig (
578+ load_in_4bit = True ,
579+ bnb_4bit_compute_dtype = torch .bfloat16 ,
580+ bnb_4bit_quant_type = "nf4" ,
581+ llm_int8_skip_modules = [
582+ "vision_tower" ,
583+ "multi_modal_projector" ,
584+ "language_model.embed_tokens" ,
585+ "language_model.norm" ,
586+ "lm_head"
587+ ]
588+ )
589+
590+ model = AutoModelForVision2Seq .from_pretrained (
591+ model_id ,
592+ quantization_config = config ,
593+ torch_dtype = torch .bfloat16 ,
594+ low_cpu_mem_usage = True ,
595+ cache_dir = cache_dir ,
596+ token = False ,
597+ device_map = "auto"
598+ )
599+ my_cprint ("Granite Vision model loaded with quantization on CUDA" , "green" )
600+
601+ else :
602+ # CPU mode - no quantization
603+ model = AutoModelForVision2Seq .from_pretrained (
604+ model_id ,
605+ torch_dtype = torch .float32 ,
606+ low_cpu_mem_usage = True ,
607+ cache_dir = cache_dir ,
608+ token = False ,
609+ device_map = {"" : "cpu" }
610+ )
611+ my_cprint ("Granite Vision model loaded on CPU (no quantization)" , "yellow" )
612+
596613 model .eval ()
597- my_cprint ("Granite Vision model loaded into memory" , "green" )
598614 return model , None , processor
599615
600616 @torch .inference_mode ()
0 commit comments