File size: 4,958 Bytes
30d349c
 
 
 
 
 
 
 
17d7a6f
30d349c
 
 
 
 
 
 
 
 
17d7a6f
30d349c
 
 
 
0d707b6
30d349c
 
 
 
17d7a6f
30d349c
 
 
17d7a6f
30d349c
 
 
 
 
 
 
17d7a6f
30d349c
17d7a6f
30d349c
 
 
 
 
 
17d7a6f
30d349c
 
 
 
 
 
 
 
 
 
 
 
 
17d7a6f
30d349c
 
17d7a6f
30d349c
 
 
 
 
 
 
 
 
 
2c1a214
30d349c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17d7a6f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import requests
import json
import re
from urllib.parse import quote

def extract_between_tags(text, start_tag, end_tag):
    start_index = text.find(start_tag)
    end_index = text.find(end_tag, start_index)
    return text[start_index+len(start_tag):end_index]

class VectaraQuery():
    def __init__(self, api_key: str, customer_id: str, corpus_id: str, prompt_name: str = None):
        self.customer_id = customer_id
        self.corpus_id = corpus_id
        self.api_key = api_key
        self.prompt_name = prompt_name if prompt_name else "vectara-experimental-summary-ext-2023-12-11-large"
        self.conv_id = None

    def get_body(self, user_response: str):
        corpora_key_list = [{
            'customer_id': self.customer_id, 'corpus_id': self.corpus_id, 'lexical_interpolation_config': {'lambda': 0.025}
        }]

        user_response = user_response.replace('"', '\\"')  # Escape double quotes
        prompt = f'''
        [
            {{
                "role": "system",
                "content": "You are an assistant that provides information about drink names based on a given corpus."
            }},
            {{
                "role": "user",
                "content": "{user_response}"
            }}
        ]
        '''

        return {
            'query': [
                { 
                    'query': user_response,
                    'start': 0,
                    'numResults': 10,
                    'corpusKey': corpora_key_list,
                    'context_config': {
                        'sentences_before': 2,
                        'sentences_after': 2,
                        'start_tag': "%START_SNIPPET%",
                        'end_tag': "%END_SNIPPET%",
                    }
                } 
            ]
        }

    def get_headers(self):
        return {
            "Content-Type": "application/json",
            "Accept": "application/json",
            "customer-id": self.customer_id,
            "x-api-key": self.api_key,
            "grpc-timeout": "60S"
        }

    def submit_query(self, query_str: str):

        endpoint = f"https://api.vectara.io/v1/stream-query"
        body = self.get_body(query_str)
        response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_headers(), stream=True) 
        if response.status_code != 200:
            print(f"Query failed with code {response.status_code}, reason {response.reason}, text {response.text}")
            return "Sorry, something went wrong in my brain. Please try again later."

        chunks = []
        accumulated_text = ""  # Initialize text accumulation
        pattern_max_length = 50  # Example heuristic
        for line in response.iter_lines():
            if line:  # filter out keep-alive new lines
                data = json.loads(line.decode('utf-8'))                
                res = data['result']
                response_set = res['responseSet']                
                if response_set is None:
                    # grab next chunk and yield it as output
                    summary = res.get('summary', None)
                    if summary is None or len(summary)==0:
                        continue
                    else:
                        chat = summary.get('chat', None)
                        if chat and chat.get('status', None):
                            st_code = chat['status']
                            print(f"Chat query failed with code {st_code}")
                            if st_code == 'RESOURCE_EXHAUSTED':
                                self.conv_id = None
                                return 'Sorry, Vectara chat turns exceeds plan limit.'
                            return 'Sorry, something went wrong in my brain. Please try again later.'
                        conv_id = chat.get('conversationId', None) if chat else None
                        if conv_id:
                            self.conv_id = conv_id
                        
                    chunk = summary['text']
                    accumulated_text += chunk  # Append current chunk to accumulation
                    if len(accumulated_text) > pattern_max_length:
                        accumulated_text = re.sub(r"\[\d+\]", "", accumulated_text)
                        accumulated_text = re.sub(r"\s+\.", ".", accumulated_text)
                        out_chunk = accumulated_text[:-pattern_max_length]
                        chunks.append(out_chunk)
                        yield out_chunk
                        accumulated_text = accumulated_text[-pattern_max_length:]

                    if summary['done']:
                        break

        # yield the last piece
        if len(accumulated_text) > 0:
            accumulated_text = re.sub(r" \[\d+\]\.", ".", accumulated_text)
            chunks.append(accumulated_text)
            yield accumulated_text        
        
        return ''.join(chunks)