File size: 6,194 Bytes
dd6f4e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# -*- coding: utf-8 -*-
"""TestAPI.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1WToaz7kQoFpI0_M8j6uWPigBrKlkL4ml
"""



from transformers import AutoTokenizer,AutoModelForCausalLM
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from typing import List


from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Any, List, Mapping, Optional
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

import mysql.connector
import re
from datetime import datetime


from langchain.memory import ConversationBufferMemory
from typing import Any, List, Mapping, Optional

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from langchain.memory import ConversationSummaryBufferMemory

from langchain.memory import ConversationSummaryMemory


model_name = "Open-Orca/OpenOrca-Platypus2-13B"


tokenizer = AutoTokenizer.from_pretrained(
    model_name, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    load_in_8bit = True,
    device_map = "auto",
)



model = PeftModel.from_pretrained(model, "teslalord/open-orca-platypus-2-medical")

model = model.merge_and_unload()



class CustomLLM(LLM):
    n: int
    # custom_model: llm  # Replace with the actual type of your custom model

    @property
    def _llm_type(self) -> str:
        return "custom"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")

        input_ids = tokenizer.encode(prompt, return_tensors="pt").to('cuda')
        with torch.no_grad():
            output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50)
        generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
        response = generated_text.split("->:")[-1]
        return response

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {"n": self.n}



def ask_bot(question):
  input_ids = tokenizer.encode(question, return_tensors="pt").to('cuda')
  with torch.no_grad():
      output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50)
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
  response = generated_text.split("->:")[-1]
  return response


class DbHandler():
    def __init__(self):
        self.db_con = mysql.connector.connect(
            host="frwahxxknm9kwy6c.cbetxkdyhwsb.us-east-1.rds.amazonaws.com",
            user="j6qbx3bgjysst4jr",
            password="mcbsdk2s27ldf37t",
            port=3306,
            database="nkw2tiuvgv6ufu1z")
        self.cursorObject = self.db_con.cursor()

    def insert(self, fields, values):
        try:
            # Convert the lists to comma-separated strings
            fields_str = ', '.join(fields)
            values_str = ', '.join([f"'{v}'" for v in values])  # Wrap values in single quotes for SQL strings

            # Construct the SQL query
            query = f"INSERT INTO chatbot_conversation ({fields_str}) VALUES ({values_str})"

            self.cursorObject.execute(query)
            self.db_con.commit()
            return True
        except Exception as e:
            print(e)
            return False

    def get_history(self, patient_id):
        try:
            query = f"SELECT * FROM chatbot_conversation WHERE patient_id = '{patient_id}' ORDER BY timestamp ASC;"
            self.cursorObject.execute(query)
            data = self.cursorObject.fetchall()
            return data
        except Exception as e:
            print(e)
            return None


    def close_db(self):
        self.db_con.close()





def get_conversation_history(db, patient_id):
    conversations = db.get_history(patient_id)
    if conversations:
        return conversations[-1][5]
    return ""



llm = CustomLLM(n=10)
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)

@app.get('/healthcheck')
async def root():
    return {'status': 'running'}

@app.post('/{patient_id}')
def chatbot(patient_id, user_data: dict=None):
    user_input = user_data["userObject"]["userInput"].get("message")
    db = DbHandler()
    try:
      history = get_conversation_history(db, patient_id)
      memory = ConversationSummaryBufferMemory(llm=llm, max_token_limit=200)
      prompt = "You are now a medical chatbot, and I am a patient. I will describe my conditions and symptoms and you will give me medical suggestions"
      if history:
        human_input = prompt + "The following is the patient's previous conversation with you: " + history + "This is the current question: " + user_input + " ->:"
      else:
        human_input = prompt + user_input + " ->:"
      human_text =  user_input.replace("'", "")
      response = llm._call(human_input)
      response = response.replace("'", "")
      memory.save_context({"input": user_input}, {"output": response})
      summary = memory.load_memory_variables({})
      ai_text = response.replace("'", "")
      memory.save_context({"input": user_input}, {"output": ai_text})
      summary = memory.load_memory_variables({})
      db.insert(("patient_id", "patient_text", "ai_text", "timestamp", "summarized_text"), (patient_id, human_text, ai_text, str(datetime.now()), summary['history'].replace("'", "")))
      db.close_db()
      return {"response": response}
    finally:
      db.close_db()