Backend optimisation #16
					 5 changed files with 104 additions and 31 deletions
				
			
		
							
								
								
									
										3
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							|  | @ -1,4 +1,5 @@ | ||||||
| venv/ | venv/ | ||||||
| __pycache__ | __pycache__/ | ||||||
| .idea/ | .idea/ | ||||||
| .vscode/ | .vscode/ | ||||||
|  | token.txt | ||||||
|  |  | ||||||
							
								
								
									
										62
									
								
								py/api.py
									
										
									
									
									
								
							
							
						
						
									
										62
									
								
								py/api.py
									
										
									
									
									
								
							|  | @ -1,27 +1,52 @@ | ||||||
| import requests | import requests | ||||||
| import json | import json | ||||||
| from transformers import AutoTokenizer, LlamaForCausalLM | from gradio_client import Client | ||||||
| 
 | import os | ||||||
|  | from mistralai import Mistral | ||||||
| 
 | 
 | ||||||
| class API: | class API: | ||||||
|     # This method processes a message via transformers. (NOT FINISHED!) |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def process_text_transformers(prompt, model): |     def process_text_mistralai(prompt, model, system): | ||||||
|         model = LlamaForCausalLM.from_pretrained(model) |         with open("token.txt", "r") as f: | ||||||
|         tokenizer = AutoTokenizer.from_pretrained(model) |             token = f.readlines()[0].strip() | ||||||
| 
 | 
 | ||||||
|         inputs = tokenizer(prompt, return_tensors="pt") |         api_key = token | ||||||
| 
 | 
 | ||||||
|         generate_ids = model.generate(inputs.input_ids, max_length=30) |         client = Mistral(api_key=api_key) | ||||||
|         return tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | 
 | ||||||
|  |         chat_response = client.chat.complete( | ||||||
|  |             model=model, | ||||||
|  |             messages=[ | ||||||
|  |                 { | ||||||
|  |                     "role": "user", | ||||||
|  |                     "content": prompt, | ||||||
|  |                 }, { | ||||||
|  |                     "role": "system", | ||||||
|  |                     "content": system, | ||||||
|  |                 }, | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|  |         return chat_response.choices[0].message.content | ||||||
|  |     @staticmethod | ||||||
|  |     def process_text_gradio(prompt, model, system): | ||||||
|  |         client = Client(model) | ||||||
|  |         result = client.predict( | ||||||
|  |             message=prompt, | ||||||
|  |             system_message=system, | ||||||
|  |             max_tokens=512, | ||||||
|  |             temperature=0.7, | ||||||
|  |             top_p=0.95, | ||||||
|  |             api_name="/chat" | ||||||
|  |         ) | ||||||
|  |         return result; | ||||||
| 
 | 
 | ||||||
|     # This method processes a message via ollama |     # This method processes a message via ollama | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def process_text_local(prompt, model): |     def process_text_local(prompt, model, system): | ||||||
|         ollama_url = "http://localhost:11434" |         ollama_url = "http://localhost:11434" | ||||||
| 
 | 
 | ||||||
|         response = requests.post( |         response = requests.post( | ||||||
|             f"{ollama_url}/api/generate", json={"model": model, "prompt": prompt} |             f"{ollama_url}/api/generate", json={"model": model, "prompt": prompt, "system": system} | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         if response.status_code == 200: |         if response.status_code == 200: | ||||||
|  | @ -37,17 +62,20 @@ class API: | ||||||
|             return "Error: " + response.text |             return "Error: " + response.text | ||||||
| 
 | 
 | ||||||
|     # This method sends a message to a certain AI. |     # This method sends a message to a certain AI. | ||||||
|     def send_message(self, message, model): | 
 | ||||||
|  |     def send_message(self, message, model, system): | ||||||
|         if model == 1: |         if model == 1: | ||||||
|             answer = self.process_text_local(message, "phi3.5") |             answer = self.process_text_local(message, "phi3.5", system) | ||||||
|         elif model == 2: |         elif model == 2: | ||||||
|             answer = self.process_text_local(message, "gemma2:2b") |             answer = self.process_text_local(message, "gemma2:9b", system) | ||||||
|         elif model == 3: |         elif model == 3: | ||||||
|             answer = self.process_text_local(message, "qwen2:0.5b") |             answer = self.process_text_local(message, "codegemma:2b", system) | ||||||
|         elif model == 4: |         elif model == 4: | ||||||
|             answer = self.process_text_local(message, "codegemma:2b") |             answer = self.process_text_gradio(message, "PatrickPluto/InterstellarAIChatbot", system) | ||||||
|         elif model == 5: |         elif model == 5: | ||||||
|             answer = self.process_text_transformers(message, "meta-llama/Meta-Llama-3.1-8B") |             answer = self.process_text_mistralai(message, "mistral-large-latest", system) | ||||||
|  |         elif model == 6: | ||||||
|  |             answer = self.process_text_mistralai(message, "codestral-latest", system) | ||||||
|         else: |         else: | ||||||
|             return "Invalid choice" |             return "Invalid choice" | ||||||
|         return answer |         return answer | ||||||
|  |  | ||||||
							
								
								
									
										17
									
								
								py/install.sh
									
										
									
									
									
										Executable file
									
								
							
							
						
						
									
										17
									
								
								py/install.sh
									
										
									
									
									
										Executable file
									
								
							|  | @ -0,0 +1,17 @@ | ||||||
|  | #!/bin/bash | ||||||
|  | 
 | ||||||
|  | python3 -m venv venv | ||||||
|  | source venv/bin/activate | ||||||
|  | pip install flask | ||||||
|  | pip install SpeechRecognition | ||||||
|  | pip install pyaudio | ||||||
|  | pip install pocketsphinx | ||||||
|  | pip install sentencepiece | ||||||
|  | pip install pyqt5 | ||||||
|  | pip install pyqtwebengine | ||||||
|  | pip install gradio_client | ||||||
|  | pip install mistralai | ||||||
|  | 
 | ||||||
|  | ollama pull phi3.5 | ||||||
|  | ollama pull codegemma:2b | ||||||
|  | ollama pull gemma2:9b | ||||||
							
								
								
									
										10
									
								
								py/venv.sh
									
										
									
									
									
								
							
							
						
						
									
										10
									
								
								py/venv.sh
									
										
									
									
									
								
							|  | @ -1,10 +0,0 @@ | ||||||
| #!/bin/bash |  | ||||||
| 
 |  | ||||||
| virtualenv venv |  | ||||||
| source venv/bin/activate |  | ||||||
| pip install transformers |  | ||||||
| pip install torch |  | ||||||
| pip install flask |  | ||||||
| pip install SpeechRecognition |  | ||||||
| pip install pyaudio |  | ||||||
| pip install pocketsphinx |  | ||||||
							
								
								
									
										43
									
								
								py/web_flask.py
									
										
									
									
									
										
										
										Normal file → Executable file
									
								
							
							
						
						
									
										43
									
								
								py/web_flask.py
									
										
									
									
									
										
										
										Normal file → Executable file
									
								
							|  | @ -1,6 +1,14 @@ | ||||||
|  | #!venv/bin/python | ||||||
|  | 
 | ||||||
| from flask import Flask, request, render_template | from flask import Flask, request, render_template | ||||||
| from api import API | from api import API | ||||||
| from voice_recognition import Voice | from voice_recognition import Voice | ||||||
|  | import sys | ||||||
|  | import threading | ||||||
|  | from PyQt5.QtCore import * | ||||||
|  | from PyQt5.QtWebEngineWidgets import * | ||||||
|  | from PyQt5.QtWidgets import * | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| APP = Flask(__name__) | APP = Flask(__name__) | ||||||
| api = API() | api = API() | ||||||
|  | @ -13,6 +21,8 @@ messages = [] | ||||||
| @APP.route('/', methods=['GET', 'POST']) | @APP.route('/', methods=['GET', 'POST']) | ||||||
| def index(): | def index(): | ||||||
|     global messages |     global messages | ||||||
|  |     system_prompt = 'You are a helpful assistant.' | ||||||
|  |     system = 'Your system prompt is: \"'+system_prompt+'\" The following is your chat log so far: \n' | ||||||
| 
 | 
 | ||||||
|     if request.method == 'POST': |     if request.method == 'POST': | ||||||
|         option = request.form['option'] |         option = request.form['option'] | ||||||
|  | @ -25,7 +35,12 @@ def index(): | ||||||
|         elif option == "chat": |         elif option == "chat": | ||||||
|             messages.append(f"User: {user_message}") |             messages.append(f"User: {user_message}") | ||||||
| 
 | 
 | ||||||
|         ai_response = "AI: " + api.send_message(user_message, 3) |         for line in messages: | ||||||
|  |             system += line + '\n' | ||||||
|  | 
 | ||||||
|  |         system += "The chat log is now finished." | ||||||
|  | 
 | ||||||
|  |         ai_response = "AI: " + api.send_message(user_message, 5, system) | ||||||
|         messages.append(ai_response) |         messages.append(ai_response) | ||||||
| 
 | 
 | ||||||
|     return render_template('index.html', messages=messages) |     return render_template('index.html', messages=messages) | ||||||
|  | @ -42,6 +57,28 @@ def contact(): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     APP.run(debug=True) |     qapp = QApplication(sys.argv) | ||||||
| 
 | 
 | ||||||
| # This is a comment --> test if this creates a merge conflict |     view = QWebEngineView() | ||||||
|  | 
 | ||||||
|  |     view.setGeometry(100, 100, 1280, 720) | ||||||
|  |     view.setWindowTitle("InterstellarAI") | ||||||
|  | 
 | ||||||
|  |     view.setUrl(QUrl("http://localhost:5000")) | ||||||
|  | 
 | ||||||
|  |     view.show() | ||||||
|  | 
 | ||||||
|  |     def run_flask(): | ||||||
|  |         APP.run() | ||||||
|  | 
 | ||||||
|  |     def stop_flask(): | ||||||
|  |         thread.join() | ||||||
|  |         qapp.quit() | ||||||
|  | 
 | ||||||
|  |     thread = threading.Thread(target=run_flask) | ||||||
|  |     thread.daemon = True | ||||||
|  |     thread.start() | ||||||
|  | 
 | ||||||
|  |     qapp.aboutToQuit.connect(stop_flask) | ||||||
|  | 
 | ||||||
|  |     sys.exit(qapp.exec_()) | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue