prithivMLmods commited on
Commit
2144dd4
·
verified ·
1 Parent(s): bae4d72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -23,7 +23,8 @@ from transformers.image_utils import load_image
23
  # Constants for text generation
24
  MAX_MAX_NEW_TOKENS = 2048
25
  DEFAULT_MAX_NEW_TOKENS = 1024
26
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
27
 
28
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29
 
@@ -84,7 +85,7 @@ def generate_image(text: str, image: Image.Image,
84
  images=[image],
85
  return_tensors="pt",
86
  padding=True,
87
- truncation=True,
88
  max_length=MAX_INPUT_TOKEN_LENGTH
89
  ).to("cuda")
90
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
@@ -128,7 +129,7 @@ def generate_video(text: str, video_path: str,
128
  add_generation_prompt=True,
129
  return_dict=True,
130
  return_tensors="pt",
131
- truncation=True,
132
  max_length=MAX_INPUT_TOKEN_LENGTH
133
  ).to("cuda")
134
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
 
23
  # Constants for text generation
24
  MAX_MAX_NEW_TOKENS = 2048
25
  DEFAULT_MAX_NEW_TOKENS = 1024
26
+ # Increase or disable input truncation to avoid token mismatches
27
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
28
 
29
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30
 
 
85
  images=[image],
86
  return_tensors="pt",
87
  padding=True,
88
+ truncation=False,
89
  max_length=MAX_INPUT_TOKEN_LENGTH
90
  ).to("cuda")
91
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
 
129
  add_generation_prompt=True,
130
  return_dict=True,
131
  return_tensors="pt",
132
+ truncation=False,
133
  max_length=MAX_INPUT_TOKEN_LENGTH
134
  ).to("cuda")
135
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)