sitammeur commited on
Commit
1a859c1
·
verified ·
1 Parent(s): c46849d

Update src/model.py

Browse files
Files changed (1) hide show
  1. src/model.py +6 -24
src/model.py CHANGED
@@ -1,5 +1,4 @@
1
  # Importing the requirements
2
- import uuid
3
  import torch
4
  from transformers import AutoModel, AutoTokenizer
5
  import spaces
@@ -23,24 +22,6 @@ tokenizer = AutoTokenizer.from_pretrained(
23
  model.eval()
24
 
25
 
26
- class _GeneratorPickleHack:
27
- def __init__(self, generator, generator_id=None):
28
- self.generator = generator
29
- self.generator_id = (
30
- generator_id if generator_id is not None else str(uuid.uuid4())
31
- )
32
-
33
- def __call__(self, *args, **kwargs):
34
- return self.generator(*args, **kwargs)
35
-
36
- def __reduce__(self):
37
- return (_GeneratorPickleHack_raise, (self.generator_id,))
38
-
39
-
40
- def _GeneratorPickleHack_raise(*args, **kwargs):
41
- raise AssertionError("cannot actually unpickle _GeneratorPickleHack!")
42
-
43
-
44
  @spaces.GPU()
45
  def describe_video(video, question):
46
  """
@@ -54,9 +35,7 @@ def describe_video(video, question):
54
  str: The generated answer to the question.
55
  """
56
  # Encode the video frames
57
- frames = _GeneratorPickleHack(encode_video)(video)
58
- #frames = encode_video(video)
59
- #frames = list(frames) # Convert generator or any iterable to list
60
 
61
  # Message format for the model
62
  msgs = [{"role": "user", "content": frames + [question]}]
@@ -79,5 +58,8 @@ def describe_video(video, question):
79
  **params
80
  )
81
 
82
- # Return the answer
83
- return answer
 
 
 
 
1
  # Importing the requirements
 
2
  import torch
3
  from transformers import AutoModel, AutoTokenizer
4
  import spaces
 
22
  model.eval()
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  @spaces.GPU()
26
  def describe_video(video, question):
27
  """
 
35
  str: The generated answer to the question.
36
  """
37
  # Encode the video frames
38
+ frames = encode_video(video)
 
 
39
 
40
  # Message format for the model
41
  msgs = [{"role": "user", "content": frames + [question]}]
 
58
  **params
59
  )
60
 
61
+ # Consume the generator and concatenate the results
62
+ full_answer = "".join(answer)
63
+
64
+ # Return the full answer as a string
65
+ return full_answer