arjunanand13 commited on
Commit
4bdfb46
·
verified ·
1 Parent(s): 719763c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -16,18 +16,18 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
  model_id = 'mistralai/Mistral-7B-Instruct-v0.3'
17
  client = InferenceClient(model_id)
18
 
19
- # Define stopping criteria
20
- class StopOnTokens:
21
- def __call__(self, input_ids, scores, **kwargs):
22
- for stop_ids in stop_token_ids:
23
- if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
24
- return True
25
- return False
26
-
27
- # Define stopping criteria list
28
- stop_list = ['\nHuman:', '\n```\n']
29
- stop_token_ids = [client.tokenizer(x)['input_ids'] for x in stop_list]
30
- stop_token_ids = [torch.LongTensor(x).to(cuda.current_device() if cuda.is_available() else 'cpu') for x in stop_token_ids]
31
 
32
  # Create text generation pipeline
33
  def generate(prompt, history, system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0):
 
16
  model_id = 'mistralai/Mistral-7B-Instruct-v0.3'
17
  client = InferenceClient(model_id)
18
 
19
+ # # Define stopping criteria
20
+ # class StopOnTokens:
21
+ # def __call__(self, input_ids, scores, **kwargs):
22
+ # for stop_ids in stop_token_ids:
23
+ # if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
24
+ # return True
25
+ # return False
26
+
27
+ # # Define stopping criteria list
28
+ # stop_list = ['\nHuman:', '\n```\n']
29
+ # stop_token_ids = [client.tokenizer(x)['input_ids'] for x in stop_list]
30
+ # stop_token_ids = [torch.LongTensor(x).to(cuda.current_device() if cuda.is_available() else 'cpu') for x in stop_token_ids]
31
 
32
  # Create text generation pipeline
33
  def generate(prompt, history, system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0):